summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/JSON_parser.h152
-rw-r--r--src/Makefile.am67
-rw-r--r--src/aligner.cc204
-rw-r--r--src/aligner.h23
-rw-r--r--src/apply_models.cc344
-rw-r--r--src/apply_models.h20
-rw-r--r--src/array2d.h171
-rw-r--r--src/bottom_up_parser.cc260
-rw-r--r--src/bottom_up_parser.h27
-rw-r--r--src/cdec.cc474
-rw-r--r--src/cdec_ff.cc18
-rw-r--r--src/collapse_weights.cc102
-rw-r--r--src/dict.h40
-rw-r--r--src/dict_test.cc30
-rw-r--r--src/earley_composer.cc726
-rw-r--r--src/earley_composer.h29
-rw-r--r--src/exp_semiring.h71
-rw-r--r--src/fdict.cc4
-rw-r--r--src/fdict.h21
-rw-r--r--src/ff.cc93
-rw-r--r--src/ff.h121
-rw-r--r--src/ff_factory.cc35
-rw-r--r--src/ff_factory.h39
-rw-r--r--src/ff_itg_span.h7
-rw-r--r--src/ff_test.cc134
-rw-r--r--src/ff_wordalign.cc221
-rw-r--r--src/ff_wordalign.h133
-rw-r--r--src/filelib.cc22
-rw-r--r--src/filelib.h66
-rw-r--r--src/forest_writer.cc23
-rw-r--r--src/forest_writer.h16
-rw-r--r--src/freqdict.cc23
-rw-r--r--src/freqdict.h19
-rw-r--r--src/fst_translator.cc91
-rw-r--r--src/grammar.cc163
-rw-r--r--src/grammar.h83
-rw-r--r--src/grammar_test.cc59
-rw-r--r--src/gzstream.cc165
-rw-r--r--src/gzstream.h121
-rw-r--r--src/hg.cc483
-rw-r--r--src/hg.h225
-rw-r--r--src/hg_intersect.cc121
-rw-r--r--src/hg_intersect.h13
-rw-r--r--src/hg_io.cc585
-rw-r--r--src/hg_io.h37
-rw-r--r--src/hg_test.cc441
-rw-r--r--src/ibm_model1.cc4
-rw-r--r--src/inside_outside.h111
-rw-r--r--src/json_parse.cc50
-rw-r--r--src/json_parse.h58
-rw-r--r--src/kbest.h207
-rw-r--r--src/lattice.cc27
-rw-r--r--src/lattice.h31
-rw-r--r--src/lexcrf.cc112
-rw-r--r--src/lexcrf.h18
-rw-r--r--src/lm_ff.cc328
-rw-r--r--src/lm_ff.h32
-rw-r--r--src/logval.h136
-rw-r--r--src/maxtrans_blunsom.cc287
-rw-r--r--src/parser_test.cc35
-rw-r--r--src/phrasebased_translator.cc206
-rw-r--r--src/phrasebased_translator.h18
-rw-r--r--src/phrasetable_fst.cc141
-rw-r--r--src/phrasetable_fst.h34
-rw-r--r--src/prob.h8
-rw-r--r--src/sampler.h136
-rw-r--r--src/scfg_translator.cc66
-rw-r--r--src/sentence_metadata.h42
-rw-r--r--src/small_vector.h187
-rw-r--r--src/small_vector_test.cc129
-rw-r--r--src/sparse_vector.cc98
-rw-r--r--src/sparse_vector.h264
-rw-r--r--src/stringlib.cc97
-rw-r--r--src/stringlib.h91
-rw-r--r--src/synparse.cc212
-rw-r--r--src/tdict.cc49
-rw-r--r--src/tdict.h19
-rw-r--r--src/timing_stats.cc24
-rw-r--r--src/timing_stats.h25
-rw-r--r--src/translator.h54
-rw-r--r--src/trule.cc237
-rw-r--r--src/trule.h122
-rw-r--r--src/trule_test.cc65
-rw-r--r--src/ttables.cc31
-rw-r--r--src/ttables.h87
-rw-r--r--src/viterbi.cc39
-rw-r--r--src/viterbi.h130
-rw-r--r--src/weights.cc73
-rw-r--r--src/weights.h21
-rw-r--r--src/weights_test.cc28
-rw-r--r--src/wordid.h6
91 files changed, 10497 insertions, 0 deletions
diff --git a/src/JSON_parser.h b/src/JSON_parser.h
new file mode 100644
index 00000000..ceb5b24b
--- /dev/null
+++ b/src/JSON_parser.h
@@ -0,0 +1,152 @@
+#ifndef JSON_PARSER_H
+#define JSON_PARSER_H
+
+/* JSON_parser.h */
+
+
+#include <stddef.h>
+
+/* Windows DLL stuff */
+#ifdef _WIN32
+# ifdef JSON_PARSER_DLL_EXPORTS
+# define JSON_PARSER_DLL_API __declspec(dllexport)
+# else
+# define JSON_PARSER_DLL_API __declspec(dllimport)
+# endif
+#else
+# define JSON_PARSER_DLL_API
+#endif
+
+/* Determine the integer type use to parse non-floating point numbers */
+#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1
+typedef long long JSON_int_t;
+#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld"
+#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld"
+#else
+typedef long JSON_int_t;
+#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld"
+#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld"
+#endif
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+typedef enum
+{
+ JSON_T_NONE = 0,
+ JSON_T_ARRAY_BEGIN, // 1
+ JSON_T_ARRAY_END, // 2
+ JSON_T_OBJECT_BEGIN, // 3
+ JSON_T_OBJECT_END, // 4
+ JSON_T_INTEGER, // 5
+ JSON_T_FLOAT, // 6
+ JSON_T_NULL, // 7
+ JSON_T_TRUE, // 8
+ JSON_T_FALSE, // 9
+ JSON_T_STRING, // 10
+ JSON_T_KEY, // 11
+ JSON_T_MAX // 12
+} JSON_type;
+
+typedef struct JSON_value_struct {
+ union {
+ JSON_int_t integer_value;
+
+ long double float_value;
+
+ struct {
+ const char* value;
+ size_t length;
+ } str;
+ } vu;
+} JSON_value;
+
+typedef struct JSON_parser_struct* JSON_parser;
+
+/*! \brief JSON parser callback
+
+ \param ctx The pointer passed to new_JSON_parser.
+ \param type An element of JSON_type but not JSON_T_NONE.
+ \param value A representation of the parsed value. This parameter is NULL for
+ JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END,
+ JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned
+ as zero-terminated C strings.
+
+ \return Non-zero if parsing should continue, else zero.
+*/
+typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value);
+
+
+/*! \brief The structure used to configure a JSON parser object
+
+ \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise
+ the depth is the limit
+ \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity.
+ \param Callback context. This parameter may be NULL.
+ \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting.
+ \param allowComments. To allow C style comments in JSON, set to non-zero.
+ \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero.
+
+ \return The parser object.
+*/
+typedef struct {
+ JSON_parser_callback callback;
+ void* callback_ctx;
+ int depth;
+ int allow_comments;
+ int handle_floats_manually;
+} JSON_config;
+
+
+/*! \brief Initializes the JSON parser configuration structure to default values.
+
+ The default configuration is
+ - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c)
+ - no parsing, just checking for JSON syntax
+ - no comments
+
+ \param config. Used to configure the parser.
+*/
+JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config);
+
+/*! \brief Create a JSON parser object
+
+ \param config. Used to configure the parser. Set to NULL to use the default configuration.
+ See init_JSON_config
+
+ \return The parser object.
+*/
+JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config);
+
+/*! \brief Destroy a previously created JSON parser object. */
+JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc);
+
+/*! \brief Parse a character.
+
+ \return Non-zero, if all characters passed to this function are part of are valid JSON.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char);
+
+/*! \brief Finalize parsing.
+
+ Call this method once after all input characters have been consumed.
+
+ \return Non-zero, if all parsed characters are valid JSON, zero otherwise.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc);
+
+/*! \brief Determine if a given string is valid JSON white space
+
+ \return Non-zero if the string is valid, zero otherwise.
+*/
+JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s);
+
+
+#ifdef __cplusplus
+}
+#endif
+
+
+#endif /* JSON_PARSER_H */
diff --git a/src/Makefile.am b/src/Makefile.am
new file mode 100644
index 00000000..2af6cab7
--- /dev/null
+++ b/src/Makefile.am
@@ -0,0 +1,67 @@
+bin_PROGRAMS = \
+ dict_test \
+ weights_test \
+ trule_test \
+ hg_test \
+ ff_test \
+ parser_test \
+ grammar_test \
+ cdec \
+ small_vector_test
+
+cdec_SOURCES = cdec.cc forest_writer.cc maxtrans_blunsom.cc cdec_ff.cc ff_factory.cc timing_stats.cc
+small_vector_test_SOURCES = small_vector_test.cc
+small_vector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+parser_test_SOURCES = parser_test.cc
+parser_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+dict_test_SOURCES = dict_test.cc
+dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+ff_test_SOURCES = ff_test.cc
+ff_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+grammar_test_SOURCES = grammar_test.cc
+grammar_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+hg_test_SOURCES = hg_test.cc
+hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+trule_test_SOURCES = trule_test.cc
+trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+weights_test_SOURCES = weights_test.cc
+weights_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a
+
+LDADD = libhg.a
+
+AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS)
+AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz
+
+noinst_LIBRARIES = libhg.a
+
+libhg_a_SOURCES = \
+ fst_translator.cc \
+ scfg_translator.cc \
+ hg.cc \
+ hg_io.cc \
+ viterbi.cc \
+ lattice.cc \
+ logging.cc \
+ aligner.cc \
+ gzstream.cc \
+ apply_models.cc \
+ earley_composer.cc \
+ phrasetable_fst.cc \
+ sparse_vector.cc \
+ trule.cc \
+ filelib.cc \
+ stringlib.cc \
+ fdict.cc \
+ tdict.cc \
+ weights.cc \
+ ttables.cc \
+ ff.cc \
+ lm_ff.cc \
+ ff_wordalign.cc \
+ hg_intersect.cc \
+ lexcrf.cc \
+ bottom_up_parser.cc \
+ phrasebased_translator.cc \
+ JSON_parser.c \
+ json_parse.cc \
+ grammar.cc
diff --git a/src/aligner.cc b/src/aligner.cc
new file mode 100644
index 00000000..d9d067e5
--- /dev/null
+++ b/src/aligner.cc
@@ -0,0 +1,204 @@
+#include "aligner.h"
+
+#include "array2d.h"
+#include "hg.h"
+#include "inside_outside.h"
+#include <set>
+
+using namespace std;
+
+struct EdgeCoverageInfo {
+ set<int> src_indices;
+ set<int> trg_indices;
+};
+
+static bool is_digit(char x) { return x >= '0' && x <= '9'; }
+
+boost::shared_ptr<Array2D<bool> > AlignerTools::ReadPharaohAlignmentGrid(const string& al) {
+ int max_x = 0;
+ int max_y = 0;
+ int i = 0;
+ while (i < al.size()) {
+ int x = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ x *= 10;
+ x += al[i] - '0';
+ ++i;
+ }
+ if (x > max_x) max_x = x;
+ assert(i < al.size());
+ assert(al[i] == '-');
+ ++i;
+ int y = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ y *= 10;
+ y += al[i] - '0';
+ ++i;
+ }
+ if (y > max_y) max_y = y;
+ while(i < al.size() && al[i] == ' ') { ++i; }
+ }
+
+ boost::shared_ptr<Array2D<bool> > grid(new Array2D<bool>(max_x + 1, max_y + 1));
+ i = 0;
+ while (i < al.size()) {
+ int x = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ x *= 10;
+ x += al[i] - '0';
+ ++i;
+ }
+ assert(i < al.size());
+ assert(al[i] == '-');
+ ++i;
+ int y = 0;
+ while(i < al.size() && is_digit(al[i])) {
+ y *= 10;
+ y += al[i] - '0';
+ ++i;
+ }
+ (*grid)(x, y) = true;
+ while(i < al.size() && al[i] == ' ') { ++i; }
+ }
+ // cerr << *grid << endl;
+ return grid;
+}
+
+void AlignerTools::SerializePharaohFormat(const Array2D<bool>& alignment, ostream* out) {
+ bool need_space = false;
+ for (int i = 0; i < alignment.width(); ++i)
+ for (int j = 0; j < alignment.height(); ++j)
+ if (alignment(i,j)) {
+ if (need_space) (*out) << ' '; else need_space = true;
+ (*out) << i << '-' << j;
+ }
+ (*out) << endl;
+}
+
+// compute the coverage vectors of each edge
+// prereq: all derivations yield the same string pair
+void ComputeCoverages(const Hypergraph& g,
+ vector<EdgeCoverageInfo>* pcovs) {
+ for (int i = 0; i < g.edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g.edges_[i];
+ EdgeCoverageInfo& cov = (*pcovs)[i];
+ // no words
+ if (edge.rule_->EWords() == 0 || edge.rule_->FWords() == 0)
+ continue;
+ // aligned to NULL (crf ibm variant only)
+ if (edge.prev_i_ == -1 || edge.i_ == -1)
+ continue;
+ assert(edge.j_ >= 0);
+ assert(edge.prev_j_ >= 0);
+ if (edge.Arity() == 0) {
+ for (int k = edge.i_; k < edge.j_; ++k)
+ cov.trg_indices.insert(k);
+ for (int k = edge.prev_i_; k < edge.prev_j_; ++k)
+ cov.src_indices.insert(k);
+ } else {
+ // note: this code, which handles mixed NT and terminal
+ // rules assumes that nodes uniquely define a src and trg
+ // span.
+ int k = edge.prev_i_;
+ int j = 0;
+ const vector<WordID>& f = edge.rule_->e(); // rules are inverted
+ while (k < edge.prev_j_) {
+ if (f[j] > 0) {
+ cov.src_indices.insert(k);
+ // cerr << "src: " << k << endl;
+ ++k;
+ ++j;
+ } else {
+ const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[-f[j]]];
+ assert(tailnode.in_edges_.size() > 0);
+ // any edge will do:
+ const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()];
+ //cerr << "skip " << (rep_edge.prev_j_ - rep_edge.prev_i_) << endl; // src span
+ k += (rep_edge.prev_j_ - rep_edge.prev_i_); // src span
+ ++j;
+ }
+ }
+ int tc = 0;
+ const vector<WordID>& e = edge.rule_->f(); // rules are inverted
+ k = edge.i_;
+ j = 0;
+ // cerr << edge.rule_->AsString() << endl;
+ // cerr << "i=" << k << " j=" << edge.j_ << endl;
+ while (k < edge.j_) {
+ //cerr << " k=" << k << endl;
+ if (e[j] > 0) {
+ cov.trg_indices.insert(k);
+ // cerr << "trg: " << k << endl;
+ ++k;
+ ++j;
+ } else {
+ assert(tc < edge.tail_nodes_.size());
+ const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[tc]];
+ assert(tailnode.in_edges_.size() > 0);
+ // any edge will do:
+ const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()];
+ // cerr << "t skip " << (rep_edge.j_ - rep_edge.i_) << endl; // src span
+ k += (rep_edge.j_ - rep_edge.i_); // src span
+ ++j;
+ ++tc;
+ }
+ }
+ //abort();
+ }
+ }
+}
+
+void AlignerTools::WriteAlignment(const string& input,
+ const Lattice& ref,
+ const Hypergraph& g,
+ bool map_instead_of_viterbi) {
+ if (!map_instead_of_viterbi) {
+ assert(!"not implemented!");
+ }
+ vector<prob_t> edge_posteriors(g.edges_.size());
+ {
+ SparseVector<prob_t> posts;
+ InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(g, &posts);
+ for (int i = 0; i < edge_posteriors.size(); ++i)
+ edge_posteriors[i] = posts[i];
+ }
+ vector<EdgeCoverageInfo> edge2cov(g.edges_.size());
+ ComputeCoverages(g, &edge2cov);
+
+ Lattice src;
+ // currently only dealing with src text, even if the
+ // model supports lattice translation (which it probably does)
+ LatticeTools::ConvertTextToLattice(input, &src);
+ // TODO assert that src is a "real lattice"
+
+ Array2D<prob_t> align(src.size(), ref.size(), prob_t::Zero());
+ for (int c = 0; c < g.edges_.size(); ++c) {
+ const prob_t& p = edge_posteriors[c];
+ const EdgeCoverageInfo& eci = edge2cov[c];
+ for (set<int>::const_iterator si = eci.src_indices.begin();
+ si != eci.src_indices.end(); ++si) {
+ for (set<int>::const_iterator ti = eci.trg_indices.begin();
+ ti != eci.trg_indices.end(); ++ti) {
+ align(*si, *ti) += p;
+ }
+ }
+ }
+ prob_t threshold(0.9);
+ const bool use_soft_threshold = true; // TODO configure
+
+ Array2D<bool> grid(src.size(), ref.size(), false);
+ for (int j = 0; j < ref.size(); ++j) {
+ if (use_soft_threshold) {
+ threshold = prob_t::Zero();
+ for (int i = 0; i < src.size(); ++i)
+ if (align(i, j) > threshold) threshold = align(i, j);
+ //threshold *= prob_t(0.99);
+ }
+ for (int i = 0; i < src.size(); ++i)
+ grid(i, j) = align(i, j) >= threshold;
+ }
+ cerr << align << endl;
+ cerr << grid << endl;
+ SerializePharaohFormat(grid, &cout);
+};
+
diff --git a/src/aligner.h b/src/aligner.h
new file mode 100644
index 00000000..970c72f2
--- /dev/null
+++ b/src/aligner.h
@@ -0,0 +1,23 @@
+#ifndef _ALIGNER_H_
+
+#include <string>
+#include <iostream>
+#include <boost/shared_ptr.hpp>
+#include "array2d.h"
+#include "lattice.h"
+
+class Hypergraph;
+
+struct AlignerTools {
+ static boost::shared_ptr<Array2D<bool> > ReadPharaohAlignmentGrid(const std::string& al);
+ static void SerializePharaohFormat(const Array2D<bool>& alignment, std::ostream* out);
+
+ // assumption: g contains derivations of input/ref and
+ // ONLY input/ref.
+ static void WriteAlignment(const std::string& input,
+ const Lattice& ref,
+ const Hypergraph& g,
+ bool map_instead_of_viterbi = true);
+};
+
+#endif
diff --git a/src/apply_models.cc b/src/apply_models.cc
new file mode 100644
index 00000000..8efb331b
--- /dev/null
+++ b/src/apply_models.cc
@@ -0,0 +1,344 @@
+#include "apply_models.h"
+
+#include <vector>
+#include <algorithm>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/functional/hash.hpp>
+
+#include "hg.h"
+#include "ff.h"
+
+using namespace std;
+using namespace std::tr1;
+
+struct Candidate;
+typedef SmallVector JVector;
+typedef vector<Candidate*> CandidateHeap;
+typedef vector<Candidate*> CandidateList;
+
+// life cycle: candidates are created, placed on the heap
+// and retrieved by their estimated cost, when they're
+// retrieved, they're incorporated into the +LM hypergraph
+// where they also know the head node index they are
+// attached to. After they are added to the +LM hypergraph
+// vit_prob_ and est_prob_ fields may be updated as better
+// derivations are found (this happens since the successor's
+// of derivation d may have a better score- they are
+// explored lazily). However, the updates don't happen
+// when a candidate is in the heap so maintaining the heap
+// property is not an issue.
+struct Candidate {
+ int node_index_; // -1 until incorporated
+ // into the +LM forest
+ const Hypergraph::Edge* in_edge_; // in -LM forest
+ Hypergraph::Edge out_edge_;
+ string state_;
+ const JVector j_;
+ prob_t vit_prob_; // these are fixed until the cand
+ // is popped, then they may be updated
+ prob_t est_prob_;
+
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j,
+ const Hypergraph& out_hg,
+ const vector<CandidateList>& D,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ bool is_goal) :
+ node_index_(-1),
+ in_edge_(&e),
+ j_(j) {
+ InitializeCandidate(out_hg, smeta, D, models, is_goal);
+ }
+
+ // used to query uniqueness
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j) : in_edge_(&e), j_(j) {}
+
+ bool IsIncorporatedIntoHypergraph() const {
+ return node_index_ >= 0;
+ }
+
+ void InitializeCandidate(const Hypergraph& out_hg,
+ const SentenceMetadata& smeta,
+ const vector<vector<Candidate*> >& D,
+ const ModelSet& models,
+ const bool is_goal) {
+ const Hypergraph::Edge& in_edge = *in_edge_;
+ out_edge_.rule_ = in_edge.rule_;
+ out_edge_.feature_values_ = in_edge.feature_values_;
+ out_edge_.i_ = in_edge.i_;
+ out_edge_.j_ = in_edge.j_;
+ out_edge_.prev_i_ = in_edge.prev_i_;
+ out_edge_.prev_j_ = in_edge.prev_j_;
+ Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_;
+ tail.resize(j_.size());
+ prob_t p = prob_t::One();
+ // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl;
+ for (int i = 0; i < tail.size(); ++i) {
+ const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]];
+ assert(ant.IsIncorporatedIntoHypergraph());
+ tail[i] = ant.node_index_;
+ p *= ant.vit_prob_;
+ }
+ prob_t edge_estimate = prob_t::One();
+ if (is_goal) {
+ assert(tail.size() == 1);
+ const string& ant_state = out_hg.nodes_[tail.front()].state_;
+ models.AddFinalFeatures(ant_state, &out_edge_);
+ } else {
+ models.AddFeaturesToEdge(smeta, out_hg, &out_edge_, &state_, &edge_estimate);
+ }
+ vit_prob_ = out_edge_.edge_prob_ * p;
+ est_prob_ = vit_prob_ * edge_estimate;
+ }
+};
+
+ostream& operator<<(ostream& os, const Candidate& cand) {
+ os << "CAND[";
+ if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; }
+ else { os << "+LM_node=" << cand.node_index_; }
+ os << " edge=" << cand.in_edge_->id_;
+ os << " j=<";
+ for (int i = 0; i < cand.j_.size(); ++i)
+ os << (i==0 ? "" : " ") << cand.j_[i];
+ os << "> vit=" << log(cand.vit_prob_);
+ os << " est=" << log(cand.est_prob_);
+ return os << ']';
+}
+
+struct HeapCandCompare {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ < r->est_prob_;
+ }
+};
+
+struct EstProbSorter {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ > r->est_prob_;
+ }
+};
+
+// the same candidate <edge, j> can be added multiple times if
+// j is multidimensional (if you're going NW in Manhattan, you
+// can first go north, then west, or you can go west then north)
+// this is a hash function on the relevant variables from
+// Candidate to enforce this.
+struct CandidateUniquenessHash {
+ size_t operator()(const Candidate* c) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ c->in_edge_->id_;
+ for (int i = 0; i < c->j_.size(); ++i)
+ x = ((x << 5) + x) ^ c->j_[i];
+ return x;
+ }
+};
+
+struct CandidateUniquenessEquals {
+ bool operator()(const Candidate* a, const Candidate* b) const {
+ return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_);
+ }
+};
+
+typedef unordered_set<const Candidate*, CandidateUniquenessHash, CandidateUniquenessEquals> UniqueCandidateSet;
+typedef unordered_map<string, Candidate*, boost::hash<string> > State2Node;
+
+class CubePruningRescorer {
+
+public:
+ CubePruningRescorer(const ModelSet& m,
+ const SentenceMetadata& sm,
+ const Hypergraph& i,
+ int pop_limit,
+ Hypergraph* o) :
+ models(m),
+ smeta(sm),
+ in(i),
+ out(*o),
+ D(in.nodes_.size()),
+ pop_limit_(pop_limit) {
+ cerr << " Rescoring forest (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ int every = 1;
+ if (num_nodes > 100) every = 10;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ if (i % every == 0) cerr << '.';
+ KBest(i, i == goal_id);
+ }
+ cerr << endl;
+ cerr << " Best path: " << log(D[goal_id].front()->vit_prob_)
+ << "\t" << log(D[goal_id].front()->est_prob_) << endl;
+ out.PruneUnreachable(D[goal_id].front()->node_index_);
+ FreeAll();
+ }
+
+ private:
+ void FreeAll() {
+ for (int i = 0; i < D.size(); ++i) {
+ CandidateList& D_i = D[i];
+ for (int j = 0; j < D_i.size(); ++j)
+ delete D_i[j];
+ }
+ D.clear();
+ }
+
+ void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) {
+ Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_);
+ new_edge->feature_values_ = item->out_edge_.feature_values_;
+ new_edge->edge_prob_ = item->out_edge_.edge_prob_;
+ new_edge->i_ = item->out_edge_.i_;
+ new_edge->j_ = item->out_edge_.j_;
+ new_edge->prev_i_ = item->out_edge_.prev_i_;
+ new_edge->prev_j_ = item->out_edge_.prev_j_;
+ Candidate*& o_item = (*s2n)[item->state_];
+ if (!o_item) o_item = item;
+
+ int& node_id = o_item->node_index_;
+ if (node_id < 0) {
+ Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, item->state_);
+ node_id = new_node->id_;
+ }
+ Hypergraph::Node* node = &out.nodes_[node_id];
+ out.ConnectEdgeToHeadNode(new_edge, node);
+
+ // update candidate if we have a better derivation
+ // note: the difference between the vit score and the estimated
+ // score is the same for all items with a common residual DP
+ // state
+ if (item->vit_prob_ > o_item->vit_prob_) {
+ assert(o_item->state_ == item->state_); // sanity check!
+ o_item->est_prob_ = item->est_prob_;
+ o_item->vit_prob_ = item->vit_prob_;
+ }
+ if (item != o_item) freelist->push_back(item);
+ }
+
+ void KBest(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_cands;
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, out, D, smeta, models, is_goal));
+ assert(unique_cands.insert(cand.back()).second); // these should all be unique!
+ }
+// cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+ PushSucc(*item, is_goal, &cand, &unique_cands);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i)
+ D_v[c++] = i->second;
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, smeta, models, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+
+ const ModelSet& models;
+ const SentenceMetadata& smeta;
+ const Hypergraph& in;
+ Hypergraph& out;
+
+ vector<CandidateList> D; // maps nodes in in-HG to the
+ // equivalent nodes (many due to state
+ // splits) in the out-HG.
+ const int pop_limit_;
+};
+
+struct NoPruningRescorer {
+ NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) :
+ models(m),
+ in(i),
+ out(*o) {
+ cerr << " Rescoring forest (full intersection)\n";
+ }
+
+ void RescoreNode(const int node_num, const bool is_goal) {
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ int every = 1;
+ if (num_nodes > 100) every = 10;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ if (i % every == 0) cerr << '.';
+ RescoreNode(i, i == goal_id);
+ }
+ cerr << endl;
+ }
+
+ private:
+ const ModelSet& models;
+ const Hypergraph& in;
+ Hypergraph& out;
+};
+
+// each node in the graph has one of these, it keeps track of
+void ApplyModelSet(const Hypergraph& in,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ const PruningConfiguration& config,
+ Hypergraph* out) {
+ int pl = config.pop_limit;
+ if (pl > 100 && in.nodes_.size() > 80000) {
+ cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
+ pl = 30;
+ }
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+}
+
diff --git a/src/apply_models.h b/src/apply_models.h
new file mode 100644
index 00000000..08fce037
--- /dev/null
+++ b/src/apply_models.h
@@ -0,0 +1,20 @@
+#ifndef _APPLY_MODELS_H_
+#define _APPLY_MODELS_H_
+
+struct ModelSet;
+struct Hypergraph;
+struct SentenceMetadata;
+
+struct PruningConfiguration {
+ const int algorithm; // 0 = full intersection, 1 = cube pruning
+ const int pop_limit; // max number of pops off the heap at each node
+ explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {}
+};
+
+void ApplyModelSet(const Hypergraph& in,
+ const SentenceMetadata& smeta,
+ const ModelSet& models,
+ const PruningConfiguration& config,
+ Hypergraph* out);
+
+#endif
diff --git a/src/array2d.h b/src/array2d.h
new file mode 100644
index 00000000..09d84d0b
--- /dev/null
+++ b/src/array2d.h
@@ -0,0 +1,171 @@
+#ifndef ARRAY2D_H_
+#define ARRAY2D_H_
+
+#include <iostream>
+#include <algorithm>
+#include <cassert>
+#include <vector>
+#include <string>
+
+template<typename T>
+class Array2D {
+ public:
+ typedef typename std::vector<T>::reference reference;
+ typedef typename std::vector<T>::const_reference const_reference;
+ typedef typename std::vector<T>::iterator iterator;
+ typedef typename std::vector<T>::const_iterator const_iterator;
+ Array2D() : width_(0), height_(0) {}
+ Array2D(int w, int h, const T& d = T()) :
+ width_(w), height_(h), data_(w*h, d) {}
+ Array2D(const Array2D& rhs) :
+ width_(rhs.width_), height_(rhs.height_), data_(rhs.data_) {}
+ void resize(int w, int h, const T& d = T()) {
+ data_.resize(w * h, d);
+ width_ = w;
+ height_ = h;
+ }
+ const Array2D& operator=(const Array2D& rhs) {
+ data_ = rhs.data_;
+ width_ = rhs.width_;
+ height_ = rhs.height_;
+ return *this;
+ }
+ void fill(const T& v) { data_.assign(data_.size(), v); }
+ int width() const { return width_; }
+ int height() const { return height_; }
+ reference operator()(int i, int j) {
+ return data_[offset(i, j)];
+ }
+ void clear() { data_.clear(); width_=0; height_=0; }
+ const_reference operator()(int i, int j) const {
+ return data_[offset(i, j)];
+ }
+ iterator begin_col(int j) {
+ return data_.begin() + offset(0,j);
+ }
+ const_iterator begin_col(int j) const {
+ return data_.begin() + offset(0,j);
+ }
+ iterator end_col(int j) {
+ return data_.begin() + offset(0,j) + width_;
+ }
+ const_iterator end_col(int j) const {
+ return data_.begin() + offset(0,j) + width_;
+ }
+ iterator end() { return data_.end(); }
+ const_iterator end() const { return data_.end(); }
+ const Array2D<T>& operator*=(const T& x) {
+ std::transform(data_.begin(), data_.end(), data_.begin(),
+ std::bind2nd(std::multiplies<T>(), x));
+ }
+ const Array2D<T>& operator/=(const T& x) {
+ std::transform(data_.begin(), data_.end(), data_.begin(),
+ std::bind2nd(std::divides<T>(), x));
+ }
+ const Array2D<T>& operator+=(const Array2D<T>& m) {
+ std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::plus<T>());
+ }
+ const Array2D<T>& operator-=(const Array2D<T>& m) {
+ std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::minus<T>());
+ }
+
+ private:
+ inline int offset(int i, int j) const {
+ assert(i<width_);
+ assert(j<height_);
+ return i + j * width_;
+ }
+
+ int width_;
+ int height_;
+
+ std::vector<T> data_;
+};
+
+template <typename T>
+Array2D<T> operator*(const Array2D<T>& l, const T& scalar) {
+ Array2D<T> res(l);
+ res *= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator*(const T& scalar, const Array2D<T>& l) {
+ Array2D<T> res(l);
+ res *= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator/(const Array2D<T>& l, const T& scalar) {
+ Array2D<T> res(l);
+ res /= scalar;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator+(const Array2D<T>& l, const Array2D<T>& r) {
+ Array2D<T> res(l);
+ res += r;
+ return res;
+}
+
+template <typename T>
+Array2D<T> operator-(const Array2D<T>& l, const Array2D<T>& r) {
+ Array2D<T> res(l);
+ res -= r;
+ return res;
+}
+
+template <typename T>
+inline std::ostream& operator<<(std::ostream& os, const Array2D<T>& m) {
+ for (int i=0; i<m.width(); ++i) {
+ for (int j=0; j<m.height(); ++j)
+ os << '\t' << m(i,j);
+ os << '\n';
+ }
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Array2D<bool>& m) {
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10);
+ os << "\n";
+ for (int i=0; i<m.width(); ++i) {
+ os << (i%10);
+ for (int j=0; j<m.height(); ++j)
+ os << (m(i,j) ? '*' : '.');
+ os << (i%10) << "\n";
+ }
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10);
+ os << "\n";
+ return os;
+}
+
+inline std::ostream& operator<<(std::ostream& os, const Array2D<std::vector<bool> >& m) {
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10) << "\t";
+ os << "\n";
+ for (int i=0; i<m.width(); ++i) {
+ os << (i%10);
+ for (int j=0; j<m.height(); ++j) {
+ const std::vector<bool>& ar = m(i,j);
+ for (int k=0; k<ar.size(); ++k)
+ os << (ar[k] ? '*' : '.');
+ }
+ os << "\t";
+ os << (i%10) << "\n";
+ }
+ os << ' ';
+ for (int j=0; j<m.height(); ++j)
+ os << (j%10) << "\t";
+ os << "\n";
+ return os;
+}
+
+#endif
+
diff --git a/src/bottom_up_parser.cc b/src/bottom_up_parser.cc
new file mode 100644
index 00000000..349ed2de
--- /dev/null
+++ b/src/bottom_up_parser.cc
@@ -0,0 +1,260 @@
+#include "bottom_up_parser.h"
+
+#include <map>
+
+#include "hg.h"
+#include "array2d.h"
+#include "tdict.h"
+
+using namespace std;
+
+class ActiveChart;
+class PassiveChart {
+ public:
+ PassiveChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest);
+ ~PassiveChart();
+
+ inline const vector<int>& operator()(int i, int j) const { return chart_(i,j); }
+ bool Parse();
+ inline int size() const { return chart_.width(); }
+ inline bool GoalFound() const { return goal_idx_ >= 0; }
+ inline int GetGoalIndex() const { return goal_idx_; }
+
+ private:
+ void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail);
+ void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes);
+ void ApplyUnaryRules(const int i, const int j);
+
+ const vector<GrammarPtr>& grammars_;
+ const Lattice& input_;
+ Hypergraph* forest_;
+ Array2D<vector<int> > chart_; // chart_(i,j) is the list of nodes derived spanning i,j
+ typedef map<int, int> Cat2NodeMap;
+ Array2D<Cat2NodeMap> nodemap_;
+ vector<ActiveChart*> act_chart_;
+ const WordID goal_cat_; // category that is being searched for at [0,n]
+ TRulePtr goal_rule_;
+ int goal_idx_; // index of goal node, if found
+
+ static WordID kGOAL; // [Goal]
+};
+
+WordID PassiveChart::kGOAL = 0;
+
+class ActiveChart {
+ public:
+ ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) :
+ hg_(hg),
+ act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {}
+
+ struct ActiveItem {
+ ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) :
+ gptr_(g), ant_nodes_(a), lattice_cost(lcost) {}
+ explicit ActiveItem(const GrammarIter* g) :
+ gptr_(g), ant_nodes_(), lattice_cost() {}
+
+ void ExtendTerminal(int symbol, double src_cost, vector<ActiveItem>* out_cell) const {
+ const GrammarIter* ni = gptr_->Extend(symbol);
+ if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
+ }
+ void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const {
+ int symbol = hg->nodes_[node_index].cat_;
+ const GrammarIter* ni = gptr_->Extend(symbol);
+ if (!ni) return;
+ Hypergraph::TailNodeVector na(ant_nodes_.size() + 1);
+ for (int i = 0; i < ant_nodes_.size(); ++i)
+ na[i] = ant_nodes_[i];
+ na[ant_nodes_.size()] = node_index;
+ out_cell->push_back(ActiveItem(ni, na, lattice_cost));
+ }
+
+ const GrammarIter* gptr_;
+ Hypergraph::TailNodeVector ant_nodes_;
+ double lattice_cost; // TODO? use SparseVector<double>
+ };
+
+ inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); }
+ void SeedActiveChart(const Grammar& g) {
+ int size = act_chart_.width();
+ for (int i = 0; i < size; ++i)
+ if (g.HasRuleForSpan(i,i))
+ act_chart_(i,i).push_back(ActiveItem(g.GetRoot()));
+ }
+
+ void ExtendActiveItems(int i, int k, int j) {
+ //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n";
+ vector<ActiveItem>& cell = act_chart_(i,j);
+ const vector<ActiveItem>& icell = act_chart_(i,k);
+ const vector<int>& idxs = psv_chart_(k, j);
+ //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; }
+ for (vector<ActiveItem>::const_iterator di = icell.begin(); di != icell.end(); ++di) {
+ for (vector<int>::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) {
+ di->ExtendNonTerminal(hg_, *ni, &cell);
+ }
+ }
+ }
+
+ void AdvanceDotsForAllItemsInCell(int i, int j, const vector<vector<LatticeArc> >& input) {
+ //cerr << "ADVANCE(" << i << "," << j << ")\n";
+ for (int k=i+1; k < j; ++k)
+ ExtendActiveItems(i, k, j);
+
+ const vector<LatticeArc>& out_arcs = input[j-1];
+ for (vector<LatticeArc>::const_iterator ai = out_arcs.begin();
+ ai != out_arcs.end(); ++ai) {
+ const WordID& f = ai->label;
+ const double& c = ai->cost;
+ const int& len = ai->dist2next;
+ //VLOG(1) << "F: " << TD::Convert(f) << endl;
+ const vector<ActiveItem>& ec = act_chart_(i, j-1);
+ for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di)
+ di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1));
+ }
+ }
+
+ private:
+ const Hypergraph* hg_;
+ Array2D<vector<ActiveItem> > act_chart_;
+ const PassiveChart& psv_chart_;
+};
+
+PassiveChart::PassiveChart(const string& goal,
+ const vector<GrammarPtr>& grammars,
+ const Lattice& input,
+ Hypergraph* forest) :
+ grammars_(grammars),
+ input_(input),
+ forest_(forest),
+ chart_(input.size()+1, input.size()+1),
+ nodemap_(input.size()+1, input.size()+1),
+ goal_cat_(TD::Convert(goal) * -1),
+ goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
+ goal_idx_(-1) {
+ act_chart_.resize(grammars_.size());
+ for (int i = 0; i < grammars_.size(); ++i)
+ act_chart_[i] = new ActiveChart(forest, *this);
+ if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
+ cerr << " Goal category: [" << goal << ']' << endl;
+}
+
+void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) {
+ Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
+ new_edge->prev_i_ = r->prev_i;
+ new_edge->prev_j_ = r->prev_j;
+ new_edge->i_ = i;
+ new_edge->j_ = j;
+ new_edge->feature_values_ = r->GetFeatureValues();
+ Cat2NodeMap& c2n = nodemap_(i,j);
+ const bool is_goal = (r->GetLHS() == kGOAL);
+ const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
+ Hypergraph::Node* node = NULL;
+ if (ni == c2n.end()) {
+ node = forest_->AddNode(r->GetLHS(), "");
+ c2n[r->GetLHS()] = node->id_;
+ if (is_goal) {
+ assert(goal_idx_ == -1);
+ goal_idx_ = node->id_;
+ } else {
+ chart_(i,j).push_back(node->id_);
+ }
+ } else {
+ node = &forest_->nodes_[ni->second];
+ }
+ forest_->ConnectEdgeToHeadNode(new_edge, node);
+}
+
+void PassiveChart::ApplyRules(const int i,
+ const int j,
+ const RuleBin* rules,
+ const Hypergraph::TailNodeVector& tail) {
+ const int n = rules->GetNumRules();
+ for (int k = 0; k < n; ++k)
+ ApplyRule(i, j, rules->GetIthRule(k), tail);
+}
+
+void PassiveChart::ApplyUnaryRules(const int i, const int j) {
+ const vector<int>& nodes = chart_(i,j); // reference is important!
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ if (!grammars_[gi]->HasRuleForSpan(i,j)) continue;
+ for (int di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
+ for (int ri = 0; ri < unaries.size(); ++ri) {
+ // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ const Hypergraph::TailNodeVector ant(1, nodes[di]);
+ ApplyRule(i, j, unaries[ri], ant); // may update nodes
+ }
+ }
+ }
+}
+
+bool PassiveChart::Parse() {
+ forest_->nodes_.reserve(input_.size() * input_.size() * 2);
+ forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation??
+ goal_idx_ = -1;
+ for (int gi = 0; gi < grammars_.size(); ++gi)
+ act_chart_[gi]->SeedActiveChart(*grammars_[gi]);
+
+ cerr << " ";
+ for (int l=1; l<input_.size()+1; ++l) {
+ cerr << '.';
+ for (int i=0; i<input_.size() + 1 - l; ++i) {
+ int j = i + l;
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ const Grammar& g = *grammars_[gi];
+ if (g.HasRuleForSpan(i, j)) {
+ act_chart_[gi]->AdvanceDotsForAllItemsInCell(i, j, input_);
+
+ const vector<ActiveChart::ActiveItem>& cell = (*act_chart_[gi])(i,j);
+ for (vector<ActiveChart::ActiveItem>::const_iterator ai = cell.begin();
+ ai != cell.end(); ++ai) {
+ const RuleBin* rules = (ai->gptr_->GetRules());
+ if (!rules) continue;
+ ApplyRules(i, j, rules, ai->ant_nodes_);
+ }
+ }
+ }
+ ApplyUnaryRules(i,j);
+
+ for (int gi = 0; gi < grammars_.size(); ++gi) {
+ const Grammar& g = *grammars_[gi];
+ // deal with non-terminals that were just proved
+ if (g.HasRuleForSpan(i, j))
+ act_chart_[gi]->ExtendActiveItems(i, i, j);
+ }
+ }
+ const vector<int>& dh = chart_(0, input_.size());
+ for (int di = 0; di < dh.size(); ++di) {
+ const Hypergraph::Node& node = forest_->nodes_[dh[di]];
+ if (node.cat_ == goal_cat_) {
+ Hypergraph::TailNodeVector ant(1, node.id_);
+ ApplyRule(0, input_.size(), goal_rule_, ant);
+ }
+ }
+ }
+ cerr << endl;
+
+ if (GoalFound())
+ forest_->PruneUnreachable(forest_->nodes_.size() - 1);
+ return GoalFound();
+}
+
+PassiveChart::~PassiveChart() {
+ for (int i = 0; i < act_chart_.size(); ++i)
+ delete act_chart_[i];
+}
+
+ExhaustiveBottomUpParser::ExhaustiveBottomUpParser(
+ const string& goal_sym,
+ const vector<GrammarPtr>& grammars) :
+ goal_sym_(goal_sym),
+ grammars_(grammars) {}
+
+bool ExhaustiveBottomUpParser::Parse(const Lattice& input,
+ Hypergraph* forest) const {
+ PassiveChart chart(goal_sym_, grammars_, input, forest);
+ return chart.Parse();
+}
diff --git a/src/bottom_up_parser.h b/src/bottom_up_parser.h
new file mode 100644
index 00000000..546bfb54
--- /dev/null
+++ b/src/bottom_up_parser.h
@@ -0,0 +1,27 @@
+#ifndef _BOTTOM_UP_PARSER_H_
+#define _BOTTOM_UP_PARSER_H_
+
+#include <vector>
+#include <string>
+
+#include "lattice.h"
+#include "grammar.h"
+
+class Hypergraph;
+
+class ExhaustiveBottomUpParser {
+ public:
+ ExhaustiveBottomUpParser(const std::string& goal_sym,
+ const std::vector<GrammarPtr>& grammars);
+
+ // returns true if goal reached spanning the full input
+ // forest contains the full (i.e., unpruned) parse forest
+ bool Parse(const Lattice& input,
+ Hypergraph* forest) const;
+
+ private:
+ const std::string goal_sym_;
+ const std::vector<GrammarPtr> grammars_;
+};
+
+#endif
diff --git a/src/cdec.cc b/src/cdec.cc
new file mode 100644
index 00000000..c5780cef
--- /dev/null
+++ b/src/cdec.cc
@@ -0,0 +1,474 @@
+#include <iostream>
+#include <fstream>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "timing_stats.h"
+#include "translator.h"
+#include "phrasebased_translator.h"
+#include "aligner.h"
+#include "stringlib.h"
+#include "forest_writer.h"
+#include "filelib.h"
+#include "sampler.h"
+#include "sparse_vector.h"
+#include "lexcrf.h"
+#include "weights.h"
+#include "tdict.h"
+#include "ff.h"
+#include "ff_factory.h"
+#include "hg_intersect.h"
+#include "apply_models.h"
+#include "viterbi.h"
+#include "kbest.h"
+#include "inside_outside.h"
+#include "exp_semiring.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+using namespace std::tr1;
+using boost::shared_ptr;
+namespace po = boost::program_options;
+
+// some globals ...
+boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
+
+namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); }
+
+void ShowBanner() {
+ cerr << "cdec v1.0 (c) 2009 by Chris Dyer\n";
+}
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment")
+ ("input,i",po::value<string>()->default_value("-"),"Source file")
+ ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
+ ("weights,w",po::value<string>(),"Feature weights file")
+ ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
+ ("list_feature_functions,L","List available feature functions")
+ ("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
+ ("k_best,k",po::value<int>(),"Extract the k best derivations")
+ ("unique_k_best,r", "Unique k-best translation list")
+ ("aligner,a", "Run as a word/phrase aligner (src & ref required)")
+ ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
+ ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
+ ("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)")
+ ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)")
+ ("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG")
+ ("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)")
+ ("show_tree_structure,T", "Show the Viterbi derivation structure")
+ ("show_expected_length", "Show the expected translation length under the model")
+ ("show_partition,z", "Compute and show the partition (inside score)")
+ ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
+ ("graphviz","Show (constrained) translation forest in GraphViz format")
+ ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
+ ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
+ ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")
+ ("gradient,G","Compute d log p(e|f) / d lambda_i and write to STDOUT (src & ref required)")
+ ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")
+ ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
+ ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
+ ("forest_output,O",po::value<string>(),"Directory to write forests to")
+ ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc.");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ const string cfg = (*conf)["config"].as<string>();
+ cerr << "Configuration file: " << cfg << endl;
+ ifstream config(cfg.c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("list_feature_functions")) {
+ cerr << "Available feature functions (specify with -F):\n";
+ global_ff_registry->DisplayList();
+ cerr << endl;
+ exit(1);
+ }
+
+ if (conf->count("help") || conf->count("grammar") == 0) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+
+ const string formalism = LowercaseString((*conf)["formalism"].as<string>());
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n";
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+// TODO move out of cdec into some sampling decoder file
+void SampleRecurse(const Hypergraph& hg, const vector<SampleSet>& ss, int n, vector<WordID>* out) {
+ const SampleSet& s = ss[n];
+ int i = rng->SelectSample(s);
+ const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]];
+ vector<vector<WordID> > ants(edge.tail_nodes_.size());
+ for (int j = 0; j < ants.size(); ++j)
+ SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]);
+
+ vector<const vector<WordID>*> pants(ants.size());
+ for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j];
+ edge.rule_->ESubstitute(pants, out);
+}
+
+struct SampleSort {
+ bool operator()(const pair<int,string>& a, const pair<int,string>& b) const {
+ return a.first > b.first;
+ }
+};
+
+// TODO move out of cdec into some sampling decoder file
+void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) {
+ unordered_map<string, int, boost::hash<string> > m;
+ hg->PushWeightsToGoal();
+ const int num_nodes = hg->nodes_.size();
+ vector<SampleSet> ss(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ SampleSet& s = ss[i];
+ const vector<int>& in_edges = hg->nodes_[i].in_edges_;
+ for (int j = 0; j < in_edges.size(); ++j) {
+ s.add(hg->edges_[in_edges[j]].edge_prob_);
+ }
+ }
+ for (int i = 0; i < samples; ++i) {
+ vector<WordID> yield;
+ SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield);
+ const string trans = TD::GetString(yield);
+ ++m[trans];
+ }
+ vector<pair<int, string> > dist;
+ for (unordered_map<string, int, boost::hash<string> >::iterator i = m.begin();
+ i != m.end(); ++i) {
+ dist.push_back(make_pair(i->second, i->first));
+ }
+ sort(dist.begin(), dist.end(), SampleSort());
+ if (k) {
+ for (int i = 0; i < k; ++i)
+ cout << dist[i].first << " ||| " << dist[i].second << endl;
+ } else {
+ cout << dist[0].second << endl;
+ }
+}
+
+// TODO decoder output should probably be moved to another file
+void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) {
+ if (unique) {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal, KBest::FilterUnique>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ } else {
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(forest, k);
+ for (int i = 0; i < k; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(forest.nodes_.size() - 1, i);
+ if (!d) break;
+ cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl;
+ }
+ }
+}
+
+struct ELengthWeightFunction {
+ double operator()(const Hypergraph::Edge& e) const {
+ return e.rule_->ELength() - e.rule_->Arity();
+ }
+};
+
+
+struct TRPHash {
+ size_t operator()(const TRulePtr& o) const { return reinterpret_cast<size_t>(o.get()); }
+};
+static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) {
+ static unordered_set<TRulePtr, TRPHash> written;
+ for (int i = 0; i < hg.edges_.size(); ++i) {
+ const TRulePtr& rule = hg.edges_[i].rule_;
+ if (written.insert(rule).second) {
+ (*os) << rule->AsString() << endl;
+ }
+ }
+}
+
+void register_feature_functions();
+
+int main(int argc, char** argv) {
+ global_ff_registry.reset(new FFRegistry);
+ register_feature_functions();
+ ShowBanner();
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const bool write_gradient = conf.count("gradient");
+ const bool feature_expectations = conf.count("feature_expectations");
+ if (write_gradient && feature_expectations) {
+ cerr << "You can only specify --gradient or --feature_expectations, not both!\n";
+ exit(1);
+ }
+ const bool output_training_vector = (write_gradient || feature_expectations);
+
+ boost::shared_ptr<Translator> translator;
+ const string formalism = LowercaseString(conf["formalism"].as<string>());
+ if (formalism == "scfg")
+ translator.reset(new SCFGTranslator(conf));
+ else if (formalism == "fst")
+ translator.reset(new FSTTranslator(conf));
+ else if (formalism == "pb")
+ translator.reset(new PhraseBasedTranslator(conf));
+ else if (formalism == "lexcrf")
+ translator.reset(new LexicalCRF(conf));
+ else
+ assert(!"error");
+
+ vector<double> wv;
+ Weights w;
+ if (conf.count("weights")) {
+ w.InitFromFile(conf["weights"].as<string>());
+ wv.resize(FD::NumFeats());
+ w.InitVector(&wv);
+ }
+
+ // set up additional scoring features
+ vector<shared_ptr<FeatureFunction> > pffs;
+ vector<const FeatureFunction*> late_ffs;
+ if (conf.count("feature_function") > 0) {
+ const vector<string>& add_ffs = conf["feature_function"].as<vector<string> >();
+ for (int i = 0; i < add_ffs.size(); ++i) {
+ string ff, param;
+ SplitCommandAndParam(add_ffs[i], &ff, &param);
+ if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n";
+ else cerr << " (no config parameters)\n";
+ shared_ptr<FeatureFunction> pff = global_ff_registry->Create(ff, param);
+ if (!pff) { exit(1); }
+ // TODO check that multiple features aren't trying to set the same fid
+ pffs.push_back(pff);
+ late_ffs.push_back(pff.get());
+ }
+ }
+ ModelSet late_models(wv, late_ffs);
+
+ const int sample_max_trans = conf.count("max_translation_sample") ?
+ conf["max_translation_sample"].as<int>() : 0;
+ if (sample_max_trans)
+ rng.reset(new RandomNumberGenerator<boost::mt19937>);
+ const bool aligner_mode = conf.count("aligner");
+ const bool minimal_forests = conf.count("minimal_forests");
+ const bool graphviz = conf.count("graphviz");
+ const bool encode_b64 = conf["vector_format"].as<string>() == "b64";
+ const bool kbest = conf.count("k_best");
+ const bool unique_kbest = conf.count("unique_k_best");
+ shared_ptr<WriteFile> extract_file;
+ if (conf.count("extract_rules"))
+ extract_file.reset(new WriteFile(conf["extract_rules"].as<string>()));
+
+ int combine_size = conf["combine_size"].as<int>();
+ if (combine_size < 1) combine_size = 1;
+ const string input = conf["input"].as<string>();
+ cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl;
+ ReadFile in_read(input);
+ istream *in = in_read.stream();
+ assert(*in);
+
+ SparseVector<double> acc_vec; // accumulate gradient
+ double acc_obj = 0; // accumulate objective
+ int g_count = 0; // number of gradient pieces computed
+ int sent_id = -1; // line counter
+
+ while(*in) {
+ Timer::Summarize();
+ ++sent_id;
+ string buf;
+ getline(*in, buf);
+ if (buf.empty()) continue;
+ map<string, string> sgml;
+ ProcessAndStripSGML(&buf, &sgml);
+ if (sgml.find("id") != sgml.end())
+ sent_id = atoi(sgml["id"].c_str());
+
+ cerr << "\nINPUT: ";
+ if (buf.size() < 100)
+ cerr << buf << endl;
+ else {
+ size_t x = buf.rfind(" ", 100);
+ if (x == string::npos) x = 100;
+ cerr << buf.substr(0, x) << " ..." << endl;
+ }
+ cerr << " id = " << sent_id << endl;
+ string to_translate;
+ Lattice ref;
+ ParseTranslatorInputLattice(buf, &to_translate, &ref);
+ const bool has_ref = ref.size() > 0;
+ SentenceMetadata smeta(sent_id, ref);
+ const bool hadoop_counters = (write_gradient);
+ Hypergraph forest; // -LM forest
+ Timer t("Translation");
+ if (!translator->Translate(to_translate, &smeta, wv, &forest)) {
+ cerr << " NO PARSE FOUND.\n";
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl;
+ cout << endl << flush;
+ continue;
+ }
+ cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl;
+ if (conf.count("show_expected_length")) {
+ const PRPair<double, double> res =
+ Inside<PRPair<double, double>,
+ PRWeightFunction<double, EdgeProb, double, ELengthWeightFunction> >(forest);
+ cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl;
+ }
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " -LM partition log(Z): " << log(z) << endl;
+ }
+ if (extract_file)
+ ExtractRulesDedupe(forest, extract_file->stream());
+ vector<WordID> trans;
+ const prob_t vs = ViterbiESentence(forest, &trans);
+ cerr << " -LM Viterbi: " << TD::GetString(trans) << endl;
+ if (conf.count("show_tree_structure"))
+ cerr << " -LM tree: " << ViterbiETree(forest) << endl;;
+ cerr << " -LM Viterbi: " << log(vs) << endl;
+
+ bool has_late_models = !late_models.empty();
+ if (has_late_models) {
+ forest.Reweight(wv);
+ forest.SortInEdgesByEdgeWeights();
+ Hypergraph lm_forest;
+ int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>();
+ ApplyModelSet(forest,
+ smeta,
+ late_models,
+ PruningConfiguration(cubepruning_pop_limit),
+ &lm_forest);
+ forest.swap(lm_forest);
+ forest.Reweight(wv);
+ trans.clear();
+ ViterbiESentence(forest, &trans);
+ cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl;
+ cerr << " +LM Viterbi: " << TD::GetString(trans) << endl;
+ }
+ if (conf.count("forest_output") && !has_ref) {
+ ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
+ assert(writer.Write(forest, minimal_forests));
+ }
+
+ if (sample_max_trans) {
+ MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as<int>() : 0);
+ } else {
+ if (kbest) {
+ DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest);
+ } else {
+ if (!graphviz && !has_ref) {
+ cout << TD::GetString(trans) << endl << flush;
+ }
+ }
+ }
+
+ const int max_trans_beam_size = conf.count("max_translation_beam") ?
+ conf["max_translation_beam"].as<int>() : 0;
+ if (max_trans_beam_size) {
+ Hack::MaxTrans(forest, max_trans_beam_size);
+ continue;
+ }
+
+ if (graphviz && !has_ref) forest.PrintGraphviz();
+
+ // the following are only used if write_gradient is true!
+ SparseVector<double> full_exp, ref_exp, gradient;
+ double log_z = 0, log_ref_z = 0;
+ if (write_gradient)
+ log_z = log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &full_exp));
+
+ if (has_ref) {
+ if (HG::Intersect(ref, &forest)) {
+ cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl;
+ cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl;
+ forest.Reweight(wv);
+ cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl;
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl;
+ if (conf.count("show_partition")) {
+ const prob_t z = Inside<prob_t, EdgeProb>(forest);
+ cerr << " Contst. partition log(Z): " << log(z) << endl;
+ }
+ //DumpKBest(sent_id, forest, 1000);
+ if (conf.count("forest_output")) {
+ ForestWriter writer(conf["forest_output"].as<string>(), sent_id);
+ assert(writer.Write(forest, minimal_forests));
+ }
+ if (aligner_mode && !output_training_vector)
+ AlignerTools::WriteAlignment(to_translate, ref, forest);
+ if (write_gradient) {
+ log_ref_z = log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &ref_exp));
+ if (log_z < log_ref_z) {
+ cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl;
+ exit(1);
+ }
+ //cerr << "FULL: " << full_exp << endl;
+ //cerr << " REF: " << ref_exp << endl;
+ ref_exp -= full_exp;
+ acc_vec += ref_exp;
+ acc_obj += (log_z - log_ref_z);
+ }
+ if (feature_expectations) {
+ acc_obj += log(
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(forest, &ref_exp));
+ acc_vec += ref_exp;
+ }
+
+ if (output_training_vector) {
+ ++g_count;
+ if (g_count % combine_size == 0) {
+ if (encode_b64) {
+ cout << "0\t";
+ B64::Encode(acc_obj, acc_vec, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ acc_vec.clear();
+ acc_obj = 0;
+ }
+ }
+ if (conf.count("graphviz")) forest.PrintGraphviz();
+ } else {
+ cerr << " REFERENCE UNREACHABLE.\n";
+ if (write_gradient) {
+ if (hadoop_counters)
+ cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl;
+ cout << endl << flush;
+ }
+ }
+ }
+ }
+ if (output_training_vector && !acc_vec.empty()) {
+ if (encode_b64) {
+ cout << "0\t";
+ B64::Encode(acc_obj, acc_vec, &cout);
+ cout << endl << flush;
+ } else {
+ cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush;
+ }
+ }
+}
+
diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc
new file mode 100644
index 00000000..86abfb4a
--- /dev/null
+++ b/src/cdec_ff.cc
@@ -0,0 +1,18 @@
+#include <boost/shared_ptr.hpp>
+
+#include "ff.h"
+#include "lm_ff.h"
+#include "ff_factory.h"
+#include "ff_wordalign.h"
+
+boost::shared_ptr<FFRegistry> global_ff_registry;
+
+void register_feature_functions() {
+ global_ff_registry->Register("LanguageModel", new FFFactory<LanguageModel>);
+ global_ff_registry->Register("WordPenalty", new FFFactory<WordPenalty>);
+ global_ff_registry->Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
+ global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>);
+ global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>);
+ global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>);
+};
+
diff --git a/src/collapse_weights.cc b/src/collapse_weights.cc
new file mode 100644
index 00000000..5e0f3f72
--- /dev/null
+++ b/src/collapse_weights.cc
@@ -0,0 +1,102 @@
+#include <iostream>
+#include <fstream>
+#include <tr1/unordered_map>
+
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "prob.h"
+#include "filelib.h"
+#include "trule.h"
+#include "weights.h"
+
+namespace po = boost::program_options;
+using namespace std;
+
+typedef std::tr1::unordered_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > > MarginalMap;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("grammar,g", po::value<string>(), "Grammar file")
+ ("weights,w", po::value<string>(), "Weights file");
+ po::options_description clo("Command line options");
+ clo.add_options()
+ ("config,c", po::value<string>(), "Configuration file")
+ ("help,h", "Print this help message and exit");
+ po::options_description dconfig_options, dcmdline_options;
+ dconfig_options.add(opts);
+ dcmdline_options.add(opts).add(clo);
+
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ if (conf->count("config")) {
+ const string cfg = (*conf)["config"].as<string>();
+ cerr << "Configuration file: " << cfg << endl;
+ ifstream config(cfg.c_str());
+ po::store(po::parse_config_file(config, dconfig_options), *conf);
+ }
+ po::notify(*conf);
+
+ if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ const string wfile = conf["weights"].as<string>();
+ const string gfile = conf["grammar"].as<string>();
+ Weights wm;
+ wm.InitFromFile(wfile);
+ vector<double> w;
+ wm.InitVector(&w);
+ MarginalMap e_tots;
+ MarginalMap f_tots;
+ prob_t tot;
+ {
+ ReadFile rf(gfile);
+ assert(*rf.stream());
+ istream& in = *rf.stream();
+ cerr << "Computing marginals...\n";
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ ++lc;
+ if (line.empty()) continue;
+ TRule tr(line, true);
+ if (tr.GetFeatureValues().empty())
+ cerr << "Line " << lc << ": empty features - may introduce bias\n";
+ prob_t prob;
+ prob.logeq(tr.GetFeatureValues().dot(w));
+ e_tots[tr.e_] += prob;
+ f_tots[tr.f_] += prob;
+ tot += prob;
+ }
+ }
+ bool normalized = (fabs(log(tot)) < 0.001);
+ cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl;
+ ReadFile rf(gfile);
+ istream&in = *rf.stream();
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ TRule tr(line, true);
+ const double lp = tr.GetFeatureValues().dot(w);
+ if (isinf(lp)) { continue; }
+ tr.scores_.clear();
+
+ cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot);
+ if (!normalized) {
+ cout << ";ZF_and_E=" << lp;
+ }
+ cout << ";F_given_E=" << lp - log(e_tots[tr.e_])
+ << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl;
+ }
+ return 0;
+}
+
diff --git a/src/dict.h b/src/dict.h
new file mode 100644
index 00000000..bae9debe
--- /dev/null
+++ b/src/dict.h
@@ -0,0 +1,40 @@
+#ifndef DICT_H_
+#define DICT_H_
+
+#include <cassert>
+#include <cstring>
+#include <tr1/unordered_map>
+#include <string>
+#include <vector>
+
+#include <boost/functional/hash.hpp>
+
+#include "wordid.h"
+
+class Dict {
+ typedef std::tr1::unordered_map<std::string, WordID, boost::hash<std::string> > Map;
+ public:
+ Dict() : b0_("<bad0>") { words_.reserve(1000); }
+ inline int max() const { return words_.size(); }
+ inline WordID Convert(const std::string& word) {
+ Map::iterator i = d_.find(word);
+ if (i == d_.end()) {
+ words_.push_back(word);
+ d_[word] = words_.size();
+ return words_.size();
+ } else {
+ return i->second;
+ }
+ }
+ inline const std::string& Convert(const WordID& id) const {
+ if (id == 0) return b0_;
+ assert(id <= words_.size());
+ return words_[id-1];
+ }
+ private:
+ const std::string b0_;
+ std::vector<std::string> words_;
+ Map d_;
+};
+
+#endif
diff --git a/src/dict_test.cc b/src/dict_test.cc
new file mode 100644
index 00000000..5c5d84f0
--- /dev/null
+++ b/src/dict_test.cc
@@ -0,0 +1,30 @@
+#include "dict.h"
+
+#include <gtest/gtest.h>
+#include <cassert>
+
+class DTest : public testing::Test {
+ public:
+ DTest() {}
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(DTest, Convert) {
+ Dict d;
+ WordID a = d.Convert("foo");
+ WordID b = d.Convert("bar");
+ std::string x = "foo";
+ WordID c = d.Convert(x);
+ EXPECT_NE(a, b);
+ EXPECT_EQ(a, c);
+ EXPECT_EQ(d.Convert(a), "foo");
+ EXPECT_EQ(d.Convert(b), "bar");
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/earley_composer.cc b/src/earley_composer.cc
new file mode 100644
index 00000000..a59686e0
--- /dev/null
+++ b/src/earley_composer.cc
@@ -0,0 +1,726 @@
+#include "earley_composer.h"
+
+#include <iostream>
+#include <fstream>
+#include <map>
+#include <queue>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+#include <boost/lexical_cast.hpp>
+
+#include "phrasetable_fst.h"
+#include "sparse_vector.h"
+#include "tdict.h"
+#include "hg.h"
+
+using boost::shared_ptr;
+namespace po = boost::program_options;
+using namespace std;
+using namespace std::tr1;
+
+// Define the following macro if you want to see lots of debugging output
+// when you run the chart parser
+#undef DEBUG_CHART_PARSER
+
+// A few constants used by the chart parser ///////////////
+static const int kMAX_NODES = 2000000;
+static const string kPHRASE_STRING = "X";
+static bool constants_need_init = true;
+static WordID kUNIQUE_START;
+static WordID kPHRASE;
+static TRulePtr kX1X2;
+static TRulePtr kX1;
+static WordID kEPS;
+static TRulePtr kEPSRule;
+
+static void InitializeConstants() {
+ if (constants_need_init) {
+ kPHRASE = TD::Convert(kPHRASE_STRING) * -1;
+ kUNIQUE_START = TD::Convert("S") * -1;
+ kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]"));
+ kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ kEPS = TD::Convert("<eps>");
+ constants_need_init = false;
+ }
+}
+////////////////////////////////////////////////////////////
+
+class EGrammarNode {
+ friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
+ friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g);
+ public:
+#ifdef DEBUG_CHART_PARSER
+ string hint;
+#endif
+ EGrammarNode() : is_some_rule_complete(false), is_root(false) {}
+ const map<WordID, EGrammarNode>& GetTerminals() const { return tptr; }
+ const map<WordID, EGrammarNode>& GetNonTerminals() const { return ntptr; }
+ bool HasNonTerminals() const { return (!ntptr.empty()); }
+ bool HasTerminals() const { return (!tptr.empty()); }
+ bool RuleCompletes() const {
+ return (is_some_rule_complete || (ntptr.empty() && tptr.empty()));
+ }
+ bool GrammarContinues() const {
+ return !(ntptr.empty() && tptr.empty());
+ }
+ bool IsRoot() const {
+ return is_root;
+ }
+ // these are the features associated with the rule from the start
+ // node up to this point. If you use these features, you must
+ // not Extend() this rule.
+ const SparseVector<double>& GetCFGProductionFeatures() const {
+ return input_features;
+ }
+
+ const EGrammarNode* Extend(const WordID& t) const {
+ if (t < 0) {
+ map<WordID, EGrammarNode>::const_iterator it = ntptr.find(t);
+ if (it == ntptr.end()) return NULL;
+ return &it->second;
+ } else {
+ map<WordID, EGrammarNode>::const_iterator it = tptr.find(t);
+ if (it == tptr.end()) return NULL;
+ return &it->second;
+ }
+ }
+
+ private:
+ map<WordID, EGrammarNode> tptr;
+ map<WordID, EGrammarNode> ntptr;
+ SparseVector<double> input_features;
+ bool is_some_rule_complete;
+ bool is_root;
+};
+typedef map<WordID, EGrammarNode> EGrammar; // indexed by the rule LHS
+
+// edges are immutable once created
+struct Edge {
+#ifdef DEBUG_CHART_PARSER
+ static int id_count;
+ const int id;
+#endif
+ const WordID cat; // lhs side of rule proved/being proved
+ const EGrammarNode* const dot; // dot position
+ const FSTNode* const q; // start of span
+ const FSTNode* const r; // end of span
+ const Edge* const active_parent; // back pointer, NULL for PREDICT items
+ const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items
+ const TargetPhraseSet* const tps; // translations
+ shared_ptr<SparseVector<double> > features; // features from CFG rule
+
+ bool IsPassive() const {
+ // when a rule is completed, this value will be set
+ return static_cast<bool>(features);
+ }
+ bool IsActive() const { return !IsPassive(); }
+ bool IsInitial() const {
+ return !(active_parent || passive_parent);
+ }
+ bool IsCreatedByScan() const {
+ return active_parent && !passive_parent && !dot->IsRoot();
+ }
+ bool IsCreatedByPredict() const {
+ return dot->IsRoot();
+ }
+ bool IsCreatedByComplete() const {
+ return active_parent && passive_parent;
+ }
+
+ // constructor for PREDICT
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r, const Edge* act_parent) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps(NULL) {}
+
+ // constructors for SCAN
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const TargetPhraseSet* translations) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {}
+
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const TargetPhraseSet* translations,
+ const SparseVector<double>& feats) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations),
+ features(new SparseVector<double>(feats)) {}
+
+ // constructors for COMPLETE
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const Edge *pas_par) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL) {
+ assert(pas_par->IsPassive());
+ assert(act_par->IsActive());
+ }
+
+ Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j,
+ const Edge* act_par, const Edge *pas_par, const SparseVector<double>& feats) :
+#ifdef DEBUG_CHART_PARSER
+ id(++id_count),
+#endif
+ cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL),
+ features(new SparseVector<double>(feats)) {
+ assert(pas_par->IsPassive());
+ assert(act_par->IsActive());
+ }
+
+ // constructor for COMPLETE query
+ Edge(const FSTNode* _r) :
+#ifdef DEBUG_CHART_PARSER
+ id(0),
+#endif
+ cat(0), dot(NULL), q(NULL),
+ r(_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+ // constructor for MERGE quere
+ Edge(const FSTNode* _q, int) :
+#ifdef DEBUG_CHART_PARSER
+ id(0),
+#endif
+ cat(0), dot(NULL), q(_q),
+ r(NULL), active_parent(NULL), passive_parent(NULL), tps(NULL) {}
+};
+#ifdef DEBUG_CHART_PARSER
+int Edge::id_count = 0;
+#endif
+
+ostream& operator<<(ostream& os, const Edge& e) {
+ string type = "PREDICT";
+ if (e.IsCreatedByScan())
+ type = "SCAN";
+ else if (e.IsCreatedByComplete())
+ type = "COMPLETE";
+ os << "["
+#ifdef DEBUG_CHART_PARSER
+ << '(' << e.id << ") "
+#else
+ << '(' << &e << ") "
+#endif
+ << "q=" << e.q << ", r=" << e.r
+ << ", cat="<< TD::Convert(e.cat*-1) << ", dot="
+ << e.dot
+#ifdef DEBUG_CHART_PARSER
+ << e.dot->hint
+#endif
+ << (e.IsActive() ? ", Active" : ", Passive")
+ << ", " << type;
+#ifdef DEBUG_CHART_PARSER
+ if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; }
+ if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; }
+#endif
+ if (e.tps) { os << ", tps=" << e.tps; }
+ return os << ']';
+}
+
+struct Traversal {
+ const Edge* const edge; // result from the active / passive combination
+ const Edge* const active;
+ const Edge* const passive;
+ Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {}
+};
+
+struct UniqueTraversalHash {
+ size_t operator()(const Traversal* t) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(t->active);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(t->passive);
+ x = ((x << 5) + x) ^ t->edge->IsActive();
+ return x;
+ }
+};
+
+struct UniqueTraversalEquals {
+ size_t operator()(const Traversal* a, const Traversal* b) const {
+ return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive());
+ }
+};
+
+struct UniqueEdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ if (e->IsActive()) {
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->dot);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ x = ((x << 5) + x) ^ static_cast<size_t>(e->cat);
+ x += 13;
+ } else { // with passive edges, we don't care about the dot
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ x = ((x << 5) + x) ^ static_cast<size_t>(e->cat);
+ }
+ return x;
+ }
+};
+
+struct UniqueEdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ if (a->IsActive() != b->IsActive()) return false;
+ if (a->IsActive()) {
+ return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r);
+ } else {
+ return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r);
+ }
+ }
+};
+
+struct REdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->r);
+ return x;
+ }
+};
+
+struct REdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ return (a->r == b->r);
+ }
+};
+
+struct QEdgeHash {
+ size_t operator()(const Edge* e) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ reinterpret_cast<size_t>(e->q);
+ return x;
+ }
+};
+
+struct QEdgeEquals {
+ bool operator()(const Edge* a, const Edge* b) const {
+ return (a->q == b->q);
+ }
+};
+
+struct EdgeQueue {
+ queue<const Edge*> q;
+ EdgeQueue() {}
+ void clear() { while(!q.empty()) q.pop(); }
+ bool HasWork() const { return !q.empty(); }
+ const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; }
+ void AddEdge(const Edge* s) { q.push(s); }
+};
+
+class EarleyComposerImpl {
+ public:
+ EarleyComposerImpl(WordID start_cat, const FSTNode& q_0) : start_cat_(start_cat), q_0_(&q_0) {}
+
+ // returns false if the intersection is empty
+ bool Compose(const EGrammar& g, Hypergraph* forest) {
+ goal_node = NULL;
+ EGrammar::const_iterator sit = g.find(start_cat_);
+ forest->ReserveNodes(kMAX_NODES);
+ assert(sit != g.end());
+ Edge* init = new Edge(start_cat_, &sit->second, q_0_);
+ assert(IncorporateNewEdge(init));
+ while (exp_agenda.HasWork() || agenda.HasWork()) {
+ while(exp_agenda.HasWork()) {
+ const Edge* edge = exp_agenda.Next();
+ FinishEdge(edge, forest);
+ }
+ if (agenda.HasWork()) {
+ const Edge* edge = agenda.Next();
+#ifdef DEBUG_CHART_PARSER
+ cerr << "processing (" << edge->id << ')' << endl;
+#endif
+ if (edge->IsActive()) {
+ if (edge->dot->HasTerminals())
+ DoScan(edge);
+ if (edge->dot->HasNonTerminals()) {
+ DoMergeWithPassives(edge);
+ DoPredict(edge, g);
+ }
+ } else {
+ DoComplete(edge);
+ }
+ }
+ }
+ if (goal_node) {
+ forest->PruneUnreachable(goal_node->id_);
+ forest->EpsilonRemove(kEPS);
+ }
+ FreeAll();
+ return goal_node;
+ }
+
+ void FreeAll() {
+ for (int i = 0; i < free_list_.size(); ++i)
+ delete free_list_[i];
+ free_list_.clear();
+ for (int i = 0; i < traversal_free_list_.size(); ++i)
+ delete traversal_free_list_[i];
+ traversal_free_list_.clear();
+ all_traversals.clear();
+ exp_agenda.clear();
+ agenda.clear();
+ tps2node.clear();
+ edge2node.clear();
+ all_edges.clear();
+ passive_edges.clear();
+ active_edges.clear();
+ }
+
+ ~EarleyComposerImpl() {
+ FreeAll();
+ }
+
+ // returns the total number of edges created during composition
+ int EdgesCreated() const {
+ return free_list_.size();
+ }
+
+ private:
+ void DoScan(const Edge* edge) {
+ // here, we assume that the FST will potentially have many more outgoing
+ // edges than the grammar, which will be just a couple. If you want to
+ // efficiently handle the case where both are relatively large, this code
+ // will need to change how the intersection is done. The best general
+ // solution would probably be the Baeza-Yates double binary search.
+
+ const EGrammarNode* dot = edge->dot;
+ const FSTNode* r = edge->r;
+ const map<WordID, EGrammarNode>& terms = dot->GetTerminals();
+ for (map<WordID, EGrammarNode>::const_iterator git = terms.begin();
+ git != terms.end(); ++git) {
+ const FSTNode* next_r = r->Extend(git->first);
+ if (!next_r) continue;
+ const EGrammarNode* next_dot = &git->second;
+ const bool grammar_continues = next_dot->GrammarContinues();
+ const bool rule_completes = next_dot->RuleCompletes();
+ assert(grammar_continues || rule_completes);
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ // create up to 4 new edges!
+ if (next_r->HasOutgoingNonEpsilonEdges()) { // are there further symbols in the FST?
+ const TargetPhraseSet* translations = NULL;
+ if (rule_completes)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations, input_features));
+ if (grammar_continues)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations));
+ }
+ if (next_r->HasData()) { // indicates a loop back to q_0 in the FST
+ const TargetPhraseSet* translations = next_r->GetTranslations();
+ if (rule_completes)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations, input_features));
+ if (grammar_continues)
+ IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations));
+ }
+ }
+ }
+
+ void DoPredict(const Edge* edge, const EGrammar& g) {
+ const EGrammarNode* dot = edge->dot;
+ const map<WordID, EGrammarNode>& non_terms = dot->GetNonTerminals();
+ for (map<WordID, EGrammarNode>::const_iterator git = non_terms.begin();
+ git != non_terms.end(); ++git) {
+ const WordID nt_to_predict = git->first;
+ //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl;
+ EGrammar::const_iterator egi = g.find(nt_to_predict);
+ if (egi == g.end()) {
+ cerr << "[ERROR] Can't find any grammar rules with a LHS of type "
+ << TD::Convert(-1*nt_to_predict) << '!' << endl;
+ continue;
+ }
+ assert(edge->IsActive());
+ const EGrammarNode* new_dot = &egi->second;
+ Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge);
+ IncorporateNewEdge(new_edge);
+ }
+ }
+
+ void DoComplete(const Edge* passive) {
+#ifdef DEBUG_CHART_PARSER
+ cerr << " complete: " << *passive << endl;
+#endif
+ const WordID completed_nt = passive->cat;
+ const FSTNode* q = passive->q;
+ const FSTNode* next_r = passive->r;
+ const Edge query(q);
+ const pair<unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator,
+ unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator > p =
+ active_edges.equal_range(&query);
+ for (unordered_multiset<const Edge*, REdgeHash, REdgeEquals>::iterator it = p.first;
+ it != p.second; ++it) {
+ const Edge* active = *it;
+#ifdef DEBUG_CHART_PARSER
+ cerr << " pos: " << *active << endl;
+#endif
+ const EGrammarNode* next_dot = active->dot->Extend(completed_nt);
+ if (!next_dot) continue;
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ // add up to 2 rules
+ if (next_dot->RuleCompletes())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features));
+ if (next_dot->GrammarContinues())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive));
+ }
+ }
+
+ void DoMergeWithPassives(const Edge* active) {
+ // edge is active, has non-terminals, we need to find the passives that can extend it
+ assert(active->IsActive());
+ assert(active->dot->HasNonTerminals());
+#ifdef DEBUG_CHART_PARSER
+ cerr << " merge active with passives: ACT=" << *active << endl;
+#endif
+ const Edge query(active->r, 1);
+ const pair<unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator,
+ unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator > p =
+ passive_edges.equal_range(&query);
+ for (unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals>::iterator it = p.first;
+ it != p.second; ++it) {
+ const Edge* passive = *it;
+ const EGrammarNode* next_dot = active->dot->Extend(passive->cat);
+ if (!next_dot) continue;
+ const FSTNode* next_r = passive->r;
+ const SparseVector<double>& input_features = next_dot->GetCFGProductionFeatures();
+ if (next_dot->RuleCompletes())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features));
+ if (next_dot->GrammarContinues())
+ IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive));
+ }
+ }
+
+ // take ownership of edge memory, add to various indexes, etc
+ // returns true if this edge is new
+ bool IncorporateNewEdge(Edge* edge) {
+ free_list_.push_back(edge);
+ if (edge->passive_parent && edge->active_parent) {
+ Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent);
+ traversal_free_list_.push_back(t);
+ if (all_traversals.find(t) != all_traversals.end()) {
+ return false;
+ } else {
+ all_traversals.insert(t);
+ }
+ }
+ exp_agenda.AddEdge(edge);
+ return true;
+ }
+
+ bool FinishEdge(const Edge* edge, Hypergraph* hg) {
+ bool is_new = false;
+ if (all_edges.find(edge) == all_edges.end()) {
+#ifdef DEBUG_CHART_PARSER
+ cerr << *edge << " is NEW\n";
+#endif
+ all_edges.insert(edge);
+ is_new = true;
+ if (edge->IsPassive()) passive_edges.insert(edge);
+ if (edge->IsActive()) active_edges.insert(edge);
+ agenda.AddEdge(edge);
+ } else {
+#ifdef DEBUG_CHART_PARSER
+ cerr << *edge << " is NOT NEW.\n";
+#endif
+ }
+ AddEdgeToTranslationForest(edge, hg);
+ return is_new;
+ }
+
+ // build the translation forest
+ void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) {
+ assert(hg->nodes_.size() < kMAX_NODES);
+ Hypergraph::Node* tps = NULL;
+ // first add any target language rules
+ if (edge->tps) {
+ Hypergraph::Node*& node = tps2node[(size_t)edge->tps];
+ if (!node) {
+ // cerr << "Creating phrases for " << edge->tps << endl;
+ const vector<TRulePtr>& rules = edge->tps->GetRules();
+ node = hg->AddNode(kPHRASE, "");
+ for (int i = 0; i < rules.size(); ++i) {
+ Hypergraph::Edge* hg_edge = hg->AddEdge(rules[i], Hypergraph::TailNodeVector());
+ hg_edge->feature_values_ += rules[i]->GetFeatureValues();
+ hg->ConnectEdgeToHeadNode(hg_edge, node);
+ }
+ }
+ tps = node;
+ }
+ Hypergraph::Node*& head_node = edge2node[edge];
+ if (!head_node)
+ head_node = hg->AddNode(kPHRASE, "");
+ if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {
+ assert(goal_node == NULL || goal_node == head_node);
+ goal_node = head_node;
+ }
+ Hypergraph::TailNodeVector tail;
+ SparseVector<double> extra;
+ if (edge->IsCreatedByPredict()) {
+ // extra.set_value(FD::Convert("predict"), 1);
+ } else if (edge->IsCreatedByScan()) {
+ tail.push_back(edge2node[edge->active_parent]->id_);
+ if (tps) {
+ tail.push_back(tps->id_);
+ }
+ //extra.set_value(FD::Convert("scan"), 1);
+ } else if (edge->IsCreatedByComplete()) {
+ tail.push_back(edge2node[edge->active_parent]->id_);
+ tail.push_back(edge2node[edge->passive_parent]->id_);
+ //extra.set_value(FD::Convert("complete"), 1);
+ } else {
+ assert(!"unexpected edge type!");
+ }
+ //cerr << head_node->id_ << "<--" << *edge << endl;
+
+#ifdef DEBUG_CHART_PARSER
+ for (int i = 0; i < tail.size(); ++i)
+ if (tail[i] == head_node->id_) {
+ cerr << "ERROR: " << *edge << "\n i=" << i << endl;
+ if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; }
+ if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; }
+ assert(!"self-loop found!");
+ }
+#endif
+ Hypergraph::Edge* hg_edge = NULL;
+ if (tail.size() == 0) {
+ hg_edge = hg->AddEdge(kEPSRule, tail);
+ } else if (tail.size() == 1) {
+ hg_edge = hg->AddEdge(kX1, tail);
+ } else if (tail.size() == 2) {
+ hg_edge = hg->AddEdge(kX1X2, tail);
+ }
+ if (edge->features)
+ hg_edge->feature_values_ += *edge->features;
+ hg_edge->feature_values_ += extra;
+ hg->ConnectEdgeToHeadNode(hg_edge, head_node);
+ }
+
+ Hypergraph::Node* goal_node;
+ EdgeQueue exp_agenda;
+ EdgeQueue agenda;
+ unordered_map<size_t, Hypergraph::Node*> tps2node;
+ unordered_map<const Edge*, Hypergraph::Node*, UniqueEdgeHash, UniqueEdgeEquals> edge2node;
+ unordered_set<const Traversal*, UniqueTraversalHash, UniqueTraversalEquals> all_traversals;
+ unordered_set<const Edge*, UniqueEdgeHash, UniqueEdgeEquals> all_edges;
+ unordered_multiset<const Edge*, QEdgeHash, QEdgeEquals> passive_edges;
+ unordered_multiset<const Edge*, REdgeHash, REdgeEquals> active_edges;
+ vector<Edge*> free_list_;
+ vector<Traversal*> traversal_free_list_;
+ const WordID start_cat_;
+ const FSTNode* const q_0_;
+};
+
+#ifdef DEBUG_CHART_PARSER
+static string TrimRule(const string& r) {
+ size_t start = r.find(" |||") + 5;
+ size_t end = r.rfind(" |||");
+ return r.substr(start, end - start);
+}
+#endif
+
+void AddGrammarRule(const string& r, EGrammar* g) {
+ const size_t pos = r.find(" ||| ");
+ if (pos == string::npos || r[0] != '[') {
+ cerr << "Bad rule: " << r << endl;
+ return;
+ }
+ const size_t rpos = r.rfind(" ||| ");
+ string feats;
+ string rs = r;
+ if (rpos != pos) {
+ feats = r.substr(rpos + 5);
+ rs = r.substr(0, rpos);
+ }
+ string rhs = rs.substr(pos + 5);
+ string trule = rs + " ||| " + rhs + " ||| " + feats;
+ TRule tr(trule);
+#ifdef DEBUG_CHART_PARSER
+ string hint_last_rule;
+#endif
+ EGrammarNode* cur = &(*g)[tr.GetLHS()];
+ cur->is_root = true;
+ for (int i = 0; i < tr.FLength(); ++i) {
+ WordID sym = tr.f()[i];
+#ifdef DEBUG_CHART_PARSER
+ hint_last_rule = TD::Convert(sym < 0 ? -sym : sym);
+ cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString());
+#endif
+ if (sym < 0)
+ cur = &cur->ntptr[sym];
+ else
+ cur = &cur->tptr[sym];
+ }
+#ifdef DEBUG_CHART_PARSER
+ cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString());
+#endif
+ cur->is_some_rule_complete = true;
+ cur->input_features = tr.GetFeatureValues();
+}
+
+EarleyComposer::~EarleyComposer() {
+ delete pimpl_;
+}
+
+EarleyComposer::EarleyComposer(const FSTNode* fst) {
+ InitializeConstants();
+ pimpl_ = new EarleyComposerImpl(kUNIQUE_START, *fst);
+}
+
+bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) {
+ // first, convert the src forest into an EGrammar
+ EGrammar g;
+ const int nedges = src_forest.edges_.size();
+ const int nnodes = src_forest.nodes_.size();
+ vector<int> cats(nnodes);
+ bool assign_cats = false;
+ for (int i = 0; i < nnodes; ++i)
+ if (assign_cats) {
+ cats[i] = TD::Convert("CAT_" + boost::lexical_cast<string>(i)) * -1;
+ } else {
+ cats[i] = src_forest.nodes_[i].cat_;
+ }
+ // construct the grammar
+ for (int i = 0; i < nedges; ++i) {
+ const Hypergraph::Edge& edge = src_forest.edges_[i];
+ const vector<WordID>& src = edge.rule_->f();
+ EGrammarNode* cur = &g[cats[edge.head_node_]];
+ cur->is_root = true;
+ int ntc = 0;
+ for (int j = 0; j < src.size(); ++j) {
+ WordID sym = src[j];
+ if (sym <= 0) {
+ sym = cats[edge.tail_nodes_[ntc]];
+ ++ntc;
+ cur = &cur->ntptr[sym];
+ } else {
+ cur = &cur->tptr[sym];
+ }
+ }
+ cur->is_some_rule_complete = true;
+ cur->input_features = edge.feature_values_;
+ }
+ EGrammarNode& goal_rule = g[kUNIQUE_START];
+ assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) ||
+ (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1));
+
+ return pimpl_->Compose(g, trg_forest);
+}
+
+bool EarleyComposer::Compose(istream* in, Hypergraph* trg_forest) {
+ EGrammar g;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ AddGrammarRule(line, &g);
+ }
+
+ return pimpl_->Compose(g, trg_forest);
+}
diff --git a/src/earley_composer.h b/src/earley_composer.h
new file mode 100644
index 00000000..9f786bf6
--- /dev/null
+++ b/src/earley_composer.h
@@ -0,0 +1,29 @@
+#ifndef _EARLEY_COMPOSER_H_
+#define _EARLEY_COMPOSER_H_
+
+#include <iostream>
+
+class EarleyComposerImpl;
+class FSTNode;
+class Hypergraph;
+
+class EarleyComposer {
+ public:
+ ~EarleyComposer();
+ EarleyComposer(const FSTNode* phrasetable_root);
+ bool Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
+
+ // reads the grammar from a file. There must be a single top-level
+ // S -> X rule. Anything else is possible. Format is:
+ // [S] ||| [SS,1]
+ // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3
+ // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8
+ // [NP] ||| [DET,1] [N,2] ||| Feature3=2
+ // ...
+ bool Compose(std::istream* grammar_file, Hypergraph* trg_forest);
+
+ private:
+ EarleyComposerImpl* pimpl_;
+};
+
+#endif
diff --git a/src/exp_semiring.h b/src/exp_semiring.h
new file mode 100644
index 00000000..f91beee4
--- /dev/null
+++ b/src/exp_semiring.h
@@ -0,0 +1,71 @@
+#ifndef _EXP_SEMIRING_H_
+#define _EXP_SEMIRING_H_
+
+#include <iostream>
+
+// this file implements the first-order expectation semiring described
+// in Li & Eisner (EMNLP 2009)
+
+// requirements:
+// RType * RType ==> RType
+// PType * PType ==> PType
+// RType * PType ==> RType
+// good examples:
+// PType scalar, RType vector
+// BAD examples:
+// PType vector, RType scalar
+template <typename PType, typename RType>
+struct PRPair {
+ PRPair() : p(), r() {}
+ // Inside algorithm requires that T(0) and T(1)
+ // return the 0 and 1 values of the semiring
+ explicit PRPair(double x) : p(x), r() {}
+ PRPair(const PType& p, const RType& r) : p(p), r(r) {}
+ PRPair& operator+=(const PRPair& o) {
+ p += o.p;
+ r += o.r;
+ return *this;
+ }
+ PRPair& operator*=(const PRPair& o) {
+ r = (o.r * p) + (o.p * r);
+ p *= o.p;
+ return *this;
+ }
+ PType p;
+ RType r;
+};
+
+template <typename P, typename R>
+std::ostream& operator<<(std::ostream& o, const PRPair<P,R>& x) {
+ return o << '<' << x.p << ", " << x.r << '>';
+}
+
+template <typename P, typename R>
+const PRPair<P,R> operator+(const PRPair<P,R>& a, const PRPair<P,R>& b) {
+ PRPair<P,R> result = a;
+ result += b;
+ return result;
+}
+
+template <typename P, typename R>
+const PRPair<P,R> operator*(const PRPair<P,R>& a, const PRPair<P,R>& b) {
+ PRPair<P,R> result = a;
+ result *= b;
+ return result;
+}
+
+template <typename P, typename PWeightFunction, typename R, typename RWeightFunction>
+struct PRWeightFunction {
+ explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(),
+ const RWeightFunction& rwf = RWeightFunction()) :
+ pweight(pwf), rweight(rwf) {}
+ PRPair<P,R> operator()(const Hypergraph::Edge& e) const {
+ const P p = pweight(e);
+ const R r = rweight(e);
+ return PRPair<P,R>(p, r * p);
+ }
+ const PWeightFunction pweight;
+ const RWeightFunction rweight;
+};
+
+#endif
diff --git a/src/fdict.cc b/src/fdict.cc
new file mode 100644
index 00000000..83aa7cea
--- /dev/null
+++ b/src/fdict.cc
@@ -0,0 +1,4 @@
+#include "fdict.h"
+
+Dict FD::dict_;
+
diff --git a/src/fdict.h b/src/fdict.h
new file mode 100644
index 00000000..ff491cfb
--- /dev/null
+++ b/src/fdict.h
@@ -0,0 +1,21 @@
+#ifndef _FDICT_H_
+#define _FDICT_H_
+
+#include <string>
+#include <vector>
+#include "dict.h"
+
+struct FD {
+ static Dict dict_;
+ static inline int NumFeats() {
+ return dict_.max() + 1;
+ }
+ static inline WordID Convert(const std::string& s) {
+ return dict_.Convert(s);
+ }
+ static inline const std::string& Convert(const WordID& w) {
+ return dict_.Convert(w);
+ }
+};
+
+#endif
diff --git a/src/ff.cc b/src/ff.cc
new file mode 100644
index 00000000..488e6468
--- /dev/null
+++ b/src/ff.cc
@@ -0,0 +1,93 @@
+#include "ff.h"
+
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+FeatureFunction::~FeatureFunction() {}
+
+
+void FeatureFunction::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ (void) ant_state;
+ (void) features;
+}
+
+// Hiero and Joshua use log_10(e) as the value, so I do to
+WordPenalty::WordPenalty(const string& param) :
+ fid_(FD::Convert("WordPenalty")),
+ value_(-1.0 / log(10)) {
+ if (!param.empty()) {
+ cerr << "Warning WordPenalty ignoring parameter: " << param << endl;
+ }
+}
+
+void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ (void) ant_states;
+ (void) state;
+ (void) estimated_features;
+ features->set_value(fid_, edge.rule_->EWords() * value_);
+}
+
+ModelSet::ModelSet(const vector<double>& w, const vector<const FeatureFunction*>& models) :
+ models_(models),
+ weights_(w),
+ state_size_(0),
+ model_state_pos_(models.size()) {
+ for (int i = 0; i < models_.size(); ++i) {
+ model_state_pos_[i] = state_size_;
+ state_size_ += models_[i]->NumBytesContext();
+ }
+}
+
+void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& hg,
+ Hypergraph::Edge* edge,
+ string* context,
+ prob_t* combination_cost_estimate) const {
+ context->resize(state_size_);
+ memset(&(*context)[0], 0, state_size_);
+ SparseVector<double> est_vals; // only computed if combination_cost_estimate is non-NULL
+ if (combination_cost_estimate) *combination_cost_estimate = prob_t::One();
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ void* cur_ff_context = NULL;
+ vector<const void*> ants(edge->tail_nodes_.size());
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ cur_ff_context = &(*context)[spos];
+ for (int i = 0; i < ants.size(); ++i) {
+ ants[i] = &hg.nodes_[edge->tail_nodes_[i]].state_[spos];
+ }
+ }
+ ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context);
+ }
+ if (combination_cost_estimate)
+ combination_cost_estimate->logeq(est_vals.dot(weights_));
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
+void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const {
+ assert(1 == edge->rule_->Arity());
+
+ for (int i = 0; i < models_.size(); ++i) {
+ const FeatureFunction& ff = *models_[i];
+ const void* ant_state = NULL;
+ bool has_context = ff.NumBytesContext() > 0;
+ if (has_context) {
+ int spos = model_state_pos_[i];
+ ant_state = &state[spos];
+ }
+ ff.FinalTraversalFeatures(ant_state, &edge->feature_values_);
+ }
+ edge->edge_prob_.logeq(edge->feature_values_.dot(weights_));
+}
+
diff --git a/src/ff.h b/src/ff.h
new file mode 100644
index 00000000..c97e2fe2
--- /dev/null
+++ b/src/ff.h
@@ -0,0 +1,121 @@
+#ifndef _FF_H_
+#define _FF_H_
+
+#include <vector>
+
+#include "fdict.h"
+#include "hg.h"
+
+class SentenceMetadata;
+class FeatureFunction; // see definition below
+
+// if you want to develop a new feature, inherit from this class and
+// override TraversalFeaturesImpl(...). If it's a feature that returns /
+// depends on context, you may also need to implement
+// FinalTraversalFeatures(...)
+class FeatureFunction {
+ public:
+ FeatureFunction() : state_size_() {}
+ explicit FeatureFunction(int state_size) : state_size_(state_size) {}
+ virtual ~FeatureFunction();
+
+ // returns the number of bytes of context that this feature function will
+ // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua).
+ // NOTE: this value is fixed for the instance of your class, you cannot
+ // use different amounts of memory for different nodes in the forest.
+ inline int NumBytesContext() const { return state_size_; }
+
+ // Compute the feature values and (if this applies) the estimates of the
+ // feature values when this edge is used incorporated into a larger context
+ inline void TraversalFeatures(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_state) const {
+ TraversalFeaturesImpl(smeta, edge, ant_contexts,
+ features, estimated_features, out_state);
+ // TODO it's easy for careless feature function developers to overwrite
+ // the end of their state and clobber someone else's memory. These bugs
+ // will be horrendously painful to track down. There should be some
+ // optional strict mode that's enforced here that adds some kind of
+ // barrier between the blocks reserved for the residual contexts
+ }
+
+ // if there's some state left when you transition to the goal state, score
+ // it here. For example, the language model computes the cost of adding
+ // <s> and </s>.
+ virtual void FinalTraversalFeatures(const void* residual_state,
+ SparseVector<double>* final_features) const;
+
+ protected:
+ // context is a pointer to a buffer of size NumBytesContext() that the
+ // feature function can write its state to. It's up to the feature function
+ // to determine how much space it needs and to determine how to encode its
+ // residual contextual information since it is OPAQUE to all clients outside
+ // of the particular FeatureFunction class. There is one exception:
+ // equality of the contents (i.e., memcmp) is required to determine whether
+ // two states can be combined.
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const = 0;
+
+ // !!! ONLY call this from subclass *CONSTRUCTORS* !!!
+ void SetStateSize(size_t state_size) {
+ state_size_ = state_size;
+ }
+
+ private:
+ int state_size_;
+};
+
+// word penalty feature, for each word on the E side of a rule,
+// add value_
+class WordPenalty : public FeatureFunction {
+ public:
+ WordPenalty(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ const int fid_;
+ const double value_;
+};
+
+// this class is a set of FeatureFunctions that can be used to score, rescore,
+// etc. a (translation?) forest
+class ModelSet {
+ public:
+ ModelSet() : state_size_(0) {}
+
+ ModelSet(const std::vector<double>& weights,
+ const std::vector<const FeatureFunction*>& models);
+
+ // sets edge->feature_values_ and edge->edge_prob_
+ // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes
+ // must be.
+ void AddFeaturesToEdge(const SentenceMetadata& smeta,
+ const Hypergraph& hg,
+ Hypergraph::Edge* edge,
+ std::string* residual_context,
+ prob_t* combination_cost_estimate = NULL) const;
+
+ void AddFinalFeatures(const std::string& residual_context,
+ Hypergraph::Edge* edge) const;
+
+ bool empty() const { return models_.empty(); }
+ private:
+ std::vector<const FeatureFunction*> models_;
+ std::vector<double> weights_;
+ int state_size_;
+ std::vector<int> model_state_pos_;
+};
+
+#endif
diff --git a/src/ff_factory.cc b/src/ff_factory.cc
new file mode 100644
index 00000000..1854e0bb
--- /dev/null
+++ b/src/ff_factory.cc
@@ -0,0 +1,35 @@
+#include "ff_factory.h"
+
+#include "ff.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+FFFactoryBase::~FFFactoryBase() {}
+
+void FFRegistry::DisplayList() const {
+ for (map<string, shared_ptr<FFFactoryBase> >::const_iterator it = reg_.begin();
+ it != reg_.end(); ++it) {
+ cerr << " " << it->first << endl;
+ }
+}
+
+shared_ptr<FeatureFunction> FFRegistry::Create(const string& ffname, const string& param) const {
+ map<string, shared_ptr<FFFactoryBase> >::const_iterator it = reg_.find(ffname);
+ shared_ptr<FeatureFunction> res;
+ if (it == reg_.end()) {
+ cerr << "I don't know how to create feature " << ffname << endl;
+ } else {
+ res = it->second->Create(param);
+ }
+ return res;
+}
+
+void FFRegistry::Register(const string& ffname, FFFactoryBase* factory) {
+ if (reg_.find(ffname) != reg_.end()) {
+ cerr << "Duplicate registration of FeatureFunction with name " << ffname << "!\n";
+ abort();
+ }
+ reg_[ffname].reset(factory);
+}
+
diff --git a/src/ff_factory.h b/src/ff_factory.h
new file mode 100644
index 00000000..bc586567
--- /dev/null
+++ b/src/ff_factory.h
@@ -0,0 +1,39 @@
+#ifndef _FF_FACTORY_H_
+#define _FF_FACTORY_H_
+
+#include <iostream>
+#include <string>
+#include <map>
+
+#include <boost/shared_ptr.hpp>
+
+class FeatureFunction;
+class FFRegistry;
+class FFFactoryBase;
+extern boost::shared_ptr<FFRegistry> global_ff_registry;
+
+class FFRegistry {
+ friend int main(int argc, char** argv);
+ friend class FFFactoryBase;
+ public:
+ boost::shared_ptr<FeatureFunction> Create(const std::string& ffname, const std::string& param) const;
+ void DisplayList() const;
+ void Register(const std::string& ffname, FFFactoryBase* factory);
+ private:
+ FFRegistry() {}
+ std::map<std::string, boost::shared_ptr<FFFactoryBase> > reg_;
+};
+
+struct FFFactoryBase {
+ virtual ~FFFactoryBase();
+ virtual boost::shared_ptr<FeatureFunction> Create(const std::string& param) const = 0;
+};
+
+template<class FF>
+class FFFactory : public FFFactoryBase {
+ boost::shared_ptr<FeatureFunction> Create(const std::string& param) const {
+ return boost::shared_ptr<FeatureFunction>(new FF(param));
+ }
+};
+
+#endif
diff --git a/src/ff_itg_span.h b/src/ff_itg_span.h
new file mode 100644
index 00000000..b990f86a
--- /dev/null
+++ b/src/ff_itg_span.h
@@ -0,0 +1,7 @@
+#ifndef _FF_ITG_SPAN_H_
+#define _FF_ITG_SPAN_H_
+
+class ITGSpanFeatures : public FeatureFunction {
+};
+
+#endif
diff --git a/src/ff_test.cc b/src/ff_test.cc
new file mode 100644
index 00000000..1c20f9ac
--- /dev/null
+++ b/src/ff_test.cc
@@ -0,0 +1,134 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "hg.h"
+#include "lm_ff.h"
+#include "ff.h"
+#include "trule.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+LanguageModel* lm_ = NULL;
+LanguageModel* lm3_ = NULL;
+
+class FFTest : public testing::Test {
+ public:
+ FFTest() : smeta(0,Lattice()) {
+ if (!lm_) {
+ static LanguageModel slm("-o 2 ./test_data/brown.lm.gz");
+ lm_ = &slm;
+ static LanguageModel slm3("./test_data/dummy.3gram.lm -o 3");
+ lm3_ = &slm3;
+ }
+ }
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ SentenceMetadata smeta;
+};
+
+TEST_F(FFTest,LanguageModel) {
+ vector<const FeatureFunction*> ms(1, lm_);
+ TRulePtr tr1(new TRule("[X] ||| [X,1] said"));
+ TRulePtr tr2(new TRule("[X] ||| the man said"));
+ TRulePtr tr3(new TRule("[X] ||| the fat man"));
+ Hypergraph hg;
+ const int lm_fid = FD::Convert("LanguageModel");
+ vector<double> w(lm_fid + 1,1);
+ ModelSet models(w, ms);
+ string state;
+ Hypergraph::Edge edge;
+ edge.rule_ = tr2;
+ models.AddFeaturesToEdge(smeta, hg, &edge, &state);
+ double lm1 = edge.feature_values_.dot(w);
+ cerr << "lm=" << edge.feature_values_[lm_fid] << endl;
+
+ hg.nodes_.resize(1);
+ hg.edges_.resize(2);
+ hg.edges_[0].rule_ = tr3;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_);
+ hg.edges_[1].tail_nodes_.push_back(0);
+ hg.edges_[1].rule_ = tr1;
+ string state2;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2);
+ double tot = hg.edges_[1].feature_values_[lm_fid] + hg.edges_[0].feature_values_[lm_fid];
+ cerr << "lm=" << tot << endl;
+ EXPECT_TRUE(state2 == state);
+ EXPECT_FALSE(state == hg.nodes_[0].state_);
+}
+
+TEST_F(FFTest, Small) {
+ WordPenalty wp("");
+ vector<const FeatureFunction*> ms(2, lm_);
+ ms[1] = &wp;
+ TRulePtr tr1(new TRule("[X] ||| [X,1] said"));
+ TRulePtr tr2(new TRule("[X] ||| john said"));
+ TRulePtr tr3(new TRule("[X] ||| john"));
+ cerr << "RULE: " << tr1->AsString() << endl;
+ Hypergraph hg;
+ vector<double> w(2); w[0]=1.0; w[1]=-2.0;
+ ModelSet models(w, ms);
+ string state;
+ Hypergraph::Edge edge;
+ edge.rule_ = tr2;
+ cerr << tr2->AsString() << endl;
+ models.AddFeaturesToEdge(smeta, hg, &edge, &state);
+ double s1 = edge.feature_values_.dot(w);
+ cerr << "lm=" << edge.feature_values_[0] << endl;
+ cerr << "wp=" << edge.feature_values_[1] << endl;
+
+ hg.nodes_.resize(1);
+ hg.edges_.resize(2);
+ hg.edges_[0].rule_ = tr3;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_);
+ double acc = hg.edges_[0].feature_values_.dot(w);
+ cerr << hg.edges_[0].feature_values_[0] << endl;
+ hg.edges_[1].tail_nodes_.push_back(0);
+ hg.edges_[1].rule_ = tr1;
+ string state2;
+ models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2);
+ acc += hg.edges_[1].feature_values_.dot(w);
+ double tot = hg.edges_[1].feature_values_[0] + hg.edges_[0].feature_values_[0];
+ cerr << "lm=" << tot << endl;
+ cerr << "acc=" << acc << endl;
+ cerr << " s1=" << s1 << endl;
+ EXPECT_TRUE(state2 == state);
+ EXPECT_FALSE(state == hg.nodes_[0].state_);
+ EXPECT_FLOAT_EQ(acc, s1);
+}
+
+TEST_F(FFTest, LM3) {
+ int x = lm3_->NumBytesContext();
+ Hypergraph::Edge edge1;
+ edge1.rule_.reset(new TRule("[X] ||| x y ||| one ||| 1.0 -2.4 3.0"));
+ Hypergraph::Edge edge2;
+ edge2.rule_.reset(new TRule("[X] ||| [X,1] a ||| [X,1] two ||| 1.0 -2.4 3.0"));
+ Hypergraph::Edge edge3;
+ edge3.rule_.reset(new TRule("[X] ||| [X,1] a ||| zero [X,1] two ||| 1.0 -2.4 3.0"));
+ vector<const void*> ants1;
+ string state(x, '\0');
+ SparseVector<double> feats;
+ SparseVector<double> est;
+ lm3_->TraversalFeatures(smeta, edge1, ants1, &feats, &est, (void *)&state[0]);
+ cerr << "returned " << feats << endl;
+ cerr << edge1.feature_values_ << endl;
+ cerr << lm3_->DebugStateToString((const void*)&state[0]) << endl;
+ EXPECT_EQ("[ one ]", lm3_->DebugStateToString((const void*)&state[0]));
+ ants1.push_back((const void*)&state[0]);
+ string state2(x, '\0');
+ lm3_->TraversalFeatures(smeta, edge2, ants1, &feats, &est, (void *)&state2[0]);
+ cerr << lm3_->DebugStateToString((const void*)&state2[0]) << endl;
+ EXPECT_EQ("[ one two ]", lm3_->DebugStateToString((const void*)&state2[0]));
+ string state3(x, '\0');
+ lm3_->TraversalFeatures(smeta, edge3, ants1, &feats, &est, (void *)&state3[0]);
+ cerr << lm3_->DebugStateToString((const void*)&state3[0]) << endl;
+ EXPECT_EQ("[ zero one <{STAR}> one two ]", lm3_->DebugStateToString((const void*)&state3[0]));
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/ff_wordalign.cc b/src/ff_wordalign.cc
new file mode 100644
index 00000000..e605ac8d
--- /dev/null
+++ b/src/ff_wordalign.cc
@@ -0,0 +1,221 @@
+#include "ff_wordalign.h"
+
+#include <string>
+#include <cmath>
+
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "hg.h"
+#include "fdict.h"
+#include "aligner.h"
+#include "tdict.h" // Blunsom hack
+#include "filelib.h" // Blunsom hack
+
+using namespace std;
+
+RelativeSentencePosition::RelativeSentencePosition(const string& param) :
+ fid_(FD::Convert("RelativeSentencePosition")) {}
+
+void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ // if the source word is either null or the generated word
+ // has no position in the reference
+ if (edge.i_ == -1 || edge.prev_i_ == -1)
+ return;
+
+ assert(smeta.GetTargetLength() > 0);
+ const double val = fabs(static_cast<double>(edge.i_) / smeta.GetSourceLength() -
+ static_cast<double>(edge.prev_i_) / smeta.GetTargetLength());
+ features->set_value(fid_, val);
+// cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl;
+}
+
+MarkovJump::MarkovJump(const string& param) :
+ FeatureFunction(1),
+ fid_(FD::Convert("MarkovJump")),
+ individual_params_per_jumpsize_(false),
+ condition_on_flen_(false) {
+ cerr << " MarkovJump: Blunsom&Cohn feature";
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc > 0) {
+ if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) {
+ cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n";
+ exit(1);
+ }
+ individual_params_per_jumpsize_ = (argv[0][1] == 'i');
+ condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f');
+ if (individual_params_per_jumpsize_) {
+ template_ = "Jump:000";
+ cerr << ", individual jump parameters";
+ if (condition_on_flen_) {
+ template_ += ":F00";
+ cerr << " (split by f-length)";
+ }
+ }
+ }
+ cerr << endl;
+}
+
+void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ unsigned char& dpstate = *((unsigned char*)state);
+ if (edge.Arity() == 0) {
+ dpstate = static_cast<unsigned int>(edge.i_);
+ } else if (edge.Arity() == 1) {
+ dpstate = *((unsigned char*)ant_states[0]);
+ } else if (edge.Arity() == 2) {
+ int left_index = *((unsigned char*)ant_states[0]);
+ int right_index = *((unsigned char*)ant_states[1]);
+ if (right_index == -1)
+ dpstate = static_cast<unsigned int>(left_index);
+ else
+ dpstate = static_cast<unsigned int>(right_index);
+ const int jumpsize = right_index - left_index;
+ features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def
+
+ if (individual_params_per_jumpsize_) {
+ string fname = template_;
+ int param = jumpsize;
+ if (jumpsize < 0) {
+ param *= -1;
+ fname[5]='L';
+ } else if (jumpsize > 0) {
+ fname[5]='R';
+ }
+ if (param) {
+ fname[6] = '0' + (param / 10);
+ fname[7] = '0' + (param % 10);
+ }
+ if (condition_on_flen_) {
+ const int flen = smeta.GetSourceLength();
+ fname[10] = '0' + (flen / 10);
+ fname[11] = '0' + (flen % 10);
+ }
+ features->set_value(FD::Convert(fname), 1.0);
+ }
+ } else {
+ assert(!"something really unexpected is happening");
+ }
+}
+
+AlignerResults::AlignerResults(const std::string& param) :
+ cur_sent_(-1),
+ cur_grid_(NULL) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc != 2) {
+ cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n";
+ exit(1);
+ }
+ cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl;
+ fid_ = FD::Convert(argv[0]);
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream(); int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ is_aligned_.push_back(AlignerTools::ReadPharaohAlignmentGrid(line));
+ }
+ cerr << " Loaded " << lc << " refs\n";
+}
+
+void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ if (edge.i_ == -1 || edge.prev_i_ == -1)
+ return;
+
+ if (cur_sent_ != smeta.GetSentenceID()) {
+ assert(smeta.HasReference());
+ cur_sent_ = smeta.GetSentenceID();
+ assert(cur_sent_ < is_aligned_.size());
+ cur_grid_ = is_aligned_[cur_sent_].get();
+ }
+
+ //cerr << edge.rule_->AsString() << endl;
+
+ int j = edge.i_; // source side (f)
+ int i = edge.prev_i_; // target side (e)
+ if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) {
+// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) {
+ features->set_value(fid_, 1.0);
+// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n";
+// }
+ }
+}
+
+BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) :
+ FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) {
+ ReadFile rf(param);
+ istream& in = *rf.stream(); int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (!in) break;
+ ++lc;
+ refs_.push_back(vector<WordID>());
+ TD::ConvertSentence(line, &refs_.back());
+ }
+ cerr << " Loaded " << lc << " refs\n";
+}
+
+void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ if (cur_sent_ != smeta.GetSentenceID()) {
+ // assert(smeta.HasReference());
+ cur_sent_ = smeta.GetSentenceID();
+ assert(cur_sent_ < refs_.size());
+ cur_ref_ = &refs_[cur_sent_];
+ cur_map_.clear();
+ for (int i = 0; i < cur_ref_->size(); ++i) {
+ vector<WordID> phrase;
+ for (int j = i; j < cur_ref_->size(); ++j) {
+ phrase.push_back((*cur_ref_)[j]);
+ cur_map_[phrase] = i;
+ }
+ }
+ }
+ //cerr << edge.rule_->AsString() << endl;
+ for (int i = 0; i < ant_states.size(); ++i) {
+ if (DoesNotBelong(ant_states[i])) {
+ //cerr << " ant " << i << " does not belong\n";
+ return;
+ }
+ }
+ vector<vector<WordID> > ants(ant_states.size());
+ vector<const vector<WordID>* > pants(ant_states.size());
+ for (int i = 0; i < ant_states.size(); ++i) {
+ AppendAntecedentString(ant_states[i], &ants[i]);
+ //cerr << " ant[" << i << "]: " << ((int)*(static_cast<const unsigned char*>(ant_states[i]))) << " " << TD::GetString(ants[i]) << endl;
+ pants[i] = &ants[i];
+ }
+ vector<WordID> yield;
+ edge.rule_->ESubstitute(pants, &yield);
+ //cerr << "YIELD: " << TD::GetString(yield) << endl;
+ Vec2Int::iterator it = cur_map_.find(yield);
+ if (it == cur_map_.end()) {
+ features->set_value(fid_, 1);
+ //cerr << " BAD!\n";
+ return;
+ }
+ SetStateMask(it->second, it->second + yield.size(), state);
+}
+
diff --git a/src/ff_wordalign.h b/src/ff_wordalign.h
new file mode 100644
index 00000000..1581641c
--- /dev/null
+++ b/src/ff_wordalign.h
@@ -0,0 +1,133 @@
+#ifndef _FF_WORD_ALIGN_H_
+#define _FF_WORD_ALIGN_H_
+
+#include "ff.h"
+#include "array2d.h"
+
+class RelativeSentencePosition : public FeatureFunction {
+ public:
+ RelativeSentencePosition(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+};
+
+class MarkovJump : public FeatureFunction {
+ public:
+ MarkovJump(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+ bool individual_params_per_jumpsize_;
+ bool condition_on_flen_;
+ std::string template_;
+};
+
+class AlignerResults : public FeatureFunction {
+ public:
+ AlignerResults(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ int fid_;
+ std::vector<boost::shared_ptr<Array2D<bool> > > is_aligned_;
+ mutable int cur_sent_;
+ const Array2D<bool> mutable* cur_grid_;
+};
+
+#include <tr1/unordered_map>
+#include <boost/functional/hash.hpp>
+#include <cassert>
+class BlunsomSynchronousParseHack : public FeatureFunction {
+ public:
+ BlunsomSynchronousParseHack(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ inline bool DoesNotBelong(const void* state) const {
+ for (int i = 0; i < NumBytesContext(); ++i) {
+ if (*(static_cast<const unsigned char*>(state) + i)) return false;
+ }
+ return true;
+ }
+
+ inline void AppendAntecedentString(const void* state, std::vector<WordID>* yield) const {
+ int i = 0;
+ int ind = 0;
+ while (i < NumBytesContext() && !(*(static_cast<const unsigned char*>(state) + i))) { ++i; ind += 8; }
+ // std::cerr << i << " " << NumBytesContext() << std::endl;
+ assert(i != NumBytesContext());
+ assert(ind < cur_ref_->size());
+ int cur = *(static_cast<const unsigned char*>(state) + i);
+ int comp = 1;
+ while (comp < 256 && (comp & cur) == 0) { comp <<= 1; ++ind; }
+ assert(ind < cur_ref_->size());
+ assert(comp < 256);
+ do {
+ assert(ind < cur_ref_->size());
+ yield->push_back((*cur_ref_)[ind]);
+ ++ind;
+ comp <<= 1;
+ if (comp == 256) {
+ comp = 1;
+ ++i;
+ cur = *(static_cast<const unsigned char*>(state) + i);
+ }
+ } while (comp & cur);
+ }
+
+ inline void SetStateMask(int start, int end, void* state) const {
+ assert((end / 8) < NumBytesContext());
+ int i = 0;
+ int comp = 1;
+ for (int j = 0; j < start; ++j) {
+ comp <<= 1;
+ if (comp == 256) {
+ ++i;
+ comp = 1;
+ }
+ }
+ //std::cerr << "SM: " << i << "\n";
+ for (int j = start; j < end; ++j) {
+ *(static_cast<unsigned char*>(state) + i) |= comp;
+ //std::cerr << " " << comp << "\n";
+ comp <<= 1;
+ if (comp == 256) {
+ ++i;
+ comp = 1;
+ }
+ }
+ //std::cerr << " MASK: " << ((int)*(static_cast<unsigned char*>(state))) << "\n";
+ }
+
+ const int fid_;
+ mutable int cur_sent_;
+ typedef std::tr1::unordered_map<std::vector<WordID>, int, boost::hash<std::vector<WordID> > > Vec2Int;
+ mutable Vec2Int cur_map_;
+ const std::vector<WordID> mutable * cur_ref_;
+ mutable std::vector<std::vector<WordID> > refs_;
+};
+
+#endif
diff --git a/src/filelib.cc b/src/filelib.cc
new file mode 100644
index 00000000..79ad2847
--- /dev/null
+++ b/src/filelib.cc
@@ -0,0 +1,22 @@
+#include "filelib.h"
+
+#include <unistd.h>
+#include <sys/stat.h>
+
+using namespace std;
+
+bool FileExists(const std::string& fn) {
+ struct stat info;
+ int s = stat(fn.c_str(), &info);
+ return (s==0);
+}
+
+bool DirectoryExists(const string& dir) {
+ if (access(dir.c_str(),0) == 0) {
+ struct stat status;
+ stat(dir.c_str(), &status);
+ if (status.st_mode & S_IFDIR) return true;
+ }
+ return false;
+}
+
diff --git a/src/filelib.h b/src/filelib.h
new file mode 100644
index 00000000..62cb9427
--- /dev/null
+++ b/src/filelib.h
@@ -0,0 +1,66 @@
+#ifndef _FILELIB_H_
+#define _FILELIB_H_
+
+#include <cassert>
+#include <string>
+#include <iostream>
+#include <cstdlib>
+#include "gzstream.h"
+
+// reads from standard in if filename is -
+// uncompresses if file ends with .gz
+// otherwise, reads from a normal file
+class ReadFile {
+ public:
+ ReadFile(const std::string& filename) :
+ no_delete_on_exit_(filename == "-"),
+ in_(no_delete_on_exit_ ? static_cast<std::istream*>(&std::cin) :
+ (EndsWith(filename, ".gz") ?
+ static_cast<std::istream*>(new igzstream(filename.c_str())) :
+ static_cast<std::istream*>(new std::ifstream(filename.c_str())))) {
+ if (!*in_) {
+ std::cerr << "Failed to open " << filename << std::endl;
+ abort();
+ }
+ }
+ ~ReadFile() {
+ if (!no_delete_on_exit_) delete in_;
+ }
+
+ inline std::istream* stream() { return in_; }
+
+ private:
+ static bool EndsWith(const std::string& f, const std::string& suf) {
+ return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
+ }
+ const bool no_delete_on_exit_;
+ std::istream* const in_;
+};
+
+class WriteFile {
+ public:
+ WriteFile(const std::string& filename) :
+ no_delete_on_exit_(filename == "-"),
+ out_(no_delete_on_exit_ ? static_cast<std::ostream*>(&std::cout) :
+ (EndsWith(filename, ".gz") ?
+ static_cast<std::ostream*>(new ogzstream(filename.c_str())) :
+ static_cast<std::ostream*>(new std::ofstream(filename.c_str())))) {}
+ ~WriteFile() {
+ (*out_) << std::flush;
+ if (!no_delete_on_exit_) delete out_;
+ }
+
+ inline std::ostream* stream() { return out_; }
+
+ private:
+ static bool EndsWith(const std::string& f, const std::string& suf) {
+ return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size());
+ }
+ const bool no_delete_on_exit_;
+ std::ostream* const out_;
+};
+
+bool FileExists(const std::string& file_name);
+bool DirectoryExists(const std::string& dir_name);
+
+#endif
diff --git a/src/forest_writer.cc b/src/forest_writer.cc
new file mode 100644
index 00000000..a9117d18
--- /dev/null
+++ b/src/forest_writer.cc
@@ -0,0 +1,23 @@
+#include "forest_writer.h"
+
+#include <iostream>
+
+#include <boost/lexical_cast.hpp>
+
+#include "filelib.h"
+#include "hg_io.h"
+#include "hg.h"
+
+using namespace std;
+
+ForestWriter::ForestWriter(const std::string& path, int num) :
+ fname_(path + '/' + boost::lexical_cast<string>(num) + ".json.gz"), used_(false) {}
+
+bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) {
+ assert(!used_);
+ used_ = true;
+ cerr << " Writing forest to " << fname_ << endl;
+ WriteFile wf(fname_);
+ return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream());
+}
+
diff --git a/src/forest_writer.h b/src/forest_writer.h
new file mode 100644
index 00000000..819a8940
--- /dev/null
+++ b/src/forest_writer.h
@@ -0,0 +1,16 @@
+#ifndef _FOREST_WRITER_H_
+#define _FOREST_WRITER_H_
+
+#include <string>
+
+class Hypergraph;
+
+struct ForestWriter {
+ ForestWriter(const std::string& path, int num);
+ bool Write(const Hypergraph& forest, bool minimal_rules);
+
+ const std::string fname_;
+ bool used_;
+};
+
+#endif
diff --git a/src/freqdict.cc b/src/freqdict.cc
new file mode 100644
index 00000000..4cfffe58
--- /dev/null
+++ b/src/freqdict.cc
@@ -0,0 +1,23 @@
+#include <iostream>
+#include <fstream>
+#include <cassert>
+#include "freqdict.h"
+
+void FreqDict::load(const std::string& fname) {
+ std::ifstream ifs(fname.c_str());
+ int cc=0;
+ while (!ifs.eof()) {
+ std::string word;
+ ifs >> word;
+ if (word.size() == 0) continue;
+ if (word[0] == '#') continue;
+ double count = 0;
+ ifs >> count;
+ assert(count > 0.0); // use -log(f)
+ counts_[word]=count;
+ ++cc;
+ if (cc % 10000 == 0) { std::cerr << "."; }
+ }
+ std::cerr << "\n";
+ std::cerr << "Loaded " << cc << " words\n";
+}
diff --git a/src/freqdict.h b/src/freqdict.h
new file mode 100644
index 00000000..c9bb4c42
--- /dev/null
+++ b/src/freqdict.h
@@ -0,0 +1,19 @@
+#ifndef _FREQDICT_H_
+#define _FREQDICT_H_
+
+#include <map>
+#include <string>
+
+class FreqDict {
+ public:
+ void load(const std::string& fname);
+ float frequency(const std::string& word) const {
+ std::map<std::string,float>::const_iterator i = counts_.find(word);
+ if (i == counts_.end()) return 0;
+ return i->second;
+ }
+ private:
+ std::map<std::string, float> counts_;
+};
+
+#endif
diff --git a/src/fst_translator.cc b/src/fst_translator.cc
new file mode 100644
index 00000000..57feb227
--- /dev/null
+++ b/src/fst_translator.cc
@@ -0,0 +1,91 @@
+#include "translator.h"
+
+#include <sstream>
+#include <boost/shared_ptr.hpp>
+
+#include "sentence_metadata.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "earley_composer.h"
+#include "phrasetable_fst.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct FSTTranslatorImpl {
+ FSTTranslatorImpl(const boost::program_options::variables_map& conf) :
+ goal_sym(conf["goal"].as<string>()),
+ kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")),
+ kGOAL(TD::Convert("Goal") * -1),
+ add_pass_through_rules(conf.count("add_pass_through_rules")) {
+ fst.reset(LoadTextPhrasetable(conf["grammar"].as<vector<string> >()));
+ ec.reset(new EarleyComposer(fst.get()));
+ }
+
+ bool Translate(const string& input,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ bool composed = false;
+ if (input.find("{\"rules\"") == 0) {
+ istringstream is(input);
+ Hypergraph src_cfg_hg;
+ assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg));
+ if (add_pass_through_rules) {
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) {
+ const vector<WordID>& f = src_cfg_hg.edges_[i].rule_->f_;
+ for (int j = 0; j < f.size(); ++j) {
+ if (f[j] > 0) {
+ fst->AddPassThroughTranslation(f[j], feats);
+ }
+ }
+ }
+ }
+ composed = ec->Compose(src_cfg_hg, forest);
+ } else {
+ const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1");
+ cerr << " Dummy grammar: " << dummy_grammar << endl;
+ istringstream is(dummy_grammar);
+ if (add_pass_through_rules) {
+ vector<WordID> words;
+ TD::ConvertSentence(input, &words);
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < words.size(); ++i)
+ fst->AddPassThroughTranslation(words[i], feats);
+ }
+ composed = ec->Compose(&is, forest);
+ }
+ if (composed) {
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1, "");
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ forest->Reweight(weights);
+ }
+ if (add_pass_through_rules)
+ fst->ClearPassThroughTranslations();
+ return composed;
+ }
+
+ const string goal_sym;
+ const TRulePtr kGOAL_RULE;
+ const WordID kGOAL;
+ const bool add_pass_through_rules;
+ boost::shared_ptr<EarleyComposer> ec;
+ boost::shared_ptr<FSTNode> fst;
+};
+
+FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new FSTTranslatorImpl(conf)) {}
+
+bool FSTTranslator::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ smeta->SetSourceLength(0); // don't know how to compute this
+ return pimpl_->Translate(input, weights, minus_lm_forest);
+}
+
diff --git a/src/grammar.cc b/src/grammar.cc
new file mode 100644
index 00000000..69e38320
--- /dev/null
+++ b/src/grammar.cc
@@ -0,0 +1,163 @@
+#include "grammar.h"
+
+#include <algorithm>
+#include <utility>
+#include <map>
+
+#include "filelib.h"
+#include "tdict.h"
+
+using namespace std;
+
+const vector<TRulePtr> Grammar::NO_RULES;
+
+RuleBin::~RuleBin() {}
+GrammarIter::~GrammarIter() {}
+Grammar::~Grammar() {}
+
+bool Grammar::HasRuleForSpan(int i, int j) const {
+ (void) i;
+ (void) j;
+ return true; // always true by default
+}
+
+struct TextRuleBin : public RuleBin {
+ int GetNumRules() const {
+ return rules_.size();
+ }
+ TRulePtr GetIthRule(int i) const {
+ return rules_[i];
+ }
+ void AddRule(TRulePtr t) {
+ rules_.push_back(t);
+ }
+ int Arity() const {
+ return rules_.front()->Arity();
+ }
+ void Dump() const {
+ for (int i = 0; i < rules_.size(); ++i)
+ VLOG(1) << rules_[i]->AsString() << endl;
+ }
+ private:
+ vector<TRulePtr> rules_;
+};
+
+struct TextGrammarNode : public GrammarIter {
+ TextGrammarNode() : rb_(NULL) {}
+ ~TextGrammarNode() {
+ delete rb_;
+ }
+ const GrammarIter* Extend(int symbol) const {
+ map<WordID, TextGrammarNode>::const_iterator i = tree_.find(symbol);
+ if (i == tree_.end()) return NULL;
+ return &i->second;
+ }
+
+ const RuleBin* GetRules() const {
+ if (rb_) {
+ //rb_->Dump();
+ }
+ return rb_;
+ }
+
+ map<WordID, TextGrammarNode> tree_;
+ TextRuleBin* rb_;
+};
+
+struct TGImpl {
+ TextGrammarNode root_;
+};
+
+TextGrammar::TextGrammar() : max_span_(10), pimpl_(new TGImpl) {}
+TextGrammar::TextGrammar(const string& file) :
+ max_span_(10),
+ pimpl_(new TGImpl) {
+ ReadFromFile(file);
+}
+
+const GrammarIter* TextGrammar::GetRoot() const {
+ return &pimpl_->root_;
+}
+
+void TextGrammar::AddRule(const TRulePtr& rule) {
+ if (rule->IsUnary()) {
+ rhs2unaries_[rule->f().front()].push_back(rule);
+ unaries_.push_back(rule);
+ } else {
+ TextGrammarNode* cur = &pimpl_->root_;
+ for (int i = 0; i < rule->f_.size(); ++i)
+ cur = &cur->tree_[rule->f_[i]];
+ if (cur->rb_ == NULL)
+ cur->rb_ = new TextRuleBin;
+ cur->rb_->AddRule(rule);
+ }
+}
+
+void TextGrammar::ReadFromFile(const string& filename) {
+ ReadFile in(filename);
+ istream& in_file = *in.stream();
+ assert(in_file);
+ long long int rule_count = 0;
+ bool fl = false;
+ while(in_file) {
+ string line;
+ getline(in_file, line);
+ if (line.empty()) continue;
+ ++rule_count;
+ if (rule_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (rule_count % 2000000 == 0) { cerr << " [" << rule_count << "]\n"; fl = false; }
+ TRulePtr rule(TRule::CreateRuleSynchronous(line));
+ if (rule) {
+ AddRule(rule);
+ } else {
+ if (fl) { cerr << endl; }
+ cerr << "Skipping badly formatted rule in line " << rule_count << " of " << filename << endl;
+ fl = false;
+ }
+ }
+ if (fl) cerr << endl;
+ cerr << " " << rule_count << " rules read.\n";
+}
+
+bool TextGrammar::HasRuleForSpan(int i, int j) const {
+ return (max_span_ >= (j - i));
+}
+
+GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {}
+
+GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) {
+ TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]"));
+ TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] ["
+ + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1"));
+
+ AddRule(stop_glue);
+ AddRule(glue);
+ //cerr << "GLUE: " << stop_glue->AsString() << endl;
+ //cerr << "GLUE: " << glue->AsString() << endl;
+}
+
+bool GlueGrammar::HasRuleForSpan(int i, int j) const {
+ (void) j;
+ return (i == 0);
+}
+
+PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) :
+ has_rule_(input.size() + 1) {
+ for (int i = 0; i < input.size(); ++i) {
+ const vector<LatticeArc>& alts = input[i];
+ for (int k = 0; k < alts.size(); ++k) {
+ const int j = alts[k].dist2next + i;
+ has_rule_[i].insert(j);
+ const string& src = TD::Convert(alts[k].label);
+ TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1"));
+ AddRule(pt);
+// cerr << "PT: " << pt->AsString() << endl;
+ }
+ }
+}
+
+bool PassThroughGrammar::HasRuleForSpan(int i, int j) const {
+ const set<int>& hr = has_rule_[i];
+ if (i == j) { return !hr.empty(); }
+ return (hr.find(j) != hr.end());
+}
diff --git a/src/grammar.h b/src/grammar.h
new file mode 100644
index 00000000..4a03c505
--- /dev/null
+++ b/src/grammar.h
@@ -0,0 +1,83 @@
+#ifndef GRAMMAR_H_
+#define GRAMMAR_H_
+
+#include <vector>
+#include <map>
+#include <set>
+#include <boost/shared_ptr.hpp>
+
+#include "lattice.h"
+#include "trule.h"
+
+struct RuleBin {
+ virtual ~RuleBin();
+ virtual int GetNumRules() const = 0;
+ virtual TRulePtr GetIthRule(int i) const = 0;
+ virtual int Arity() const = 0;
+};
+
+struct GrammarIter {
+ virtual ~GrammarIter();
+ virtual const RuleBin* GetRules() const = 0;
+ virtual const GrammarIter* Extend(int symbol) const = 0;
+};
+
+struct Grammar {
+ typedef std::map<WordID, std::vector<TRulePtr> > Cat2Rules;
+ static const std::vector<TRulePtr> NO_RULES;
+
+ virtual ~Grammar();
+ virtual const GrammarIter* GetRoot() const = 0;
+ virtual bool HasRuleForSpan(int i, int j) const;
+
+ // cat is the category to be rewritten
+ inline const std::vector<TRulePtr>& GetAllUnaryRules() const {
+ return unaries_;
+ }
+
+ // get all the unary rules that rewrite category cat
+ inline const std::vector<TRulePtr>& GetUnaryRulesForRHS(const WordID& cat) const {
+ Cat2Rules::const_iterator found = rhs2unaries_.find(cat);
+ if (found == rhs2unaries_.end())
+ return NO_RULES;
+ else
+ return found->second;
+ }
+
+ protected:
+ Cat2Rules rhs2unaries_; // these must be filled in by subclasses!
+ std::vector<TRulePtr> unaries_;
+};
+
+typedef boost::shared_ptr<Grammar> GrammarPtr;
+
+class TGImpl;
+struct TextGrammar : public Grammar {
+ TextGrammar();
+ TextGrammar(const std::string& file);
+ void SetMaxSpan(int m) { max_span_ = m; }
+ virtual const GrammarIter* GetRoot() const;
+ void AddRule(const TRulePtr& rule);
+ void ReadFromFile(const std::string& filename);
+ virtual bool HasRuleForSpan(int i, int j) const;
+ const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const;
+ private:
+ int max_span_;
+ boost::shared_ptr<TGImpl> pimpl_;
+};
+
+struct GlueGrammar : public TextGrammar {
+ // read glue grammar from file
+ explicit GlueGrammar(const std::string& file);
+ GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X"
+ virtual bool HasRuleForSpan(int i, int j) const;
+};
+
+struct PassThroughGrammar : public TextGrammar {
+ PassThroughGrammar(const Lattice& input, const std::string& cat);
+ virtual bool HasRuleForSpan(int i, int j) const;
+ private:
+ std::vector<std::set<int> > has_rule_; // index by [i][j]
+};
+
+#endif
diff --git a/src/grammar_test.cc b/src/grammar_test.cc
new file mode 100644
index 00000000..62b8f958
--- /dev/null
+++ b/src/grammar_test.cc
@@ -0,0 +1,59 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "trule.h"
+#include "tdict.h"
+#include "grammar.h"
+#include "bottom_up_parser.h"
+#include "ff.h"
+#include "weights.h"
+
+using namespace std;
+
+class GrammarTest : public testing::Test {
+ public:
+ GrammarTest() {
+ wts.InitFromFile("test_data/weights.gt");
+ }
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ Weights wts;
+};
+
+TEST_F(GrammarTest,TestTextGrammar) {
+ vector<double> w;
+ vector<const FeatureFunction*> ms;
+ ModelSet models(w, ms);
+
+ TextGrammar g;
+ TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true));
+ TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true));
+ TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true));
+ cerr << r1->AsString() << endl;
+ g.AddRule(r1);
+ g.AddRule(r2);
+ g.AddRule(r3);
+}
+
+TEST_F(GrammarTest,TestTextGrammarFile) {
+ GrammarPtr g(new TextGrammar("./test_data/grammar.prune"));
+ vector<GrammarPtr> grammars(1, g);
+
+ LatticeArc a(TD::Convert("ein"), 0.0, 1);
+ LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ Lattice lattice(2);
+ lattice[0].push_back(a);
+ lattice[1].push_back(b);
+ Hypergraph forest;
+ ExhaustiveBottomUpParser parser("PHRASE", grammars);
+ parser.Parse(lattice, &forest);
+ forest.PrintGraphviz();
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/gzstream.cc b/src/gzstream.cc
new file mode 100644
index 00000000..9703e6ad
--- /dev/null
+++ b/src/gzstream.cc
@@ -0,0 +1,165 @@
+// ============================================================================
+// gzstream, C++ iostream classes wrapping the zlib compression library.
+// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+// ============================================================================
+//
+// File : gzstream.C
+// Revision : $Revision: 1.1 $
+// Revision_date : $Date: 2006/03/30 04:05:52 $
+// Author(s) : Deepak Bandyopadhyay, Lutz Kettner
+//
+// Standard streambuf implementation following Nicolai Josuttis, "The
+// Standard C++ Library".
+// ============================================================================
+
+#include "gzstream.h"
+#include <iostream>
+#include <cstring>
+
+#ifdef GZSTREAM_NAMESPACE
+namespace GZSTREAM_NAMESPACE {
+#endif
+
+// ----------------------------------------------------------------------------
+// Internal classes to implement gzstream. See header file for user classes.
+// ----------------------------------------------------------------------------
+
+// --------------------------------------
+// class gzstreambuf:
+// --------------------------------------
+
+gzstreambuf* gzstreambuf::open( const char* name, int open_mode) {
+ if ( is_open())
+ return (gzstreambuf*)0;
+ mode = open_mode;
+ // no append nor read/write mode
+ if ((mode & std::ios::ate) || (mode & std::ios::app)
+ || ((mode & std::ios::in) && (mode & std::ios::out)))
+ return (gzstreambuf*)0;
+ char fmode[10];
+ char* fmodeptr = fmode;
+ if ( mode & std::ios::in)
+ *fmodeptr++ = 'r';
+ else if ( mode & std::ios::out)
+ *fmodeptr++ = 'w';
+ *fmodeptr++ = 'b';
+ *fmodeptr = '\0';
+ file = gzopen( name, fmode);
+ if (file == 0)
+ return (gzstreambuf*)0;
+ opened = 1;
+ return this;
+}
+
+gzstreambuf * gzstreambuf::close() {
+ if ( is_open()) {
+ sync();
+ opened = 0;
+ if ( gzclose( file) == Z_OK)
+ return this;
+ }
+ return (gzstreambuf*)0;
+}
+
+int gzstreambuf::underflow() { // used for input buffer only
+ if ( gptr() && ( gptr() < egptr()))
+ return * reinterpret_cast<unsigned char *>( gptr());
+
+ if ( ! (mode & std::ios::in) || ! opened)
+ return EOF;
+ // Josuttis' implementation of inbuf
+ int n_putback = gptr() - eback();
+ if ( n_putback > 4)
+ n_putback = 4;
+ memcpy( buffer + (4 - n_putback), gptr() - n_putback, n_putback);
+
+ int num = gzread( file, buffer+4, bufferSize-4);
+ if (num <= 0) // ERROR or EOF
+ return EOF;
+
+ // reset buffer pointers
+ setg( buffer + (4 - n_putback), // beginning of putback area
+ buffer + 4, // read position
+ buffer + 4 + num); // end of buffer
+
+ // return next character
+ return * reinterpret_cast<unsigned char *>( gptr());
+}
+
+int gzstreambuf::flush_buffer() {
+ // Separate the writing of the buffer from overflow() and
+ // sync() operation.
+ int w = pptr() - pbase();
+ if ( gzwrite( file, pbase(), w) != w)
+ return EOF;
+ pbump( -w);
+ return w;
+}
+
+int gzstreambuf::overflow( int c) { // used for output buffer only
+ if ( ! ( mode & std::ios::out) || ! opened)
+ return EOF;
+ if (c != EOF) {
+ *pptr() = c;
+ pbump(1);
+ }
+ if ( flush_buffer() == EOF)
+ return EOF;
+ return c;
+}
+
+int gzstreambuf::sync() {
+ // Changed to use flush_buffer() instead of overflow( EOF)
+ // which caused improper behavior with std::endl and flush(),
+ // bug reported by Vincent Ricard.
+ if ( pptr() && pptr() > pbase()) {
+ if ( flush_buffer() == EOF)
+ return -1;
+ }
+ return 0;
+}
+
+// --------------------------------------
+// class gzstreambase:
+// --------------------------------------
+
+gzstreambase::gzstreambase( const char* name, int mode) {
+ init( &buf);
+ open( name, mode);
+}
+
+gzstreambase::~gzstreambase() {
+ buf.close();
+}
+
+void gzstreambase::open( const char* name, int open_mode) {
+ if ( ! buf.open( name, open_mode))
+ clear( rdstate() | std::ios::badbit);
+}
+
+void gzstreambase::close() {
+ if ( buf.is_open())
+ if ( ! buf.close())
+ clear( rdstate() | std::ios::badbit);
+}
+
+#ifdef GZSTREAM_NAMESPACE
+} // namespace GZSTREAM_NAMESPACE
+#endif
+
+// ============================================================================
+// EOF //
diff --git a/src/gzstream.h b/src/gzstream.h
new file mode 100644
index 00000000..ad9785fd
--- /dev/null
+++ b/src/gzstream.h
@@ -0,0 +1,121 @@
+// ============================================================================
+// gzstream, C++ iostream classes wrapping the zlib compression library.
+// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+// ============================================================================
+//
+// File : gzstream.h
+// Revision : $Revision: 1.1 $
+// Revision_date : $Date: 2006/03/30 04:05:52 $
+// Author(s) : Deepak Bandyopadhyay, Lutz Kettner
+//
+// Standard streambuf implementation following Nicolai Josuttis, "The
+// Standard C++ Library".
+// ============================================================================
+
+#ifndef GZSTREAM_H
+#define GZSTREAM_H 1
+
+// standard C++ with new header file names and std:: namespace
+#include <iostream>
+#include <fstream>
+#include <zlib.h>
+
+#ifdef GZSTREAM_NAMESPACE
+namespace GZSTREAM_NAMESPACE {
+#endif
+
+// ----------------------------------------------------------------------------
+// Internal classes to implement gzstream. See below for user classes.
+// ----------------------------------------------------------------------------
+
+class gzstreambuf : public std::streambuf {
+private:
+ static const int bufferSize = 47+256; // size of data buff
+ // totals 512 bytes under g++ for igzstream at the end.
+
+ gzFile file; // file handle for compressed file
+ char buffer[bufferSize]; // data buffer
+ char opened; // open/close state of stream
+ int mode; // I/O mode
+
+ int flush_buffer();
+public:
+ gzstreambuf() : opened(0) {
+ setp( buffer, buffer + (bufferSize-1));
+ setg( buffer + 4, // beginning of putback area
+ buffer + 4, // read position
+ buffer + 4); // end position
+ // ASSERT: both input & output capabilities will not be used together
+ }
+ int is_open() { return opened; }
+ gzstreambuf* open( const char* name, int open_mode);
+ gzstreambuf* close();
+ ~gzstreambuf() { close(); }
+
+ virtual int overflow( int c = EOF);
+ virtual int underflow();
+ virtual int sync();
+};
+
+class gzstreambase : virtual public std::ios {
+protected:
+ gzstreambuf buf;
+public:
+ gzstreambase() { init(&buf); }
+ gzstreambase( const char* name, int open_mode);
+ ~gzstreambase();
+ void open( const char* name, int open_mode);
+ void close();
+ gzstreambuf* rdbuf() { return &buf; }
+};
+
+// ----------------------------------------------------------------------------
+// User classes. Use igzstream and ogzstream analogously to ifstream and
+// ofstream respectively. They read and write files based on the gz*
+// function interface of the zlib. Files are compatible with gzip compression.
+// ----------------------------------------------------------------------------
+
+class igzstream : public gzstreambase, public std::istream {
+public:
+ igzstream() : std::istream( &buf) {}
+ igzstream( const char* name, int open_mode = std::ios::in)
+ : gzstreambase( name, open_mode), std::istream( &buf) {}
+ gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); }
+ void open( const char* name, int open_mode = std::ios::in) {
+ gzstreambase::open( name, open_mode);
+ }
+};
+
+class ogzstream : public gzstreambase, public std::ostream {
+public:
+ ogzstream() : std::ostream( &buf) {}
+ ogzstream( const char* name, int mode = std::ios::out)
+ : gzstreambase( name, mode), std::ostream( &buf) {}
+ gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); }
+ void open( const char* name, int open_mode = std::ios::out) {
+ gzstreambase::open( name, open_mode);
+ }
+};
+
+#ifdef GZSTREAM_NAMESPACE
+} // namespace GZSTREAM_NAMESPACE
+#endif
+
+#endif // GZSTREAM_H
+// ============================================================================
+// EOF //
+
diff --git a/src/hg.cc b/src/hg.cc
new file mode 100644
index 00000000..dd8f8eba
--- /dev/null
+++ b/src/hg.cc
@@ -0,0 +1,483 @@
+#include "hg.h"
+
+#include <cassert>
+#include <numeric>
+#include <set>
+#include <map>
+#include <iostream>
+
+#include "viterbi.h"
+#include "inside_outside.h"
+#include "tdict.h"
+
+using namespace std;
+
+double Hypergraph::NumberOfPaths() const {
+ return Inside<double, TransitionCountWeightFunction>(*this);
+}
+
+prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector<prob_t>* posts) const {
+ const ScaledEdgeProb weight(scale);
+ SparseVector<double> pv;
+ const double inside = InsideOutside<prob_t,
+ ScaledEdgeProb,
+ SparseVector<double>,
+ EdgeFeaturesWeightFunction>(*this, &pv, weight);
+ posts->resize(edges_.size());
+ for (int i = 0; i < edges_.size(); ++i)
+ (*posts)[i] = prob_t(pv.value(i));
+ return prob_t(inside);
+}
+
+prob_t Hypergraph::ComputeBestPathThroughEdges(vector<prob_t>* post) const {
+ vector<prob_t> in(edges_.size());
+ vector<prob_t> out(edges_.size());
+ post->resize(edges_.size());
+
+ vector<prob_t> ins_node_best(nodes_.size());
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Node& node = nodes_[i];
+ prob_t& node_ins_best = ins_node_best[i];
+ if (node.in_edges_.empty()) node_ins_best = prob_t::One();
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ const Edge& edge = edges_[node.in_edges_[j]];
+ prob_t& in_edge_sco = in[node.in_edges_[j]];
+ in_edge_sco = edge.edge_prob_;
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k)
+ in_edge_sco *= ins_node_best[edge.tail_nodes_[k]];
+ if (in_edge_sco > node_ins_best) node_ins_best = in_edge_sco;
+ }
+ }
+ const prob_t ins_sco = ins_node_best[nodes_.size() - 1];
+
+ // sanity check
+ int tots = 0;
+ for (int i = 0; i < nodes_.size(); ++i) { if (nodes_[i].out_edges_.empty()) tots++; }
+ assert(tots == 1);
+
+ // compute outside scores, potentially using inside scores
+ vector<prob_t> out_node_best(nodes_.size());
+ for (int i = nodes_.size() - 1; i >= 0; --i) {
+ const Node& node = nodes_[i];
+ prob_t& node_out_best = out_node_best[node.id_];
+ if (node.out_edges_.empty()) node_out_best = prob_t::One();
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ const Edge& edge = edges_[node.out_edges_[j]];
+ prob_t sco = edge.edge_prob_ * out_node_best[edge.head_node_];
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ if (edge.tail_nodes_[k] != i)
+ sco *= ins_node_best[edge.tail_nodes_[k]];
+ }
+ if (sco > node_out_best) node_out_best = sco;
+ }
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ out[node.in_edges_[j]] = node_out_best;
+ }
+ }
+
+ for (int i = 0; i < in.size(); ++i)
+ (*post)[i] = in[i] * out[i];
+
+ return ins_sco;
+}
+
+void Hypergraph::PushWeightsToSource(double scale) {
+ vector<prob_t> posts;
+ ComputeEdgePosteriors(scale, &posts);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Hypergraph::Node& node = nodes_[i];
+ prob_t z = prob_t::Zero();
+ for (int j = 0; j < node.out_edges_.size(); ++j)
+ z += posts[node.out_edges_[j]];
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ edges_[node.out_edges_[j]].edge_prob_ = posts[node.out_edges_[j]] / z;
+ }
+ }
+}
+
+void Hypergraph::PushWeightsToGoal(double scale) {
+ vector<prob_t> posts;
+ ComputeEdgePosteriors(scale, &posts);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Hypergraph::Node& node = nodes_[i];
+ prob_t z = prob_t::Zero();
+ for (int j = 0; j < node.in_edges_.size(); ++j)
+ z += posts[node.in_edges_[j]];
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ edges_[node.in_edges_[j]].edge_prob_ = posts[node.in_edges_[j]] / z;
+ }
+ }
+}
+
+void Hypergraph::PruneEdges(const std::vector<bool>& prune_edge) {
+ assert(prune_edge.size() == edges_.size());
+ TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge);
+}
+
+void Hypergraph::DensityPruneInsideOutside(const double scale,
+ const bool use_sum_prod_semiring,
+ const double density,
+ const vector<bool>* preserve_mask) {
+ assert(density >= 1.0);
+ const int plen = ViterbiPathLength(*this);
+ vector<WordID> bp;
+ int rnum = min(static_cast<int>(edges_.size()), static_cast<int>(density * static_cast<double>(plen)));
+ if (rnum == edges_.size()) {
+ cerr << "No pruning required: denisty already sufficient";
+ return;
+ }
+ vector<prob_t> io(edges_.size());
+ if (use_sum_prod_semiring)
+ ComputeEdgePosteriors(scale, &io);
+ else
+ ComputeBestPathThroughEdges(&io);
+ assert(edges_.size() == io.size());
+ vector<prob_t> sorted = io;
+ nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater<prob_t>());
+ const double cutoff = sorted[rnum];
+ vector<bool> prune(edges_.size());
+ for (int i = 0; i < edges_.size(); ++i) {
+ prune[i] = (io[i] < cutoff);
+ if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
+ }
+ PruneEdges(prune);
+}
+
+void Hypergraph::BeamPruneInsideOutside(
+ const double scale,
+ const bool use_sum_prod_semiring,
+ const double alpha,
+ const vector<bool>* preserve_mask) {
+ assert(alpha > 0.0);
+ assert(scale > 0.0);
+ vector<prob_t> io(edges_.size());
+ if (use_sum_prod_semiring)
+ ComputeEdgePosteriors(scale, &io);
+ else
+ ComputeBestPathThroughEdges(&io);
+ assert(edges_.size() == io.size());
+ prob_t best; // initializes to zero
+ for (int i = 0; i < io.size(); ++i)
+ if (io[i] > best) best = io[i];
+ const prob_t aprob(exp(-alpha));
+ const prob_t cutoff = best * aprob;
+ vector<bool> prune(edges_.size());
+ //cerr << preserve_mask.size() << " " << edges_.size() << endl;
+ int pc = 0;
+ for (int i = 0; i < io.size(); ++i) {
+ const bool prune_edge = (io[i] < cutoff);
+ if (prune_edge) ++pc;
+ prune[i] = (io[i] < cutoff);
+ if (preserve_mask && (*preserve_mask)[i]) prune[i] = false;
+ }
+ cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n";
+ PruneEdges(prune);
+}
+
+void Hypergraph::PrintGraphviz() const {
+ int ei = 0;
+ cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n";
+ for (vector<Edge>::const_iterator i = edges_.begin();
+ i != edges_.end(); ++i) {
+ const Edge& edge=*i;
+ ++ei;
+ static const string none = "<null>";
+ string rule = (edge.rule_ ? edge.rule_->AsString(false) : none);
+
+ cerr << " A_" << ei << " [label=\"" << rule << " p=" << edge.edge_prob_
+ << " F:" << edge.feature_values_
+ << "\" shape=\"rect\"];\n";
+ for (int i = 0; i < edge.tail_nodes_.size(); ++i) {
+ cerr << " " << edge.tail_nodes_[i] << " -> A_" << ei << ";\n";
+ }
+ cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n";
+ }
+ for (vector<Node>::const_iterator ni = nodes_.begin();
+ ni != nodes_.end(); ++ni) {
+ cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "")
+ //cerr << " " << ni->id_ << "[label=\"" << ni->cat_
+ << " n=" << ni->id_
+// << ",x=" << &*ni
+// << ",in=" << ni->in_edges_.size()
+// << ",out=" << ni->out_edges_.size()
+ << "\"];\n";
+ }
+ cerr << "}\n";
+}
+
+void Hypergraph::Union(const Hypergraph& other) {
+ if (&other == this) return;
+ if (nodes_.empty()) { nodes_ = other.nodes_; edges_ = other.edges_; return; }
+ int noff = nodes_.size();
+ int eoff = edges_.size();
+ int ogoal = other.nodes_.size() - 1;
+ int cgoal = noff - 1;
+ // keep a single goal node, so add nodes.size - 1
+ nodes_.resize(nodes_.size() + ogoal);
+ // add all edges
+ edges_.resize(edges_.size() + other.edges_.size());
+
+ for (int i = 0; i < ogoal; ++i) {
+ const Node& on = other.nodes_[i];
+ Node& cn = nodes_[i + noff];
+ cn.id_ = i + noff;
+ cn.in_edges_.resize(on.in_edges_.size());
+ for (int j = 0; j < on.in_edges_.size(); ++j)
+ cn.in_edges_[j] = on.in_edges_[j] + eoff;
+
+ cn.out_edges_.resize(on.out_edges_.size());
+ for (int j = 0; j < on.out_edges_.size(); ++j)
+ cn.out_edges_[j] = on.out_edges_[j] + eoff;
+ }
+
+ for (int i = 0; i < other.edges_.size(); ++i) {
+ const Edge& oe = other.edges_[i];
+ Edge& ce = edges_[i + eoff];
+ ce.id_ = i + eoff;
+ ce.rule_ = oe.rule_;
+ ce.feature_values_ = oe.feature_values_;
+ if (oe.head_node_ == ogoal) {
+ ce.head_node_ = cgoal;
+ nodes_[cgoal].in_edges_.push_back(ce.id_);
+ } else {
+ ce.head_node_ = oe.head_node_ + noff;
+ }
+ ce.tail_nodes_.resize(oe.tail_nodes_.size());
+ for (int j = 0; j < oe.tail_nodes_.size(); ++j)
+ ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff;
+ }
+
+ TopologicallySortNodesAndEdges(cgoal);
+}
+
+int Hypergraph::MarkReachable(const Node& node,
+ vector<bool>* rmap,
+ const vector<bool>* prune_edges) const {
+ int total = 0;
+ if (!(*rmap)[node.id_]) {
+ total = 1;
+ (*rmap)[node.id_] = true;
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ if (!(prune_edges && (*prune_edges)[node.in_edges_[i]])) {
+ for (int j = 0; j < edges_[node.in_edges_[i]].tail_nodes_.size(); ++j)
+ total += MarkReachable(nodes_[edges_[node.in_edges_[i]].tail_nodes_[j]], rmap, prune_edges);
+ }
+ }
+ }
+ return total;
+}
+
+void Hypergraph::PruneUnreachable(int goal_node_id) {
+ TopologicallySortNodesAndEdges(goal_node_id, NULL);
+}
+
+void Hypergraph::RemoveNoncoaccessibleStates(int goal_node_id) {
+ if (goal_node_id < 0) goal_node_id += nodes_.size();
+ assert(goal_node_id >= 0);
+ assert(goal_node_id < nodes_.size());
+
+ // TODO finish implementation
+ abort();
+}
+
+void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,
+ const vector<bool>* prune_edges) {
+ vector<Edge> sedges(edges_.size());
+ // figure out which nodes are reachable from the goal
+ vector<bool> reachable(nodes_.size(), false);
+ int num_reachable = MarkReachable(nodes_[goal_index], &reachable, prune_edges);
+ vector<Node> snodes(num_reachable); snodes.clear();
+
+ // enumerate all reachable nodes in topologically sorted order
+ vector<int> old_node_to_new_id(nodes_.size(), -1);
+ vector<int> node_to_incount(nodes_.size(), -1);
+ vector<bool> node_processed(nodes_.size(), false);
+ typedef map<int, set<int> > PQueue;
+ PQueue pri_q;
+ for (int i = 0; i < nodes_.size(); ++i) {
+ if (!reachable[i])
+ continue;
+ const int inedges = nodes_[i].in_edges_.size();
+ int incount = inedges;
+ for (int j = 0; j < inedges; ++j)
+ if (edges_[nodes_[i].in_edges_[j]].tail_nodes_.size() == 0 ||
+ (prune_edges && (*prune_edges)[nodes_[i].in_edges_[j]]))
+ --incount;
+ // cerr << &nodes_[i] <<" : incount=" << incount << "\tout=" << nodes_[i].out_edges_.size() << "\t(in-edges=" << inedges << ")\n";
+ assert(node_to_incount[i] == -1);
+ node_to_incount[i] = incount;
+ pri_q[incount].insert(i);
+ }
+
+ int edge_count = 0;
+ while (!pri_q.empty()) {
+ PQueue::iterator iter = pri_q.find(0);
+ assert(iter != pri_q.end());
+ assert(!iter->second.empty());
+
+ // get first node with incount = 0
+ const int cur_index = *iter->second.begin();
+ const Node& node = nodes_[cur_index];
+ assert(reachable[cur_index]);
+ //cerr << "node: " << node << endl;
+ const int new_node_index = snodes.size();
+ old_node_to_new_id[cur_index] = new_node_index;
+ snodes.push_back(node);
+ Node& new_node = snodes.back();
+ new_node.id_ = new_node_index;
+ new_node.out_edges_.clear();
+
+ // fix up edges - we can now process the in edges and
+ // the out edges of their tails
+ int oi = 0;
+ for (int i = 0; i < node.in_edges_.size(); ++i, ++oi) {
+ if (prune_edges && (*prune_edges)[node.in_edges_[i]]) {
+ --oi;
+ continue;
+ }
+ new_node.in_edges_[oi] = edge_count;
+ Edge& edge = sedges[edge_count];
+ edge.id_ = edge_count;
+ ++edge_count;
+ const Edge& old_edge = edges_[node.in_edges_[i]];
+ edge.rule_ = old_edge.rule_;
+ edge.feature_values_ = old_edge.feature_values_;
+ edge.head_node_ = new_node_index;
+ edge.tail_nodes_.resize(old_edge.tail_nodes_.size());
+ edge.edge_prob_ = old_edge.edge_prob_;
+ edge.i_ = old_edge.i_;
+ edge.j_ = old_edge.j_;
+ edge.prev_i_ = old_edge.prev_i_;
+ edge.prev_j_ = old_edge.prev_j_;
+ for (int j = 0; j < old_edge.tail_nodes_.size(); ++j) {
+ const Node& old_tail_node = nodes_[old_edge.tail_nodes_[j]];
+ edge.tail_nodes_[j] = old_node_to_new_id[old_tail_node.id_];
+ snodes[edge.tail_nodes_[j]].out_edges_.push_back(edge_count-1);
+ assert(edge.tail_nodes_[j] != new_node_index);
+ }
+ }
+ assert(oi <= new_node.in_edges_.size());
+ new_node.in_edges_.resize(oi);
+
+ for (int i = 0; i < node.out_edges_.size(); ++i) {
+ const Edge& edge = edges_[node.out_edges_[i]];
+ const int next_index = edge.head_node_;
+ assert(cur_index != next_index);
+ if (!reachable[next_index]) continue;
+ if (prune_edges && (*prune_edges)[edge.id_]) continue;
+
+ bool dontReduce = false;
+ for (int j = 0; j < edge.tail_nodes_.size() && !dontReduce; ++j) {
+ int tail_index = edge.tail_nodes_[j];
+ dontReduce = (tail_index != cur_index) && !node_processed[tail_index];
+ }
+ if (dontReduce)
+ continue;
+
+ const int incount = node_to_incount[next_index];
+ if (incount <= 0) {
+ cerr << "incount = " << incount << ", should be > 0!\n";
+ cerr << "do you have a cycle in your hypergraph?\n";
+ abort();
+ }
+ PQueue::iterator it = pri_q.find(incount);
+ assert(it != pri_q.end());
+ it->second.erase(next_index);
+ if (it->second.empty()) pri_q.erase(it);
+
+ // reinsert node with reduced incount
+ pri_q[incount-1].insert(next_index);
+ --node_to_incount[next_index];
+ }
+
+ // remove node from set
+ iter->second.erase(cur_index);
+ if (iter->second.empty())
+ pri_q.erase(iter);
+ node_processed[cur_index] = true;
+ }
+
+ sedges.resize(edge_count);
+ nodes_.swap(snodes);
+ edges_.swap(sedges);
+ assert(nodes_.back().out_edges_.size() == 0);
+}
+
+TRulePtr Hypergraph::kEPSRule;
+TRulePtr Hypergraph::kUnaryRule;
+
+void Hypergraph::EpsilonRemove(WordID eps) {
+ if (!kEPSRule) {
+ kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ }
+ vector<bool> kill(edges_.size(), false);
+ for (int i = 0; i < edges_.size(); ++i) {
+ const Edge& edge = edges_[i];
+ if (edge.tail_nodes_.empty() &&
+ edge.rule_->f_.size() == 1 &&
+ edge.rule_->f_[0] == eps) {
+ kill[i] = true;
+ if (!edge.feature_values_.empty()) {
+ Node& node = nodes_[edge.head_node_];
+ if (node.in_edges_.size() != 1) {
+ cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
+ // this *probably* means that there are multiple derivations of the
+ // same sequence via different paths through the input forest
+ // this needs to be investigated and fixed
+ } else {
+ for (int j = 0; j < node.out_edges_.size(); ++j)
+ edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
+ // cerr << "PROMOTED " << edge.feature_values_ << endl;
+ }
+ }
+ }
+ }
+ bool created_eps = false;
+ PruneEdges(kill);
+ for (int i = 0; i < nodes_.size(); ++i) {
+ const Node& node = nodes_[i];
+ if (node.in_edges_.empty()) {
+ for (int j = 0; j < node.out_edges_.size(); ++j) {
+ Edge& edge = edges_[node.out_edges_[j]];
+ if (edge.rule_->Arity() == 2) {
+ assert(edge.rule_->f_.size() == 2);
+ assert(edge.rule_->e_.size() == 2);
+ edge.rule_ = kUnaryRule;
+ int cur = node.id_;
+ int t = -1;
+ assert(edge.tail_nodes_.size() == 2);
+ for (int i = 0; i < 2; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; }
+ assert(t != -1);
+ edge.tail_nodes_.resize(1);
+ edge.tail_nodes_[0] = t;
+ } else {
+ edge.rule_ = kEPSRule;
+ edge.rule_->f_[0] = eps;
+ edge.rule_->e_[0] = eps;
+ edge.tail_nodes_.clear();
+ created_eps = true;
+ }
+ }
+ }
+ }
+ vector<bool> k2(edges_.size(), false);
+ PruneEdges(k2);
+ if (created_eps) EpsilonRemove(eps);
+}
+
+struct EdgeWeightSorter {
+ const Hypergraph& hg;
+ EdgeWeightSorter(const Hypergraph& h) : hg(h) {}
+ bool operator()(int a, int b) const {
+ return hg.edges_[a].edge_prob_ > hg.edges_[b].edge_prob_;
+ }
+};
+
+void Hypergraph::SortInEdgesByEdgeWeights() {
+ for (int i = 0; i < nodes_.size(); ++i) {
+ Node& node = nodes_[i];
+ sort(node.in_edges_.begin(), node.in_edges_.end(), EdgeWeightSorter(*this));
+ }
+}
+
diff --git a/src/hg.h b/src/hg.h
new file mode 100644
index 00000000..7a2658b8
--- /dev/null
+++ b/src/hg.h
@@ -0,0 +1,225 @@
+#ifndef _HG_H_
+#define _HG_H_
+
+#include <string>
+#include <vector>
+
+#include "small_vector.h"
+#include "sparse_vector.h"
+#include "wordid.h"
+#include "trule.h"
+#include "prob.h"
+
+// class representing an acyclic hypergraph
+// - edges have 1 head, 0..n tails
+class Hypergraph {
+ public:
+ Hypergraph() {}
+
+ // SmallVector is a fast, small vector<int> implementation for sizes <= 2
+ typedef SmallVector TailNodeVector;
+
+ // TODO get rid of state_ and cat_?
+ struct Node {
+ Node() : id_(), cat_() {}
+ int id_; // equal to this object's position in the nodes_ vector
+ WordID cat_; // non-terminal category if <0, 0 if not set
+ std::vector<int> in_edges_; // contents refer to positions in edges_
+ std::vector<int> out_edges_; // contents refer to positions in edges_
+ std::string state_; // opaque state
+ };
+
+ // TODO get rid of edge_prob_? (can be computed on the fly as the dot
+ // product of the weight vector and the feature values)
+ struct Edge {
+ Edge() : i_(-1), j_(-1), prev_i_(-1), prev_j_(-1) {}
+ inline int Arity() const { return tail_nodes_.size(); }
+ int head_node_; // refers to a position in nodes_
+ TailNodeVector tail_nodes_; // contents refer to positions in nodes_
+ TRulePtr rule_;
+ SparseVector<double> feature_values_;
+ prob_t edge_prob_; // dot product of weights and feat_values
+ int id_; // equal to this object's position in the edges_ vector
+
+ // span info. typically, i_ and j_ refer to indices in the source sentence
+ // if a synchronous parse has been executed i_ and j_ will refer to indices
+ // in the target sentence / lattice and prev_i_ prev_j_ will refer to
+ // positions in the source. Note: it is up to the translator implementation
+ // to properly set these values. For some models (like the Forest-input
+ // phrase based model) it may not be straightforward to do. if these values
+ // are not properly set, most things will work but alignment and any features
+ // that depend on them will be broken.
+ short int i_;
+ short int j_;
+ short int prev_i_;
+ short int prev_j_;
+ };
+
+ void swap(Hypergraph& other) {
+ other.nodes_.swap(nodes_);
+ other.edges_.swap(edges_);
+ }
+
+ void ResizeNodes(int size) {
+ nodes_.resize(size);
+ for (int i = 0; i < size; ++i) nodes_[i].id_ = i;
+ }
+
+ // reserves space in the nodes vector to prevent memory locations
+ // from changing
+ void ReserveNodes(size_t n, size_t e = 0) {
+ nodes_.reserve(n);
+ if (e) edges_.reserve(e);
+ }
+
+ Edge* AddEdge(const TRulePtr& rule, const TailNodeVector& tail) {
+ edges_.push_back(Edge());
+ Edge* edge = &edges_.back();
+ edge->rule_ = rule;
+ edge->tail_nodes_ = tail;
+ edge->id_ = edges_.size() - 1;
+ for (int i = 0; i < edge->tail_nodes_.size(); ++i)
+ nodes_[edge->tail_nodes_[i]].out_edges_.push_back(edge->id_);
+ return edge;
+ }
+
+ Node* AddNode(const WordID& cat, const std::string& state = "") {
+ nodes_.push_back(Node());
+ nodes_.back().cat_ = cat;
+ nodes_.back().state_ = state;
+ nodes_.back().id_ = nodes_.size() - 1;
+ return &nodes_.back();
+ }
+
+ void ConnectEdgeToHeadNode(const int edge_id, const int head_id) {
+ edges_[edge_id].head_node_ = head_id;
+ nodes_[head_id].in_edges_.push_back(edge_id);
+ }
+
+ // TODO remove this - use the version that takes indices
+ void ConnectEdgeToHeadNode(Edge* edge, Node* head) {
+ edge->head_node_ = head->id_;
+ head->in_edges_.push_back(edge->id_);
+ }
+
+ // merge the goal node from other with this goal node
+ void Union(const Hypergraph& other);
+
+ void PrintGraphviz() const;
+
+ // compute the total number of paths in the forest
+ double NumberOfPaths() const;
+
+ // BEWARE. this assumes that the source and target language
+ // strings are identical and that there are no loops.
+ // It assumes a bunch of other things about where the
+ // epsilons will be. It tries to assert failure if you
+ // break these assumptions, but it may not.
+ // TODO - make this work
+ void EpsilonRemove(WordID eps);
+
+ // multiple the weights vector by the edge feature vector
+ // (inner product) to set the edge probabilities
+ template <typename V>
+ void Reweight(const V& weights) {
+ for (int i = 0; i < edges_.size(); ++i) {
+ Edge& e = edges_[i];
+ e.edge_prob_.logeq(e.feature_values_.dot(weights));
+ }
+ }
+
+ // computes inside and outside scores for each
+ // edge in the hypergraph
+ // alpha->size = edges_.size = beta->size
+ // returns inside prob of goal node
+ prob_t ComputeEdgePosteriors(double scale,
+ std::vector<prob_t>* posts) const;
+
+ // find the score of the very best path passing through each edge
+ prob_t ComputeBestPathThroughEdges(std::vector<prob_t>* posts) const;
+
+ // move weights as near to the source as possible, resulting in a
+ // stochastic automaton. ONLY FUNCTIONAL FOR *LATTICES*.
+ // See M. Mohri and M. Riley. A Weight Pushing Algorithm for Large
+ // Vocabulary Speech Recognition. 2001.
+ // the log semiring (NOT tropical) is used
+ void PushWeightsToSource(double scale = 1.0);
+ // same, except weights are pushed to the goal, works for HGs,
+ // not just lattices
+ void PushWeightsToGoal(double scale = 1.0);
+
+ void SortInEdgesByEdgeWeights();
+
+ void PruneUnreachable(int goal_node_id); // DEPRECATED
+
+ void RemoveNoncoaccessibleStates(int goal_node_id = -1);
+
+ // remove edges from the hypergraph if prune_edge[edge_id] is true
+ void PruneEdges(const std::vector<bool>& prune_edge);
+
+ // if you don't know, use_sum_prod_semiring should be false
+ void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density,
+ const std::vector<bool>* preserve_mask = NULL);
+
+ // prunes any edge whose score on the best path taking that edge is more than alpha away
+ // from the score of the global best past (or the highest edge posterior)
+ void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha,
+ const std::vector<bool>* preserve_mask = NULL);
+
+ void clear() {
+ nodes_.clear();
+ edges_.clear();
+ }
+
+ inline size_t NumberOfEdges() const { return edges_.size(); }
+ inline size_t NumberOfNodes() const { return nodes_.size(); }
+ inline bool empty() const { return nodes_.empty(); }
+
+ // nodes_ is sorted in topological order
+ std::vector<Node> nodes_;
+ // edges_ is not guaranteed to be in any particular order
+ std::vector<Edge> edges_;
+
+ // reorder nodes_ so they are in topological order
+ // source nodes at 0 sink nodes at size-1
+ void TopologicallySortNodesAndEdges(int goal_idx,
+ const std::vector<bool>* prune_edges = NULL);
+ private:
+ // returns total nodes reachable
+ int MarkReachable(const Node& node,
+ std::vector<bool>* rmap,
+ const std::vector<bool>* prune_edges) const;
+
+ static TRulePtr kEPSRule;
+ static TRulePtr kUnaryRule;
+};
+
+// common WeightFunctions, map an edge -> WeightType
+// for generic Viterbi/Inside algorithms
+struct EdgeProb {
+ inline const prob_t& operator()(const Hypergraph::Edge& e) const { return e.edge_prob_; }
+};
+
+struct ScaledEdgeProb {
+ ScaledEdgeProb(const double& alpha) : alpha_(alpha) {}
+ inline prob_t operator()(const Hypergraph::Edge& e) const { return e.edge_prob_.pow(alpha_); }
+ const double alpha_;
+};
+
+struct EdgeFeaturesWeightFunction {
+ inline const SparseVector<double>& operator()(const Hypergraph::Edge& e) const { return e.feature_values_; }
+};
+
+struct TransitionEventWeightFunction {
+ inline SparseVector<prob_t> operator()(const Hypergraph::Edge& e) const {
+ SparseVector<prob_t> result;
+ result.set_value(e.id_, prob_t::One());
+ return result;
+ }
+};
+
+struct TransitionCountWeightFunction {
+ inline double operator()(const Hypergraph::Edge& e) const { (void)e; return 1.0; }
+};
+
+#endif
diff --git a/src/hg_intersect.cc b/src/hg_intersect.cc
new file mode 100644
index 00000000..a5e8913a
--- /dev/null
+++ b/src/hg_intersect.cc
@@ -0,0 +1,121 @@
+#include "hg_intersect.h"
+
+#include <vector>
+#include <tr1/unordered_map>
+#include <boost/lexical_cast.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "tdict.h"
+#include "hg.h"
+#include "trule.h"
+#include "wordid.h"
+#include "bottom_up_parser.h"
+
+using boost::lexical_cast;
+using namespace std::tr1;
+using namespace std;
+
+struct RuleFilter {
+ unordered_map<vector<WordID>, bool, boost::hash<vector<WordID> > > exists_;
+ bool true_lattice;
+ RuleFilter(const Lattice& target, int max_phrase_size) {
+ true_lattice = false;
+ for (int i = 0; i < target.size(); ++i) {
+ vector<WordID> phrase;
+ int lim = min(static_cast<int>(target.size()), i + max_phrase_size);
+ for (int j = i; j < lim; ++j) {
+ if (target[j].size() > 1) { true_lattice = true; break; }
+ phrase.push_back(target[j][0].label);
+ exists_[phrase] = true;
+ }
+ }
+ vector<WordID> sos(1, TD::Convert("<s>"));
+ exists_[sos] = true;
+ }
+ bool operator()(const TRule& r) const {
+ // TODO do some smarter filtering for lattices
+ if (true_lattice) return false; // don't filter "true lattice" input
+ const vector<WordID>& e = r.e();
+ for (int i = 0; i < e.size(); ++i) {
+ if (e[i] <= 0) continue;
+ vector<WordID> phrase;
+ for (int j = i; j < e.size(); ++j) {
+ if (e[j] <= 0) break;
+ phrase.push_back(e[j]);
+ if (exists_.count(phrase) == 0) return true;
+ }
+ }
+ return false;
+ }
+};
+
+bool HG::Intersect(const Lattice& target, Hypergraph* hg) {
+ vector<bool> rem(hg->edges_.size(), false);
+ const RuleFilter filter(target, 15); // TODO make configurable
+ for (int i = 0; i < rem.size(); ++i)
+ rem[i] = filter(*hg->edges_[i].rule_);
+ hg->PruneEdges(rem);
+
+ const int nedges = hg->edges_.size();
+ const int nnodes = hg->nodes_.size();
+
+ TextGrammar* g = new TextGrammar;
+ GrammarPtr gp(g);
+ vector<int> cats(nnodes);
+ // each node in the translation forest becomes a "non-terminal" in the new
+ // grammar, create the labels here
+ for (int i = 0; i < nnodes; ++i)
+ cats[i] = TD::Convert("CAT_" + lexical_cast<string>(i)) * -1;
+
+ // construct the grammar
+ for (int i = 0; i < nedges; ++i) {
+ const Hypergraph::Edge& edge = hg->edges_[i];
+ const vector<WordID>& tgt = edge.rule_->e();
+ const vector<WordID>& src = edge.rule_->f();
+ TRulePtr rule(new TRule);
+ rule->prev_i = edge.i_;
+ rule->prev_j = edge.j_;
+ rule->lhs_ = cats[edge.head_node_];
+ vector<WordID>& f = rule->f_;
+ vector<WordID>& e = rule->e_;
+ f.resize(tgt.size()); // swap source and target, since the parser
+ e.resize(src.size()); // parses using the source side!
+ Hypergraph::TailNodeVector tn(edge.tail_nodes_.size());
+ int ntc = 0;
+ for (int j = 0; j < tgt.size(); ++j) {
+ const WordID& cur = tgt[j];
+ if (cur > 0) {
+ f[j] = cur;
+ } else {
+ tn[ntc++] = cur;
+ f[j] = cats[edge.tail_nodes_[-cur]];
+ }
+ }
+ ntc = 0;
+ for (int j = 0; j < src.size(); ++j) {
+ const WordID& cur = src[j];
+ if (cur > 0) {
+ e[j] = cur;
+ } else {
+ e[j] = tn[ntc++];
+ }
+ }
+ rule->scores_ = edge.feature_values_;
+ rule->parent_rule_ = edge.rule_;
+ rule->ComputeArity();
+ //cerr << "ADD: " << rule->AsString() << endl;
+
+ g->AddRule(rule);
+ }
+ g->SetMaxSpan(target.size() + 1);
+ const string& new_goal = TD::Convert(cats.back() * -1);
+ vector<GrammarPtr> grammars(1, gp);
+ Hypergraph tforest;
+ ExhaustiveBottomUpParser parser(new_goal, grammars);
+ if (!parser.Parse(target, &tforest))
+ return false;
+ else
+ hg->swap(tforest);
+ return true;
+}
+
diff --git a/src/hg_intersect.h b/src/hg_intersect.h
new file mode 100644
index 00000000..826bdaae
--- /dev/null
+++ b/src/hg_intersect.h
@@ -0,0 +1,13 @@
+#ifndef _HG_INTERSECT_H_
+#define _HG_INTERSECT_H_
+
+#include <vector>
+
+#include "lattice.h"
+
+class Hypergraph;
+struct HG {
+ static bool Intersect(const Lattice& target, Hypergraph* hg);
+};
+
+#endif
diff --git a/src/hg_io.cc b/src/hg_io.cc
new file mode 100644
index 00000000..629e65f1
--- /dev/null
+++ b/src/hg_io.cc
@@ -0,0 +1,585 @@
+#include "hg_io.h"
+
+#include <sstream>
+#include <iostream>
+
+#include <boost/lexical_cast.hpp>
+
+#include "tdict.h"
+#include "json_parse.h"
+#include "hg.h"
+
+using namespace std;
+
+struct HGReader : public JSONParser {
+ HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; }
+
+ void CreateNode(const string& cat, const vector<int>& in_edges) {
+ WordID c = TD::Convert("X") * -1;
+ if (!cat.empty()) c = TD::Convert(cat) * -1;
+ Hypergraph::Node* node = hg.AddNode(c, "");
+ for (int i = 0; i < in_edges.size(); ++i) {
+ if (in_edges[i] >= hg.edges_.size()) {
+ cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i]
+ << ", but hg only has " << hg.edges_.size() << " edges!\n";
+ abort();
+ }
+ hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node);
+ }
+ }
+ void CreateEdge(const TRulePtr& rule, SparseVector<double>* feats, const SmallVector& tail) {
+ Hypergraph::Edge* edge = hg.AddEdge(rule, tail);
+ feats->swap(edge->feature_values_);
+ }
+
+ bool HandleJSONEvent(int type, const JSON_value* value) {
+ switch(state) {
+ case -1:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ state = 0;
+ break;
+ case 0:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "HG created\n"; // TODO, signal some kind of callback
+ } else if (type == JSON_T_KEY) {
+ string val = value->vu.str.value;
+ if (val == "features") { assert(fdict.empty()); state = 1; }
+ else if (val == "is_sorted") { state = 3; }
+ else if (val == "rules") { assert(rules.empty()); state = 4; }
+ else if (val == "node") { state = 8; }
+ else if (val == "edges") { state = 13; }
+ else { cerr << "Unexpected key: " << val << endl; return false; }
+ }
+ break;
+
+ // features
+ case 1:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 2;
+ break;
+ case 2:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_STRING);
+ fdict.push_back(FD::Convert(value->vu.str.value));
+ break;
+
+ // is_sorted
+ case 3:
+ assert(type == JSON_T_TRUE || type == JSON_T_FALSE);
+ is_sorted = (type == JSON_T_TRUE);
+ if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; }
+ state = 0;
+ break;
+
+ // rules
+ case 4:
+ if(type == JSON_T_NULL) { state = 0; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 5;
+ break;
+ case 5:
+ if(type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_INTEGER);
+ state = 6;
+ rule_id = value->vu.integer_value;
+ break;
+ case 6:
+ assert(type == JSON_T_STRING);
+ rules[rule_id] = TRulePtr(new TRule(value->vu.str.value));
+ state = 5;
+ break;
+
+ // Nodes
+ case 8:
+ assert(type == JSON_T_OBJECT_BEGIN);
+ ++nodes;
+ in_edges.clear();
+ cat.clear();
+ state = 9; break;
+ case 9:
+ if (type == JSON_T_OBJECT_END) {
+ //cerr << "Creating NODE\n";
+ CreateNode(cat, in_edges);
+ state = 0; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ if (cur_key == "cat") { assert(cat.empty()); state = 10; break; }
+ if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; }
+ cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n";
+ return false;
+ case 10:
+ assert(type == JSON_T_STRING || type == JSON_T_NULL);
+ cat = value->vu.str.value;
+ state = 9; break;
+ case 11:
+ if (type == JSON_T_NULL) { state = 9; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 12; break;
+ case 12:
+ if (type == JSON_T_ARRAY_END) { state = 9; break; }
+ assert(type == JSON_T_INTEGER);
+ //cerr << "in_edges: " << value->vu.integer_value << endl;
+ in_edges.push_back(value->vu.integer_value);
+ break;
+
+ // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12},
+ // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17},
+ // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}]
+ case 13:
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 14;
+ break;
+ case 14:
+ if (type == JSON_T_ARRAY_END) { state = 0; break; }
+ assert(type == JSON_T_OBJECT_BEGIN);
+ //cerr << "New edge\n";
+ ++edges;
+ cur_rule.reset(); feats.clear(); tail.clear();
+ state = 15; break;
+ case 15:
+ if (type == JSON_T_OBJECT_END) {
+ CreateEdge(cur_rule, &feats, tail);
+ state = 14; break;
+ }
+ assert(type == JSON_T_KEY);
+ cur_key = value->vu.str.value;
+ //cerr << "edge key " << cur_key << endl;
+ if (cur_key == "rule") { assert(!cur_rule); state = 16; break; }
+ if (cur_key == "feats") { assert(feats.empty()); state = 17; break; }
+ if (cur_key == "tail") { assert(tail.empty()); state = 20; break; }
+ cerr << "Unexpected key " << cur_key << " in edge specification\n";
+ return false;
+ case 16: // edge.rule
+ if (type == JSON_T_INTEGER) {
+ int rule_id = value->vu.integer_value;
+ if (rules.find(rule_id) == rules.end()) {
+ // rules list must come before the edge definitions!
+ cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n";
+ return false;
+ }
+ cur_rule = rules[rule_id];
+ } else if (type == JSON_T_STRING) {
+ cur_rule.reset(new TRule(value->vu.str.value));
+ } else {
+ cerr << "Rule must be either a rule id or a rule string" << endl;
+ return false;
+ }
+ // cerr << "Edge: rule=" << cur_rule->AsString() << endl;
+ state = 15;
+ break;
+ case 17: // edge.feats
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 18; break;
+ case 18:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ if (type != JSON_T_INTEGER && type != JSON_T_STRING) {
+ cerr << "Unexpected feature id type\n"; return false;
+ }
+ if (type == JSON_T_INTEGER) {
+ fid = value->vu.integer_value;
+ assert(fid < fdict.size());
+ fid = fdict[fid];
+ } else if (JSON_T_STRING) {
+ fid = FD::Convert(value->vu.str.value);
+ } else { abort(); }
+ state = 19;
+ break;
+ case 19:
+ {
+ assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT);
+ double val = (type == JSON_T_INTEGER ? static_cast<double>(value->vu.integer_value) :
+ strtod(value->vu.str.value, NULL));
+ feats.set_value(fid, val);
+ state = 18;
+ break;
+ }
+ case 20: // edge.tail
+ if (type == JSON_T_NULL) { state = 15; break; }
+ assert(type == JSON_T_ARRAY_BEGIN);
+ state = 21; break;
+ case 21:
+ if (type == JSON_T_ARRAY_END) { state = 15; break; }
+ assert(type == JSON_T_INTEGER);
+ tail.push_back(value->vu.integer_value);
+ break;
+ }
+ return true;
+ }
+ string rp;
+ string cat;
+ SmallVector tail;
+ vector<int> in_edges;
+ TRulePtr cur_rule;
+ map<int, TRulePtr> rules;
+ vector<int> fdict;
+ SparseVector<double> feats;
+ int state;
+ int fid;
+ int nodes;
+ int edges;
+ string cur_key;
+ Hypergraph& hg;
+ int rule_id;
+ bool nodes_needed;
+ bool edges_needed;
+ bool is_sorted;
+};
+
+bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) {
+ hg->clear();
+ HGReader reader(hg);
+ return reader.Parse(in);
+}
+
+static void WriteRule(const TRule& r, ostream* out) {
+ if (!r.lhs_) { (*out) << "[X] ||| "; }
+ JSONParser::WriteEscapedString(r.AsString(), out);
+}
+
+bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
+ map<const TRule*, int> rid;
+ ostream& o = *out;
+ rid[NULL] = 0;
+ o << '{';
+ if (!remove_rules) {
+ o << "\"rules\":[";
+ for (int i = 0; i < hg.edges_.size(); ++i) {
+ const TRule* r = hg.edges_[i].rule_.get();
+ int &id = rid[r];
+ if (!id) {
+ id=rid.size() - 1;
+ if (id > 1) o << ',';
+ o << id << ',';
+ WriteRule(*r, &o);
+ };
+ }
+ o << "],";
+ }
+ const bool use_fdict = FD::NumFeats() < 1000;
+ if (use_fdict) {
+ o << "\"features\":[";
+ for (int i = 1; i < FD::NumFeats(); ++i) {
+ o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"';
+ }
+ o << "],";
+ }
+ vector<int> edgemap(hg.edges_.size(), -1); // edges may be in non-topo order
+ int edge_count = 0;
+ for (int i = 0; i < hg.nodes_.size(); ++i) {
+ const Hypergraph::Node& node = hg.nodes_[i];
+ if (i > 0) { o << ","; }
+ o << "\"edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]];
+ edgemap[edge.id_] = edge_count;
+ ++edge_count;
+ o << (j == 0 ? "" : ",") << "{";
+
+ o << "\"tail\":[";
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ o << (k > 0 ? "," : "") << edge.tail_nodes_[k];
+ }
+ o << "],";
+
+ o << "\"feats\":[";
+ bool first = true;
+ for (SparseVector<double>::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) {
+ if (!it->second) continue;
+ if (!first) o << ',';
+ if (use_fdict)
+ o << (it->first - 1);
+ else
+ o << '"' << FD::Convert(it->first) << '"';
+ o << ',' << it->second;
+ first = false;
+ }
+ o << "]";
+ if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; }
+ o << "}";
+ }
+ o << "],";
+
+ o << "\"node\":{\"in_edges\":[";
+ for (int j = 0; j < node.in_edges_.size(); ++j) {
+ int mapped_edge = edgemap[node.in_edges_[j]];
+ assert(mapped_edge >= 0);
+ o << (j == 0 ? "" : ",") << mapped_edge;
+ }
+ o << "]";
+ if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; }
+ o << "}";
+ }
+ o << "}\n";
+ return true;
+}
+
+bool needs_escape[128];
+void InitEscapes() {
+ memset(needs_escape, false, 128);
+ needs_escape[static_cast<size_t>('\'')] = true;
+ needs_escape[static_cast<size_t>('\\')] = true;
+}
+
+string HypergraphIO::Escape(const string& s) {
+ size_t len = s.size();
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c]) ++len;
+ }
+ if (len == s.size()) return s;
+ string res(len, ' ');
+ size_t o = 0;
+ for (int i = 0; i < s.size(); ++i) {
+ unsigned char c = s[i];
+ if (c < 128 && needs_escape[c])
+ res[o++] = '\\';
+ res[o++] = c;
+ }
+ assert(o == len);
+ return res;
+}
+
+string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) {
+ static bool first = true;
+ if (first) { InitEscapes(); first = false; }
+ if (hg.nodes_.empty()) return "()";
+ ostringstream os;
+ if (include_global_parentheses) os << '(';
+ static const string EPS="*EPS*";
+ for (int i = 0; i < hg.nodes_.size()-1; ++i) {
+ os << '(';
+ if (hg.nodes_[i].out_edges_.empty()) abort();
+ for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) {
+ const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]];
+ const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS;
+ double prob = log(e.edge_prob_);
+ if (isinf(prob)) { prob = -9e20; }
+ if (isnan(prob)) { prob = 0; }
+ os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),";
+ }
+ os << "),";
+ }
+ if (include_global_parentheses) os << ')';
+ return os.str();
+}
+
+namespace PLF {
+
+const string chars = "'\\";
+const char& quote = chars[0];
+const char& slash = chars[1];
+
+// safe get
+inline char get(const std::string& in, int c) {
+ if (c < 0 || c >= (int)in.size()) return 0;
+ else return in[(size_t)c];
+}
+
+// consume whitespace
+inline void eatws(const std::string& in, int& c) {
+ while (get(in,c) == ' ') { c++; }
+}
+
+// from 'foo' return foo
+std::string getEscapedString(const std::string& in, int &c)
+{
+ eatws(in,c);
+ if (get(in,c++) != quote) return "ERROR";
+ std::string res;
+ char cur = 0;
+ do {
+ cur = get(in,c++);
+ if (cur == slash) { res += get(in,c++); }
+ else if (cur != quote) { res += cur; }
+ } while (get(in,c) != quote && (c < (int)in.size()));
+ c++;
+ eatws(in,c);
+ return res;
+}
+
+// basically atof
+float getFloat(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atof(tmp.c_str());
+}
+
+// basically atoi
+int getInt(const std::string& in, int &c)
+{
+ std::string tmp;
+ eatws(in,c);
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ tmp += get(in,c++);
+ }
+ eatws(in,c);
+ return atoi(tmp.c_str());
+}
+
+// maximum number of nodes permitted
+#define MAX_NODES 100000000
+// parse ('foo', 0.23)
+void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
+ if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); }
+ vector<WordID> ewords(2, 0);
+ ewords[1] = TD::Convert(getEscapedString(in,c));
+ TRulePtr r(new TRule(ewords));
+ //cerr << "RULE: " << r->AsString() << endl;
+ if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); }
+ size_t cnNext = 1;
+ std::vector<float> probs;
+ probs.push_back(getFloat(in,c));
+ while (get(in,c) == ',') {
+ c++;
+ float val = getFloat(in,c);
+ probs.push_back(val);
+ }
+ //if we read more than one prob, this was a lattice, last item was column increment
+ if (probs.size()>1) {
+ cnNext = static_cast<size_t>(probs.back());
+ probs.pop_back();
+ if (cnNext < 1) { assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); }
+ }
+ if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); }
+ eatws(in,c);
+ Hypergraph::TailNodeVector tail(1, cur_node);
+ Hypergraph::Edge* edge = hg->AddEdge(r, tail);
+ //cerr << " <--" << cur_node << endl;
+ int head_node = cur_node + cnNext;
+ assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
+ if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
+ hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
+ for (int i = 0; i < probs.size(); ++i)
+ edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
+}
+
+// parse (('foo', 0.23), ('bar', 0.77))
+void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
+ //cerr << "PLF READING NODE " << cur_node << endl;
+ if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
+ if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
+ eatws(in,c);
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (get(in,c) == ')') {
+ c++;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',' && get(in,c+1) == ')') {
+ c+=2;
+ eatws(in,c);
+ break;
+ }
+ if (get(in,c) == ',') { c++; eatws(in,c); }
+ ReadPLFEdge(in, c, cur_node, hg);
+ }
+}
+
+} // namespace PLF
+
+void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) {
+ hg->clear();
+ int c = 0;
+ int cur_node = 0;
+ if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
+ while (1) {
+ if (c > (int)in.size()) { break; }
+ if (PLF::get(in,c) == ')') {
+ c++;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') {
+ c+=2;
+ PLF::eatws(in,c);
+ break;
+ }
+ if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); }
+ PLF::ReadPLFNode(in, c, cur_node, line, hg);
+ ++cur_node;
+ }
+ assert(cur_node == hg->nodes_.size() - 1);
+}
+
+void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
+ Lattice& l = *pl;
+ Hypergraph g;
+ ReadFromPLF(plf, &g, 0);
+ const int num_nodes = g.nodes_.size() - 1;
+ l.resize(num_nodes);
+ for (int i = 0; i < num_nodes; ++i) {
+ vector<LatticeArc>& alts = l[i];
+ const Hypergraph::Node& node = g.nodes_[i];
+ const int num_alts = node.out_edges_.size();
+ alts.resize(num_alts);
+ for (int j = 0; j < num_alts; ++j) {
+ const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
+ alts[j].label = edge.rule_->e_[1];
+ alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0"));
+ alts[j].dist2next = edge.head_node_ - node.id_;
+ }
+ }
+}
+
+namespace B64 {
+
+static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq";
+
+static void encodeblock(const unsigned char* in, ostream* os, int len) {
+ char out[4];
+ out[0] = cb64[ in[0] >> 2 ];
+ out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ];
+ out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '=');
+ out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '=');
+ os->write(out, 4);
+}
+
+void b64encode(const char* data, const size_t size, ostream* out) {
+ size_t cur = 0;
+ while(cur < size) {
+ int len = min(static_cast<size_t>(3), size - cur);
+ encodeblock(reinterpret_cast<const unsigned char*>(&data[cur]), out, len);
+ cur += len;
+ }
+}
+
+static void decodeblock(const unsigned char* in, unsigned char* out) {
+ out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4);
+ out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2);
+ out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]);
+}
+
+bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) {
+ size_t cur = 0;
+ size_t ocur = 0;
+ unsigned char in[4];
+ while(cur < insize) {
+ assert(ocur < outsize);
+ for (int i = 0; i < 4; ++i) {
+ unsigned char v = data[cur];
+ v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]);
+ if (!v) {
+ cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl;
+ return false;
+ }
+ v = (unsigned char) ((v == '$') ? '\0' : v - 61);
+ if (v) in[i] = v - 1; else in[i] = 0;
+ ++cur;
+ }
+ decodeblock(in, reinterpret_cast<unsigned char*>(&out[ocur]));
+ ocur += 3;
+ }
+ return true;
+}
+}
+
diff --git a/src/hg_io.h b/src/hg_io.h
new file mode 100644
index 00000000..69a516c1
--- /dev/null
+++ b/src/hg_io.h
@@ -0,0 +1,37 @@
+#ifndef _HG_IO_H_
+#define _HG_IO_H_
+
+#include <iostream>
+
+#include "lattice.h"
+class Hypergraph;
+
+struct HypergraphIO {
+
+ // the format is basically a list of nodes and edges in topological order
+ // any edge you read, you must have already read its tail nodes
+ // any node you read, you must have already read its incoming edges
+ // this may make writing a bit more challenging if your forest is not
+ // topologically sorted (but that probably doesn't happen very often),
+ // but it makes reading much more memory efficient.
+ // see test_data/small.json.gz for an email encoding
+ static bool ReadFromJSON(std::istream* in, Hypergraph* out);
+
+ // if remove_rules is used, the hypergraph is serialized without rule information
+ // (so it only contains structure and feature information)
+ static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out);
+
+ // serialization utils
+ static void ReadFromPLF(const std::string& in, Hypergraph* out, int line = 0);
+ // return PLF string representation (undefined behavior on non-lattices)
+ static std::string AsPLF(const Hypergraph& hg, bool include_global_parentheses = true);
+ static void PLFtoLattice(const std::string& plf, Lattice* pl);
+ static std::string Escape(const std::string& s); // PLF helper
+};
+
+namespace B64 {
+ bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize);
+ void b64encode(const char* data, const size_t size, std::ostream* out);
+}
+
+#endif
diff --git a/src/hg_test.cc b/src/hg_test.cc
new file mode 100644
index 00000000..ecd97508
--- /dev/null
+++ b/src/hg_test.cc
@@ -0,0 +1,441 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "tdict.h"
+
+#include "json_parse.h"
+#include "filelib.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "hg_intersect.h"
+#include "viterbi.h"
+#include "kbest.h"
+#include "inside_outside.h"
+
+using namespace std;
+
+class HGTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+ void CreateHG(Hypergraph* hg) const;
+ void CreateHG_int(Hypergraph* hg) const;
+ void CreateHG_tiny(Hypergraph* hg) const;
+ void CreateHGBalanced(Hypergraph* hg) const;
+ void CreateLatticeHG(Hypergraph* hg) const;
+ void CreateTinyLatticeHG(Hypergraph* hg) const;
+};
+
+void HGTest::CreateTinyLatticeHG(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateLatticeHG(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG_tiny(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| <s>\",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG_int(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHG(Hypergraph* hg) const {
+ string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+void HGTest::CreateHGBalanced(Hypergraph* hg) const {
+ const string json = "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}";
+ istringstream instr(json);
+ EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg));
+}
+
+TEST_F(HGTest,Controlled) {
+ Hypergraph hg;
+ CreateHG_tiny(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t prob = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "prob: " << prob << "\n";
+ EXPECT_FLOAT_EQ(-80.839996, log(prob));
+ EXPECT_EQ("X <s>", TD::GetString(trans));
+ vector<prob_t> post;
+ hg.PrintGraphviz();
+ prob_t c2 = Inside<prob_t, ScaledEdgeProb>(hg, NULL, ScaledEdgeProb(0.6));
+ EXPECT_FLOAT_EQ(-47.8577, log(c2));
+}
+
+TEST_F(HGTest,Union) {
+ Hypergraph hg1;
+ Hypergraph hg2;
+ CreateHG_tiny(&hg1);
+ CreateHG(&hg2);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 1.0);
+ hg1.Reweight(wts);
+ hg2.Reweight(wts);
+ prob_t c1,c2,c3,c4;
+ vector<WordID> t1,t2,t3,t4;
+ c1 = ViterbiESentence(hg1, &t1);
+ c2 = ViterbiESentence(hg2, &t2);
+ int l2 = ViterbiPathLength(hg2);
+ cerr << c1 << "\t" << TD::GetString(t1) << endl;
+ cerr << c2 << "\t" << TD::GetString(t2) << endl;
+ hg1.Union(hg2);
+ hg1.Reweight(wts);
+ c3 = ViterbiESentence(hg1, &t3);
+ int l3 = ViterbiPathLength(hg1);
+ cerr << c3 << "\t" << TD::GetString(t3) << endl;
+ EXPECT_FLOAT_EQ(c2, c3);
+ EXPECT_EQ(TD::GetString(t2), TD::GetString(t3));
+ EXPECT_EQ(l2, l3);
+
+ wts.set_value(FD::Convert("f2"), -1);
+ hg1.Reweight(wts);
+ c4 = ViterbiESentence(hg1, &t4);
+ cerr << c4 << "\t" << TD::GetString(t4) << endl;
+ EXPECT_EQ("Z <s>", TD::GetString(t4));
+ EXPECT_FLOAT_EQ(98.82, log(c4));
+
+ vector<pair<vector<WordID>, prob_t> > list;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg1, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg1.nodes_.size() - 1, i);
+ if (!d) break;
+ list.push_back(make_pair(d->yield, d->score));
+ }
+ EXPECT_TRUE(list[0].first == t4);
+ EXPECT_FLOAT_EQ(log(list[0].second), log(c4));
+ EXPECT_EQ(list.size(), 6);
+ EXPECT_FLOAT_EQ(log(list.back().second / list.front().second), -97.7);
+}
+
+TEST_F(HGTest,ControlledKBest) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ vector<double> w(2); w[0]=0.4; w[1]=0.8;
+ hg.Reweight(w);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+
+ int best = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ ++best;
+ }
+ EXPECT_EQ(4, best);
+}
+
+
+TEST_F(HGTest,InsideScore) {
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ Hypergraph hg;
+ CreateTinyLatticeHG(&hg);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ prob_t inside = Inside<prob_t, EdgeProb>(hg);
+ EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand
+ vector<prob_t> post;
+ inside = hg.ComputeBestPathThroughEdges(&post);
+ EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand
+ EXPECT_EQ(post.size(), 4);
+ for (int i = 0; i < 4; ++i) {
+ cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl;
+ }
+}
+
+
+TEST_F(HGTest,PruneInsideOutside) {
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_1"), 1.0);
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ //hg.DensityPruneInsideOutside(0.5, false, 2.0);
+ hg.BeamPruneInsideOutside(0.5, false, 0.5);
+ cost = ViterbiESentence(hg, &trans);
+ cerr << "Ncst: " << cost << endl;
+ cerr << TD::GetString(trans) << "\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestPruneEdges) {
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ vector<bool> prune(hg.edges_.size(), true);
+ prune[6] = false;
+ hg.PruneEdges(prune);
+ cerr << "Pruned:\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestIntersect) {
+ Hypergraph hg;
+ CreateHG_int(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+
+ int best = 0;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 10);
+ for (int i = 0; i < 10; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ ++best;
+ }
+ EXPECT_EQ(4, best);
+
+ Lattice target(2);
+ target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1));
+ target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1));
+ HG::Intersect(target, &hg);
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestPrune2) {
+ Hypergraph hg;
+ CreateHG_int(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ vector<bool> rem(hg.edges_.size(), false);
+ rem[0] = true;
+ rem[1] = true;
+ hg.PruneEdges(rem);
+ hg.PrintGraphviz();
+ cerr << "TODO: fix this pruning behavior-- the resulting HG should be empty!\n";
+}
+
+TEST_F(HGTest,Sample) {
+ Hypergraph hg;
+ CreateLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_1"), 0.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,PLF) {
+ Hypergraph hg;
+ string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)";
+ HypergraphIO::ReadFromPLF(inplf, &hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Feature_0"), 1.0);
+ hg.Reweight(wts);
+ hg.PrintGraphviz();
+ string outplf = HypergraphIO::AsPLF(hg);
+ cerr << " IN: " << inplf << endl;
+ cerr << "OUT: " << outplf << endl;
+ assert(inplf == outplf);
+}
+
+TEST_F(HGTest,PushWeightsToGoal) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ vector<double> w(2); w[0]=0.4; w[1]=0.8;
+ hg.Reweight(w);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ hg.PrintGraphviz();
+ hg.PushWeightsToGoal();
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest,TestSpecialKBest) {
+ Hypergraph hg;
+ CreateHGBalanced(&hg);
+ vector<double> w(1); w[0]=0;
+ hg.Reweight(w);
+ vector<pair<vector<WordID>, prob_t> > list;
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 100000);
+ for (int i = 0; i < 100000; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << endl;
+ }
+ hg.PrintGraphviz();
+}
+
+TEST_F(HGTest, TestGenericViterbi) {
+ Hypergraph hg;
+ CreateHG_tiny(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ const prob_t prob = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "prob: " << prob << "\n";
+ EXPECT_FLOAT_EQ(-80.839996, log(prob));
+ EXPECT_EQ("X <s>", TD::GetString(trans));
+}
+
+TEST_F(HGTest, TestGenericInside) {
+ Hypergraph hg;
+ CreateTinyLatticeHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 1.0);
+ hg.Reweight(wts);
+ vector<prob_t> inside;
+ prob_t ins = Inside<prob_t, EdgeProb>(hg, &inside);
+ EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand
+ vector<prob_t> outside;
+ Outside<prob_t, EdgeProb>(hg, inside, &outside);
+ EXPECT_EQ(3, outside.size());
+ EXPECT_FLOAT_EQ(1.7934048, outside[0]);
+ EXPECT_FLOAT_EQ(1.3114071, outside[1]);
+ EXPECT_FLOAT_EQ(1.0, outside[2]);
+}
+
+TEST_F(HGTest,TestGenericInside2) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ vector<prob_t> inside, outside;
+ prob_t ins = Inside<prob_t, EdgeProb>(hg, &inside);
+ Outside<prob_t, EdgeProb>(hg, inside, &outside);
+ for (int i = 0; i < hg.nodes_.size(); ++i)
+ cerr << i << "\t" << log(inside[i]) << "\t" << log(outside[i]) << endl;
+ EXPECT_FLOAT_EQ(0, log(inside[0]));
+ EXPECT_FLOAT_EQ(-1.7861683, log(outside[0]));
+ EXPECT_FLOAT_EQ(-0.4, log(inside[1]));
+ EXPECT_FLOAT_EQ(-1.3861683, log(outside[1]));
+ EXPECT_FLOAT_EQ(-0.8, log(inside[2]));
+ EXPECT_FLOAT_EQ(-0.986168, log(outside[2]));
+ EXPECT_FLOAT_EQ(-0.96, log(inside[3]));
+ EXPECT_FLOAT_EQ(-0.8261683, log(outside[3]));
+ EXPECT_FLOAT_EQ(-1.562512, log(inside[4]));
+ EXPECT_FLOAT_EQ(-0.22365622, log(outside[4]));
+ EXPECT_FLOAT_EQ(-1.7861683, log(inside[5]));
+ EXPECT_FLOAT_EQ(0, log(outside[5]));
+}
+
+TEST_F(HGTest,TestAddExpectations) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 0.8);
+ hg.Reweight(wts);
+ SparseVector<double> feat_exps;
+ InsideOutside<prob_t, EdgeProb, SparseVector<double>, EdgeFeaturesWeightFunction>(hg, &feat_exps);
+ EXPECT_FLOAT_EQ(-2.5439765, feat_exps[FD::Convert("f1")]);
+ EXPECT_FLOAT_EQ(-2.6357865, feat_exps[FD::Convert("f2")]);
+ cerr << feat_exps << endl;
+ SparseVector<prob_t> posts;
+ InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, TransitionEventWeightFunction>(hg, &posts);
+}
+
+TEST_F(HGTest, Small) {
+ ReadFile rf("test_data/small.json.gz");
+ Hypergraph hg;
+ assert(HypergraphIO::ReadFromJSON(rf.stream(), &hg));
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("Model_0"), -2.0);
+ wts.set_value(FD::Convert("Model_1"), -0.5);
+ wts.set_value(FD::Convert("Model_2"), -1.1);
+ wts.set_value(FD::Convert("Model_3"), -1.0);
+ wts.set_value(FD::Convert("Model_4"), -1.0);
+ wts.set_value(FD::Convert("Model_5"), 0.5);
+ wts.set_value(FD::Convert("Model_6"), 0.2);
+ wts.set_value(FD::Convert("Model_7"), -3.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+ vector<prob_t> post;
+ prob_t c2 = Inside<prob_t, ScaledEdgeProb>(hg, NULL, ScaledEdgeProb(0.6));
+ EXPECT_FLOAT_EQ(2.1431036, log(c2));
+}
+
+TEST_F(HGTest, JSONTest) {
+ ostringstream os;
+ JSONParser::WriteEscapedString("\"I don't know\", she said.", &os);
+ EXPECT_EQ("\"\\\"I don't know\\\", she said.\"", os.str());
+ ostringstream os2;
+ JSONParser::WriteEscapedString("yes", &os2);
+ EXPECT_EQ("\"yes\"", os2.str());
+}
+
+TEST_F(HGTest, TestGenericKBest) {
+ Hypergraph hg;
+ CreateHG(&hg);
+ //CreateHGBalanced(&hg);
+ SparseVector<double> wts;
+ wts.set_value(FD::Convert("f1"), 0.4);
+ wts.set_value(FD::Convert("f2"), 1.0);
+ hg.Reweight(wts);
+ vector<WordID> trans;
+ prob_t cost = ViterbiESentence(hg, &trans);
+ cerr << TD::GetString(trans) << "\n";
+ cerr << "cost: " << cost << "\n";
+
+ KBest::KBestDerivations<vector<WordID>, ESentenceTraversal> kbest(hg, 1000);
+ for (int i = 0; i < 1000; ++i) {
+ const KBest::KBestDerivations<vector<WordID>, ESentenceTraversal>::Derivation* d =
+ kbest.LazyKthBest(hg.nodes_.size() - 1, i);
+ if (!d) break;
+ cerr << TD::GetString(d->yield) << " F:" << d->feature_values << endl;
+ }
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/ibm_model1.cc b/src/ibm_model1.cc
new file mode 100644
index 00000000..eb9c617c
--- /dev/null
+++ b/src/ibm_model1.cc
@@ -0,0 +1,4 @@
+#include <iostream>
+
+
+
diff --git a/src/inside_outside.h b/src/inside_outside.h
new file mode 100644
index 00000000..9114c9d7
--- /dev/null
+++ b/src/inside_outside.h
@@ -0,0 +1,111 @@
+#ifndef _INSIDE_H_
+#define _INSIDE_H_
+
+#include <vector>
+#include <algorithm>
+#include "hg.h"
+
+// run the inside algorithm and return the inside score
+// if result is non-NULL, result will contain the inside
+// score for each node
+// NOTE: WeightType(0) must construct the semiring's additive identity
+// WeightType(1) must construct the semiring's multiplicative identity
+template<typename WeightType, typename WeightFunction>
+WeightType Inside(const Hypergraph& hg,
+ std::vector<WeightType>* result = NULL,
+ const WeightFunction& weight = WeightFunction()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<WeightType> dummy;
+ std::vector<WeightType>& inside_score = result ? *result : dummy;
+ inside_score.resize(num_nodes);
+ std::fill(inside_score.begin(), inside_score.end(), WeightType());
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ WeightType* const cur_node_inside_score = &inside_score[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ if (num_in_edges == 0) {
+ *cur_node_inside_score = WeightType(1);
+ continue;
+ }
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ WeightType score = weight(edge);
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ const int tail_node_index = edge.tail_nodes_[k];
+ score *= inside_score[tail_node_index];
+ }
+ *cur_node_inside_score += score;
+ }
+ }
+ return inside_score.back();
+}
+
+template<typename WeightType, typename WeightFunction>
+void Outside(const Hypergraph& hg,
+ std::vector<WeightType>& inside_score,
+ std::vector<WeightType>* result,
+ const WeightFunction& weight = WeightFunction()) {
+ assert(result);
+ const int num_nodes = hg.nodes_.size();
+ assert(inside_score.size() == num_nodes);
+ std::vector<WeightType>& outside_score = *result;
+ outside_score.resize(num_nodes);
+ std::fill(outside_score.begin(), outside_score.end(), WeightType(0));
+ outside_score.back() = WeightType(1);
+ for (int i = num_nodes - 1; i >= 0; --i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const WeightType& head_node_outside_score = outside_score[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ const WeightType head_and_edge_weight = weight(edge) * head_node_outside_score;
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k) {
+ const int update_tail_node_index = edge.tail_nodes_[k];
+ WeightType* const tail_outside_score = &outside_score[update_tail_node_index];
+ WeightType inside_contribution = WeightType(1);
+ for (int l = 0; l < num_tail_nodes; ++l) {
+ const int other_tail_node_index = edge.tail_nodes_[l];
+ if (update_tail_node_index != other_tail_node_index)
+ inside_contribution *= inside_score[other_tail_node_index];
+ }
+ *tail_outside_score += head_and_edge_weight * inside_contribution;
+ }
+ }
+ }
+}
+
+// this is the Inside-Outside optimization described in Li et al. (EMNLP 2009)
+// for computing the inside algorithm over expensive semirings
+// (such as expectations over features). See Figure 4. It is slightly different
+// in that x/k is returned not (k,x)
+// NOTE: RType * PType must be valid (and yield RType)
+template<typename PType, typename WeightFunction, typename RType, typename WeightFunction2>
+PType InsideOutside(const Hypergraph& hg,
+ RType* result_x,
+ const WeightFunction& weight1 = WeightFunction(),
+ const WeightFunction2& weight2 = WeightFunction2()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<PType> inside, outside;
+ const PType z = Inside<PType,WeightFunction>(hg, &inside, weight1);
+ Outside<PType,WeightFunction>(hg, inside, &outside, weight1);
+ RType& x = *result_x;
+ x = RType();
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ const int num_in_edges = cur_node.in_edges_.size();
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ PType prob = outside[i];
+ prob *= weight1(edge);
+ const int num_tail_nodes = edge.tail_nodes_.size();
+ for (int k = 0; k < num_tail_nodes; ++k)
+ prob *= inside[edge.tail_nodes_[k]];
+ prob /= z;
+ x += weight2(edge) * prob;
+ }
+ }
+ return z;
+}
+
+#endif
diff --git a/src/json_parse.cc b/src/json_parse.cc
new file mode 100644
index 00000000..f6fdfea8
--- /dev/null
+++ b/src/json_parse.cc
@@ -0,0 +1,50 @@
+#include "json_parse.h"
+
+#include <string>
+#include <iostream>
+
+using namespace std;
+
+static const char *json_hex_chars = "0123456789abcdef";
+
+void JSONParser::WriteEscapedString(const string& in, ostream* out) {
+ int pos = 0;
+ int start_offset = 0;
+ unsigned char c = 0;
+ (*out) << '"';
+ while(pos < in.size()) {
+ c = in[pos];
+ switch(c) {
+ case '\b':
+ case '\n':
+ case '\r':
+ case '\t':
+ case '"':
+ case '\\':
+ case '/':
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ if(c == '\b') (*out) << "\\b";
+ else if(c == '\n') (*out) << "\\n";
+ else if(c == '\r') (*out) << "\\r";
+ else if(c == '\t') (*out) << "\\t";
+ else if(c == '"') (*out) << "\\\"";
+ else if(c == '\\') (*out) << "\\\\";
+ else if(c == '/') (*out) << "\\/";
+ start_offset = ++pos;
+ break;
+ default:
+ if(c < ' ') {
+ cerr << "Warning, bad character (" << static_cast<int>(c) << ") in string\n";
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf];
+ start_offset = ++pos;
+ } else pos++;
+ }
+ }
+ if(pos - start_offset > 0)
+ (*out) << in.substr(start_offset, pos - start_offset);
+ (*out) << '"';
+}
+
diff --git a/src/json_parse.h b/src/json_parse.h
new file mode 100644
index 00000000..c3cba954
--- /dev/null
+++ b/src/json_parse.h
@@ -0,0 +1,58 @@
+#ifndef _JSON_WRAPPER_H_
+#define _JSON_WRAPPER_H_
+
+#include <iostream>
+#include <cassert>
+#include "JSON_parser.h"
+
+class JSONParser {
+ public:
+ JSONParser() {
+ init_JSON_config(&config);
+ hack.mf = &JSONParser::Callback;
+ config.depth = 10;
+ config.callback_ctx = reinterpret_cast<void*>(this);
+ config.callback = hack.cb;
+ config.allow_comments = 1;
+ config.handle_floats_manually = 1;
+ jc = new_JSON_parser(&config);
+ }
+ virtual ~JSONParser() {
+ delete_JSON_parser(jc);
+ }
+ bool Parse(std::istream* in) {
+ int count = 0;
+ int lc = 1;
+ for (; in ; ++count) {
+ int next_char = in->get();
+ if (!in->good()) break;
+ if (lc == '\n') { ++lc; }
+ if (!JSON_parser_char(jc, next_char)) {
+ std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl;
+ return false;
+ }
+ }
+ if (!JSON_parser_done(jc)) {
+ std::cerr << "JSON_parser_done: syntax error\n";
+ return false;
+ }
+ return true;
+ }
+ static void WriteEscapedString(const std::string& in, std::ostream* out);
+ protected:
+ virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0;
+ private:
+ int Callback(int type, const JSON_value* value) {
+ if (HandleJSONEvent(type, value)) return 1;
+ return 0;
+ }
+ JSON_parser_struct* jc;
+ JSON_config config;
+ typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value);
+ union CBHack {
+ JSON_parser_callback cb;
+ MF mf;
+ } hack;
+};
+
+#endif
diff --git a/src/kbest.h b/src/kbest.h
new file mode 100644
index 00000000..cd9b6c2b
--- /dev/null
+++ b/src/kbest.h
@@ -0,0 +1,207 @@
+#ifndef _HG_KBEST_H_
+#define _HG_KBEST_H_
+
+#include <vector>
+#include <utility>
+#include <tr1/unordered_set>
+
+#include <boost/shared_ptr.hpp>
+
+#include "wordid.h"
+#include "hg.h"
+
+namespace KBest {
+ // default, don't filter any derivations from the k-best list
+ struct NoFilter {
+ bool operator()(const std::vector<WordID>& yield) {
+ (void) yield;
+ return false;
+ }
+ };
+
+ // optional, filter unique yield strings
+ struct FilterUnique {
+ std::tr1::unordered_set<std::vector<WordID>, boost::hash<std::vector<WordID> > > unique;
+
+ bool operator()(const std::vector<WordID>& yield) {
+ return !unique.insert(yield).second;
+ }
+ };
+
+ // utility class to lazily create the k-best derivations from a forest, uses
+ // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005)
+ template<typename T, // yield type (returned by Traversal)
+ typename Traversal,
+ typename DerivationFilter = NoFilter,
+ typename WeightType = prob_t,
+ typename WeightFunction = EdgeProb>
+ struct KBestDerivations {
+ KBestDerivations(const Hypergraph& hg,
+ const size_t k,
+ const Traversal& tf = Traversal(),
+ const WeightFunction& wf = WeightFunction()) :
+ traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {}
+
+ ~KBestDerivations() {
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ struct Derivation {
+ Derivation(const Hypergraph::Edge& e,
+ const SmallVector& jv,
+ const WeightType& w,
+ const SparseVector<double>& f) :
+ edge(&e),
+ j(jv),
+ score(w),
+ feature_values(f) {}
+
+ // dummy constructor, just for query
+ Derivation(const Hypergraph::Edge& e,
+ const SmallVector& jv) : edge(&e), j(jv) {}
+
+ T yield;
+ const Hypergraph::Edge* const edge;
+ const SmallVector j;
+ const WeightType score;
+ const SparseVector<double> feature_values;
+ };
+ struct HeapCompare {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return a->score < b->score;
+ }
+ };
+ struct DerivationCompare {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return a->score > b->score;
+ }
+ };
+ struct DerivationUniquenessHash {
+ size_t operator()(const Derivation* d) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ d->edge->id_;
+ for (int i = 0; i < d->j.size(); ++i)
+ x = ((x << 5) + x) ^ d->j[i];
+ return x;
+ }
+ };
+ struct DerivationUniquenessEquals {
+ bool operator()(const Derivation* a, const Derivation* b) const {
+ return (a->edge == b->edge) && (a->j == b->j);
+ }
+ };
+ typedef std::vector<Derivation*> CandidateHeap;
+ typedef std::vector<Derivation*> DerivationList;
+ typedef std::tr1::unordered_set<
+ const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet;
+
+ struct NodeDerivationState {
+ CandidateHeap cand;
+ DerivationList D;
+ DerivationFilter filter;
+ UniqueDerivationSet ds;
+ explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {}
+ };
+
+ Derivation* LazyKthBest(int v, int k) {
+ NodeDerivationState& s = GetCandidates(v);
+ CandidateHeap& cand = s.cand;
+ DerivationList& D = s.D;
+ DerivationFilter& filter = s.filter;
+ bool add_next = true;
+ while (D.size() <= k) {
+ if (add_next && D.size() > 0) {
+ const Derivation* d = D.back();
+ LazyNext(d, &cand, &s.ds);
+ }
+ add_next = false;
+
+ if (cand.size() > 0) {
+ std::pop_heap(cand.begin(), cand.end(), HeapCompare());
+ Derivation* d = cand.back();
+ cand.pop_back();
+ std::vector<const T*> ants(d->edge->Arity());
+ for (int j = 0; j < ants.size(); ++j)
+ ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield;
+ traverse(*d->edge, ants, &d->yield);
+ if (!filter(d->yield)) {
+ D.push_back(d);
+ add_next = true;
+ }
+ } else {
+ break;
+ }
+ }
+ if (k < D.size()) return D[k]; else return NULL;
+ }
+
+ private:
+ // creates a derivation object with all fields set but the yield
+ // the yield is computed in LazyKthBest before the derivation is added to D
+ // returns NULL if j refers to derivation numbers larger than the
+ // antecedent structure define
+ Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVector& j) {
+ WeightType score = w(e);
+ SparseVector<double> feats = e.feature_values_;
+ for (int i = 0; i < e.Arity(); ++i) {
+ const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]);
+ if (!ant) { return NULL; }
+ score *= ant->score;
+ feats += ant->feature_values;
+ }
+ freelist.push_back(new Derivation(e, j, score, feats));
+ return freelist.back();
+ }
+
+ NodeDerivationState& GetCandidates(int v) {
+ NodeDerivationState& s = nds[v];
+ if (!s.D.empty() || !s.cand.empty()) return s;
+
+ const Hypergraph::Node& node = g.nodes_[v];
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]];
+ SmallVector jv(edge.Arity(), 0);
+ Derivation* d = CreateDerivation(edge, jv);
+ assert(d);
+ s.cand.push_back(d);
+ }
+
+ const int effective_k = std::min(k_prime, s.cand.size());
+ const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k;
+ std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare());
+ s.cand.resize(effective_k);
+ std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare());
+
+ return s;
+ }
+
+ void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) {
+ for (int i = 0; i < d->j.size(); ++i) {
+ SmallVector j = d->j;
+ ++j[i];
+ const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]);
+ if (ant) {
+ Derivation query_unique(*d->edge, j);
+ if (ds->count(&query_unique) == 0) {
+ Derivation* new_d = CreateDerivation(*d->edge, j);
+ if (new_d) {
+ cand->push_back(new_d);
+ std::push_heap(cand->begin(), cand->end(), HeapCompare());
+ assert(ds->insert(new_d).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+ }
+
+ const Traversal traverse;
+ const WeightFunction w;
+ const Hypergraph& g;
+ std::vector<NodeDerivationState> nds;
+ std::vector<Derivation*> freelist;
+ const size_t k_prime;
+ };
+}
+
+#endif
diff --git a/src/lattice.cc b/src/lattice.cc
new file mode 100644
index 00000000..aa1df3db
--- /dev/null
+++ b/src/lattice.cc
@@ -0,0 +1,27 @@
+#include "lattice.h"
+
+#include "tdict.h"
+#include "hg_io.h"
+
+using namespace std;
+
+bool LatticeTools::LooksLikePLF(const string &line) {
+ return (line.size() > 5) && (line.substr(0,4) == "((('");
+}
+
+void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
+ Lattice& l = *pl;
+ vector<WordID> ids;
+ TD::ConvertSentence(text, &ids);
+ l.resize(ids.size());
+ for (int i = 0; i < l.size(); ++i)
+ l[i].push_back(LatticeArc(ids[i], 0.0, 1));
+}
+
+void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
+ if (LooksLikePLF(text_or_plf))
+ HypergraphIO::PLFtoLattice(text_or_plf, pl);
+ else
+ ConvertTextToLattice(text_or_plf, pl);
+}
+
diff --git a/src/lattice.h b/src/lattice.h
new file mode 100644
index 00000000..1177e768
--- /dev/null
+++ b/src/lattice.h
@@ -0,0 +1,31 @@
+#ifndef __LATTICE_H_
+#define __LATTICE_H_
+
+#include <string>
+#include <vector>
+#include "wordid.h"
+
+struct LatticeArc {
+ WordID label;
+ double cost;
+ int dist2next;
+ LatticeArc() : label(), cost(), dist2next() {}
+ LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {}
+};
+
+class Lattice : public std::vector<std::vector<LatticeArc> > {
+ public:
+ Lattice() {}
+ explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
+ std::vector<std::vector<LatticeArc> >(t, v) {}
+
+ // TODO add distance functions
+};
+
+struct LatticeTools {
+ static bool LooksLikePLF(const std::string &line);
+ static void ConvertTextToLattice(const std::string& text, Lattice* pl);
+ static void ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
+};
+
+#endif
diff --git a/src/lexcrf.cc b/src/lexcrf.cc
new file mode 100644
index 00000000..33455a3d
--- /dev/null
+++ b/src/lexcrf.cc
@@ -0,0 +1,112 @@
+#include "lexcrf.h"
+
+#include <iostream>
+
+#include "filelib.h"
+#include "hg.h"
+#include "tdict.h"
+#include "grammar.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+struct LexicalCRFImpl {
+ LexicalCRFImpl(const boost::program_options::variables_map& conf) :
+ use_null(false),
+ kXCAT(TD::Convert("X")*-1),
+ kNULL(TD::Convert("<eps>")),
+ kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")),
+ kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) {
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ assert(gfiles.size() == 1);
+ ReadFile rf(gfiles.front());
+ TextGrammar *tg = new TextGrammar;
+ grammar.reset(tg);
+ istream* in = rf.stream();
+ int lc = 0;
+ bool flag = false;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ ++lc;
+ TRulePtr r(TRule::CreateRulePhrasetable(line));
+ tg->AddRule(r);
+ if (lc % 50000 == 0) { cerr << '.'; flag = true; }
+ if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; }
+ }
+ if (flag) cerr << endl;
+ cerr << "Loaded " << lc << " rules\n";
+ }
+
+ void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) {
+ const int e_len = smeta.GetTargetLength();
+ assert(e_len > 0);
+ const int f_len = lattice.size();
+ // hack to tell the feature function system how big the sentence pair is
+ const int f_start = (use_null ? -1 : 0);
+ int prev_node_id = -1;
+ for (int i = 0; i < e_len; ++i) { // for each word in the *ref*
+ Hypergraph::Node* node = forest->AddNode(kXCAT);
+ const int new_node_id = node->id_;
+ for (int j = f_start; j < f_len; ++j) { // for each word in the source
+ const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label);
+ const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym);
+ if (!gi) {
+ cerr << "No translations found for: " << TD::Convert(src_sym) << "\n";
+ abort();
+ }
+ const RuleBin* rb = gi->GetRules();
+ assert(rb);
+ for (int k = 0; k < rb->GetNumRules(); ++k) {
+ TRulePtr rule = rb->GetIthRule(k);
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = j;
+ edge->j_ = j+1;
+ edge->prev_i_ = i;
+ edge->prev_j_ = i+1;
+ edge->feature_values_ += edge->rule_->GetFeatureValues();
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ }
+ }
+ if (prev_node_id >= 0) {
+ const int comb_node_id = forest->AddNode(kXCAT)->id_;
+ Hypergraph::TailNodeVector tail(2, prev_node_id);
+ tail[1] = new_node_id;
+ const int edge_id = forest->AddEdge(kBINARY, tail)->id_;
+ forest->ConnectEdgeToHeadNode(edge_id, comb_node_id);
+ prev_node_id = comb_node_id;
+ } else {
+ prev_node_id = new_node_id;
+ }
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ }
+
+ private:
+ const bool use_null;
+ const WordID kXCAT;
+ const WordID kNULL;
+ const TRulePtr kBINARY;
+ const TRulePtr kGOAL_RULE;
+ GrammarPtr grammar;
+};
+
+LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) :
+ pimpl_(new LexicalCRFImpl(conf)) {}
+
+bool LexicalCRF::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextToLattice(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ pimpl_->BuildTrellis(lattice, *smeta, forest);
+ forest->Reweight(weights);
+ return true;
+}
+
diff --git a/src/lexcrf.h b/src/lexcrf.h
new file mode 100644
index 00000000..99362c81
--- /dev/null
+++ b/src/lexcrf.h
@@ -0,0 +1,18 @@
+#ifndef _LEXCRF_H_
+#define _LEXCRF_H_
+
+#include "translator.h"
+#include "lattice.h"
+
+struct LexicalCRFImpl;
+struct LexicalCRF : public Translator {
+ LexicalCRF(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* forest);
+ private:
+ boost::shared_ptr<LexicalCRFImpl> pimpl_;
+};
+
+#endif
diff --git a/src/lm_ff.cc b/src/lm_ff.cc
new file mode 100644
index 00000000..f95140de
--- /dev/null
+++ b/src/lm_ff.cc
@@ -0,0 +1,328 @@
+#include "lm_ff.h"
+
+#include <sstream>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <netinet/in.h>
+#include <netdb.h>
+
+#include "tdict.h"
+#include "Vocab.h"
+#include "Ngram.h"
+#include "hg.h"
+#include "stringlib.h"
+
+using namespace std;
+
+struct LMClient {
+ struct Cache {
+ map<WordID, Cache> tree;
+ float prob;
+ Cache() : prob() {}
+ };
+
+ LMClient(const char* host) : port(6666) {
+ s = strchr(host, ':');
+ if (s != NULL) {
+ *s = '\0';
+ ++s;
+ port = atoi(s);
+ }
+ sock = socket(AF_INET, SOCK_STREAM, 0);
+ hp = gethostbyname(host);
+ if (hp == NULL) {
+ cerr << "unknown host " << host << endl;
+ abort();
+ }
+ bzero((char *)&server, sizeof(server));
+ bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length);
+ server.sin_family = hp->h_addrtype;
+ server.sin_port = htons(port);
+
+ int errors = 0;
+ while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) {
+ cerr << "Error: connect()\n";
+ sleep(1);
+ errors++;
+ if (errors > 3) exit(1);
+ }
+ cerr << "Connected to LM on " << host << " on port " << port << endl;
+ }
+
+ float wordProb(int word, int* context) {
+ Cache* cur = &cache;
+ int i = 0;
+ while (context[i] > 0) {
+ cur = &cur->tree[context[i++]];
+ }
+ cur = &cur->tree[word];
+ if (cur->prob) { return cur->prob; }
+
+ i = 0;
+ ostringstream os;
+ os << "prob " << TD::Convert(word);
+ while (context[i] > 0) {
+ os << ' ' << TD::Convert(context[i++]);
+ }
+ os << endl;
+ string out = os.str();
+ write(sock, out.c_str(), out.size());
+ int r = read(sock, res, 6);
+ int errors = 0;
+ int cnt = 0;
+ while (1) {
+ if (r < 0) {
+ errors++; sleep(1);
+ cerr << "Error: read()\n";
+ if (errors > 5) exit(1);
+ } else if (r==0 || res[cnt] == '\n') { break; }
+ else {
+ cnt += r;
+ if (cnt==6) break;
+ read(sock, &res[cnt], 6-cnt);
+ }
+ }
+ cur->prob = *reinterpret_cast<float*>(res);
+ return cur->prob;
+ }
+
+ void clear() {
+ cache.tree.clear();
+ }
+
+ private:
+ Cache cache;
+ int sock, port;
+ char *s;
+ struct hostent *hp;
+ struct sockaddr_in server;
+ char res[8];
+};
+
+class LanguageModelImpl {
+ public:
+ LanguageModelImpl(int order, const string& f) :
+ ngram_(*TD::dict_), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1),
+ floor_(-100.0),
+ client_(NULL),
+ kSTART(TD::Convert("<s>")),
+ kSTOP(TD::Convert("</s>")),
+ kUNKNOWN(TD::Convert("<unk>")),
+ kNONE(-1),
+ kSTAR(TD::Convert("<{STAR}>")) {
+ if (f.find("lm://") == 0) {
+ client_ = new LMClient(f.substr(5).c_str());
+ } else {
+ File file(f.c_str(), "r", 0);
+ assert(file);
+ cerr << "Reading " << order_ << "-gram LM from " << f << endl;
+ ngram_.read(file, false);
+ }
+ }
+
+ ~LanguageModelImpl() {
+ delete client_;
+ }
+
+ inline int StateSize(const void* state) const {
+ return *(static_cast<const char*>(state) + state_size_);
+ }
+
+ inline void SetStateSize(int size, void* state) const {
+ *(static_cast<char*>(state) + state_size_) = size;
+ }
+
+ inline double LookupProbForBufferContents(int i) {
+ double p = client_ ?
+ client_->wordProb(buffer_[i], &buffer_[i+1])
+ : ngram_.wordProb(buffer_[i], (VocabIndex*)&buffer_[i+1]);
+ if (p < floor_) p = floor_;
+ return p;
+ }
+
+ string DebugStateToString(const void* state) const {
+ int len = StateSize(state);
+ const int* astate = reinterpret_cast<const int*>(state);
+ string res = "[";
+ for (int i = 0; i < len; ++i) {
+ res += " ";
+ res += TD::Convert(astate[i]);
+ }
+ res += " ]";
+ return res;
+ }
+
+ inline double ProbNoRemnant(int i, int len) {
+ int edge = len;
+ bool flag = true;
+ double sum = 0.0;
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ flag = false;
+ } else if (buffer_[i] <= 0) {
+ edge = i;
+ flag = true;
+ } else {
+ if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART)))
+ sum += LookupProbForBufferContents(i);
+ }
+ --i;
+ }
+ return sum;
+ }
+
+ double EstimateProb(const vector<WordID>& phrase) {
+ int len = phrase.size();
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = phrase[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double EstimateProb(const void* state) {
+ int len = StateSize(state);
+ // cerr << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ const int* astate = reinterpret_cast<const int*>(state);
+ int i = len - 1;
+ for (int j = 0; j < len; ++j,--i)
+ buffer_[i] = astate[j];
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double FinalTraversalCost(const void* state) {
+ int slen = StateSize(state);
+ int len = slen + 2;
+ // cerr << "residual len: " << len << endl;
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ buffer_[len-1] = kSTART;
+ const int* astate = reinterpret_cast<const int*>(state);
+ int i = len - 2;
+ for (int j = 0; j < slen; ++j,--i)
+ buffer_[i] = astate[j];
+ buffer_[i] = kSTOP;
+ assert(i == 0);
+ return ProbNoRemnant(len - 1, len);
+ }
+
+ double LookupWords(const TRule& rule, const vector<const void*>& ant_states, void* vstate) {
+ int len = rule.ELength() - rule.Arity();
+ for (int i = 0; i < ant_states.size(); ++i)
+ len += StateSize(ant_states[i]);
+ buffer_.resize(len + 1);
+ buffer_[len] = kNONE;
+ int i = len - 1;
+ const vector<WordID>& e = rule.e();
+ for (int j = 0; j < e.size(); ++j) {
+ if (e[j] < 1) {
+ const int* astate = reinterpret_cast<const int*>(ant_states[-e[j]]);
+ int slen = StateSize(astate);
+ for (int k = 0; k < slen; ++k)
+ buffer_[i--] = astate[k];
+ } else {
+ buffer_[i--] = e[j];
+ }
+ }
+
+ double sum = 0.0;
+ int* remnant = reinterpret_cast<int*>(vstate);
+ int j = 0;
+ i = len - 1;
+ int edge = len;
+
+ while (i >= 0) {
+ if (buffer_[i] == kSTAR) {
+ edge = i;
+ } else if (edge-i >= order_) {
+ sum += LookupProbForBufferContents(i);
+ } else if (edge == len && remnant) {
+ remnant[j++] = buffer_[i];
+ }
+ --i;
+ }
+ if (!remnant) return sum;
+
+ if (edge != len || len >= order_) {
+ remnant[j++] = kSTAR;
+ if (order_-1 < edge) edge = order_-1;
+ for (int i = edge-1; i >= 0; --i)
+ remnant[j++] = buffer_[i];
+ }
+
+ SetStateSize(j, vstate);
+ return sum;
+ }
+
+ static int OrderToStateSize(int order) {
+ return ((order-1) * 2 + 1) * sizeof(WordID) + 1;
+ }
+
+ private:
+ Ngram ngram_;
+ vector<WordID> buffer_;
+ const int order_;
+ const int state_size_;
+ const double floor_;
+ LMClient* client_;
+
+ public:
+ const WordID kSTART;
+ const WordID kSTOP;
+ const WordID kUNKNOWN;
+ const WordID kNONE;
+ const WordID kSTAR;
+};
+
+LanguageModel::LanguageModel(const string& param) :
+ fid_(FD::Convert("LanguageModel")) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ int order = 3;
+ // TODO add support for -n FeatureName
+ string filename;
+ if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); }
+ else if (argc == 1) { filename = argv[0]; }
+ else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; }
+ else if (argc == 3) {
+ if (argv[0] == "-o") {
+ order = atoi(argv[1].c_str());
+ filename = argv[2];
+ } else if (argv[1] == "-o") {
+ order = atoi(argv[2].c_str());
+ filename = argv[0];
+ }
+ }
+ SetStateSize(LanguageModelImpl::OrderToStateSize(order));
+ pimpl_ = new LanguageModelImpl(order, filename);
+}
+
+LanguageModel::~LanguageModel() {
+ delete pimpl_;
+}
+
+string LanguageModel::DebugStateToString(const void* state) const{
+ return pimpl_->DebugStateToString(state);
+}
+
+void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_states,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* state) const {
+ (void) smeta;
+ features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state));
+ estimated_features->set_value(fid_, pimpl_->EstimateProb(state));
+}
+
+void LanguageModel::FinalTraversalFeatures(const void* ant_state,
+ SparseVector<double>* features) const {
+ features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state));
+}
+
diff --git a/src/lm_ff.h b/src/lm_ff.h
new file mode 100644
index 00000000..cd717360
--- /dev/null
+++ b/src/lm_ff.h
@@ -0,0 +1,32 @@
+#ifndef _LM_FF_H_
+#define _LM_FF_H_
+
+#include <vector>
+#include <string>
+
+#include "hg.h"
+#include "ff.h"
+
+class LanguageModelImpl;
+
+class LanguageModel : public FeatureFunction {
+ public:
+ // param = "filename.lm [-o n]"
+ LanguageModel(const std::string& param);
+ ~LanguageModel();
+ virtual void FinalTraversalFeatures(const void* context,
+ SparseVector<double>* features) const;
+ std::string DebugStateToString(const void* state) const;
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ const int fid_;
+ mutable LanguageModelImpl* pimpl_;
+};
+
+#endif
diff --git a/src/logval.h b/src/logval.h
new file mode 100644
index 00000000..a8ca620c
--- /dev/null
+++ b/src/logval.h
@@ -0,0 +1,136 @@
+#ifndef LOGVAL_H_
+#define LOGVAL_H_
+
+#include <cmath>
+#include <limits>
+
+template <typename T>
+class LogVal {
+ public:
+ LogVal() : v_(-std::numeric_limits<T>::infinity()) {}
+ explicit LogVal(double x) : v_(std::log(x)) {}
+ LogVal<T>(const LogVal<T>& o) : v_(o.v_) {}
+ static LogVal<T> One() { return LogVal(1); }
+ static LogVal<T> Zero() { return LogVal(); }
+
+ void logeq(const T& v) { v_ = v; }
+
+ LogVal& operator+=(const LogVal& a) {
+ if (a.v_ == -std::numeric_limits<T>::infinity()) return *this;
+ if (a.v_ < v_) {
+ v_ = v_ + log1p(std::exp(a.v_ - v_));
+ } else {
+ v_ = a.v_ + log1p(std::exp(v_ - a.v_));
+ }
+ return *this;
+ }
+
+ LogVal& operator*=(const LogVal& a) {
+ v_ += a.v_;
+ return *this;
+ }
+
+ LogVal& operator*=(const T& a) {
+ v_ += log(a);
+ return *this;
+ }
+
+ LogVal& operator/=(const LogVal& a) {
+ v_ -= a.v_;
+ return *this;
+ }
+
+ LogVal& poweq(const T& power) {
+ if (power == 0) v_ = 0; else v_ *= power;
+ return *this;
+ }
+
+ LogVal pow(const T& power) const {
+ LogVal res = *this;
+ res.poweq(power);
+ return res;
+ }
+
+ operator T() const {
+ return std::exp(v_);
+ }
+
+ T v_;
+};
+
+template<typename T>
+LogVal<T> operator+(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res += o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res *= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const LogVal<T>& o1, const T& o2) {
+ LogVal<T> res(o1);
+ res *= o2;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator*(const T& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o2);
+ res *= o1;
+ return res;
+}
+
+template<typename T>
+LogVal<T> operator/(const LogVal<T>& o1, const LogVal<T>& o2) {
+ LogVal<T> res(o1);
+ res /= o2;
+ return res;
+}
+
+template<typename T>
+T log(const LogVal<T>& o) {
+ return o.v_;
+}
+
+template <typename T>
+LogVal<T> pow(const LogVal<T>& b, const T& e) {
+ return b.pow(e);
+}
+
+template <typename T>
+bool operator<(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ < rhs.v_);
+}
+
+template <typename T>
+bool operator<=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ <= rhs.v_);
+}
+
+template <typename T>
+bool operator>(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ > rhs.v_);
+}
+
+template <typename T>
+bool operator>=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ >= rhs.v_);
+}
+
+template <typename T>
+bool operator==(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ == rhs.v_);
+}
+
+template <typename T>
+bool operator!=(const LogVal<T>& lhs, const LogVal<T>& rhs) {
+ return (lhs.v_ != rhs.v_);
+}
+
+#endif
diff --git a/src/maxtrans_blunsom.cc b/src/maxtrans_blunsom.cc
new file mode 100644
index 00000000..4a6680e0
--- /dev/null
+++ b/src/maxtrans_blunsom.cc
@@ -0,0 +1,287 @@
+#include "apply_models.h"
+
+#include <vector>
+#include <algorithm>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "tdict.h"
+#include "hg.h"
+#include "ff.h"
+
+using boost::tuple;
+using namespace std;
+using namespace std::tr1;
+
+namespace Hack {
+
+struct Candidate;
+typedef SmallVector JVector;
+typedef vector<Candidate*> CandidateHeap;
+typedef vector<Candidate*> CandidateList;
+
+// life cycle: candidates are created, placed on the heap
+// and retrieved by their estimated cost, when they're
+// retrieved, they're incorporated into the +LM hypergraph
+// where they also know the head node index they are
+// attached to. After they are added to the +LM hypergraph
+// inside_prob_ and est_prob_ fields may be updated as better
+// derivations are found (this happens since the successor's
+// of derivation d may have a better score- they are
+// explored lazily). However, the updates don't happen
+// when a candidate is in the heap so maintaining the heap
+// property is not an issue.
+struct Candidate {
+ int node_index_; // -1 until incorporated
+ // into the +LM forest
+ const Hypergraph::Edge* in_edge_; // in -LM forest
+ Hypergraph::Edge out_edge_;
+ vector<WordID> state_;
+ const JVector j_;
+ prob_t inside_prob_; // these are fixed until the cand
+ // is popped, then they may be updated
+ prob_t est_prob_;
+
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j,
+ const vector<CandidateList>& D,
+ bool is_goal) :
+ node_index_(-1),
+ in_edge_(&e),
+ j_(j) {
+ InitializeCandidate(D, is_goal);
+ }
+
+ // used to query uniqueness
+ Candidate(const Hypergraph::Edge& e,
+ const JVector& j) : in_edge_(&e), j_(j) {}
+
+ bool IsIncorporatedIntoHypergraph() const {
+ return node_index_ >= 0;
+ }
+
+ void InitializeCandidate(const vector<vector<Candidate*> >& D,
+ const bool is_goal) {
+ const Hypergraph::Edge& in_edge = *in_edge_;
+ out_edge_.rule_ = in_edge.rule_;
+ out_edge_.feature_values_ = in_edge.feature_values_;
+ Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_;
+ tail.resize(j_.size());
+ prob_t p = prob_t::One();
+ // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl;
+ vector<const vector<WordID>* > ants(tail.size());
+ for (int i = 0; i < tail.size(); ++i) {
+ const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]];
+ ants[i] = &ant.state_;
+ assert(ant.IsIncorporatedIntoHypergraph());
+ tail[i] = ant.node_index_;
+ p *= ant.inside_prob_;
+ }
+ prob_t edge_estimate = prob_t::One();
+ if (is_goal) {
+ assert(tail.size() == 1);
+ out_edge_.edge_prob_ = in_edge.edge_prob_;
+ } else {
+ in_edge.rule_->ESubstitute(ants, &state_);
+ out_edge_.edge_prob_ = in_edge.edge_prob_;
+ }
+ inside_prob_ = out_edge_.edge_prob_ * p;
+ est_prob_ = inside_prob_ * edge_estimate;
+ }
+};
+
+ostream& operator<<(ostream& os, const Candidate& cand) {
+ os << "CAND[";
+ if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; }
+ else { os << "+LM_node=" << cand.node_index_; }
+ os << " edge=" << cand.in_edge_->id_;
+ os << " j=<";
+ for (int i = 0; i < cand.j_.size(); ++i)
+ os << (i==0 ? "" : " ") << cand.j_[i];
+ os << "> vit=" << log(cand.inside_prob_);
+ os << " est=" << log(cand.est_prob_);
+ return os << ']';
+}
+
+struct HeapCandCompare {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ < r->est_prob_;
+ }
+};
+
+struct EstProbSorter {
+ bool operator()(const Candidate* l, const Candidate* r) const {
+ return l->est_prob_ > r->est_prob_;
+ }
+};
+
+// the same candidate <edge, j> can be added multiple times if
+// j is multidimensional (if you're going NW in Manhattan, you
+// can first go north, then west, or you can go west then north)
+// this is a hash function on the relevant variables from
+// Candidate to enforce this.
+struct CandidateUniquenessHash {
+ size_t operator()(const Candidate* c) const {
+ size_t x = 5381;
+ x = ((x << 5) + x) ^ c->in_edge_->id_;
+ for (int i = 0; i < c->j_.size(); ++i)
+ x = ((x << 5) + x) ^ c->j_[i];
+ return x;
+ }
+};
+
+struct CandidateUniquenessEquals {
+ bool operator()(const Candidate* a, const Candidate* b) const {
+ return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_);
+ }
+};
+
+typedef unordered_set<const Candidate*, CandidateUniquenessHash, CandidateUniquenessEquals> UniqueCandidateSet;
+typedef unordered_map<vector<WordID>, Candidate*, boost::hash<vector<WordID> > > State2Node;
+
+class MaxTransBeamSearch {
+
+public:
+ MaxTransBeamSearch(const Hypergraph& i, int pop_limit, Hypergraph* o) :
+ in(i),
+ out(*o),
+ D(in.nodes_.size()),
+ pop_limit_(pop_limit) {
+ cerr << " Finding max translation (cube pruning, pop_limit = " << pop_limit_ << ')' << endl;
+ }
+
+ void Apply() {
+ int num_nodes = in.nodes_.size();
+ int goal_id = num_nodes - 1;
+ int pregoal = goal_id - 1;
+ assert(in.nodes_[pregoal].out_edges_.size() == 1);
+ cerr << " ";
+ for (int i = 0; i < in.nodes_.size(); ++i) {
+ cerr << '.';
+ KBest(i, i == goal_id);
+ }
+ cerr << endl;
+ int best_node = D[goal_id].front()->in_edge_->tail_nodes_.front();
+ Candidate& best = *D[best_node].front();
+ cerr << " Best path: " << log(best.inside_prob_)
+ << "\t" << log(best.est_prob_) << endl;
+ cout << TD::GetString(D[best_node].front()->state_) << endl;
+ FreeAll();
+ }
+
+ private:
+ void FreeAll() {
+ for (int i = 0; i < D.size(); ++i) {
+ CandidateList& D_i = D[i];
+ for (int j = 0; j < D_i.size(); ++j)
+ delete D_i[j];
+ }
+ D.clear();
+ }
+
+ void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) {
+ Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_);
+ new_edge->feature_values_ = item->out_edge_.feature_values_;
+ new_edge->edge_prob_ = item->out_edge_.edge_prob_;
+ Candidate*& o_item = (*s2n)[item->state_];
+ if (!o_item) o_item = item;
+
+ int& node_id = o_item->node_index_;
+ if (node_id < 0) {
+ Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, "");
+ node_id = new_node->id_;
+ }
+ Hypergraph::Node* node = &out.nodes_[node_id];
+ out.ConnectEdgeToHeadNode(new_edge, node);
+
+ if (item != o_item) {
+ assert(o_item->state_ == item->state_); // sanity check!
+ o_item->est_prob_ += item->est_prob_;
+ o_item->inside_prob_ += item->inside_prob_;
+ freelist->push_back(item);
+ }
+ }
+
+ void KBest(const int vert_index, const bool is_goal) {
+ // cerr << "KBest(" << vert_index << ")\n";
+ CandidateList& D_v = D[vert_index];
+ assert(D_v.empty());
+ const Hypergraph::Node& v = in.nodes_[vert_index];
+ // cerr << " has " << v.in_edges_.size() << " in-coming edges\n";
+ const vector<int>& in_edges = v.in_edges_;
+ CandidateHeap cand;
+ CandidateList freelist;
+ cand.reserve(in_edges.size());
+ UniqueCandidateSet unique_cands;
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[in_edges[i]];
+ const JVector j(edge.tail_nodes_.size(), 0);
+ cand.push_back(new Candidate(edge, j, D, is_goal));
+ assert(unique_cands.insert(cand.back()).second); // these should all be unique!
+ }
+// cerr << " making heap of " << cand.size() << " candidates\n";
+ make_heap(cand.begin(), cand.end(), HeapCandCompare());
+ State2Node state2node; // "buf" in Figure 2
+ int pops = 0;
+ while(!cand.empty() && pops < pop_limit_) {
+ pop_heap(cand.begin(), cand.end(), HeapCandCompare());
+ Candidate* item = cand.back();
+ cand.pop_back();
+ // cerr << "POPPED: " << *item << endl;
+ PushSucc(*item, is_goal, &cand, &unique_cands);
+ IncorporateIntoPlusLMForest(item, &state2node, &freelist);
+ ++pops;
+ }
+ D_v.resize(state2node.size());
+ int c = 0;
+ for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i)
+ D_v[c++] = i->second;
+ sort(D_v.begin(), D_v.end(), EstProbSorter());
+ // cerr << " expanded to " << D_v.size() << " nodes\n";
+
+ for (int i = 0; i < cand.size(); ++i)
+ delete cand[i];
+ // freelist is necessary since even after an item merged, it still stays in
+ // the unique set so it can't be deleted til now
+ for (int i = 0; i < freelist.size(); ++i)
+ delete freelist[i];
+ }
+
+ void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) {
+ CandidateHeap& cand = *pcand;
+ for (int i = 0; i < item.j_.size(); ++i) {
+ JVector j = item.j_;
+ ++j[i];
+ if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) {
+ Candidate query_unique(*item.in_edge_, j);
+ if (cs->count(&query_unique) == 0) {
+ Candidate* new_cand = new Candidate(*item.in_edge_, j, D, is_goal);
+ cand.push_back(new_cand);
+ push_heap(cand.begin(), cand.end(), HeapCandCompare());
+ assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check
+ }
+ }
+ }
+ }
+
+ const Hypergraph& in;
+ Hypergraph& out;
+
+ vector<CandidateList> D; // maps nodes in in-HG to the
+ // equivalent nodes (many due to state
+ // splits) in the out-HG.
+ const int pop_limit_;
+};
+
+// each node in the graph has one of these, it keeps track of
+void MaxTrans(const Hypergraph& in,
+ int beam_size) {
+ Hypergraph out;
+ MaxTransBeamSearch ma(in, beam_size, &out);
+ ma.Apply();
+}
+
+}
diff --git a/src/parser_test.cc b/src/parser_test.cc
new file mode 100644
index 00000000..da1fbd89
--- /dev/null
+++ b/src/parser_test.cc
@@ -0,0 +1,35 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "hg.h"
+#include "trule.h"
+#include "bottom_up_parser.h"
+#include "tdict.h"
+
+using namespace std;
+
+class ChartTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(ChartTest,LanguageModel) {
+ LatticeArc a(TD::Convert("ein"), 0.0, 1);
+ LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ Lattice lattice(2);
+ lattice[0].push_back(a);
+ lattice[1].push_back(b);
+ Hypergraph forest;
+ GrammarPtr g(new TextGrammar);
+ vector<GrammarPtr> grammars(1, g);
+ ExhaustiveBottomUpParser parser("PHRASE", grammars);
+ parser.Parse(lattice, &forest);
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc
new file mode 100644
index 00000000..5eb70876
--- /dev/null
+++ b/src/phrasebased_translator.cc
@@ -0,0 +1,206 @@
+#include "phrasebased_translator.h"
+
+#include <queue>
+#include <iostream>
+#include <tr1/unordered_map>
+#include <tr1/unordered_set>
+
+#include <boost/tuple/tuple.hpp>
+#include <boost/functional/hash.hpp>
+
+#include "sentence_metadata.h"
+#include "tdict.h"
+#include "hg.h"
+#include "filelib.h"
+#include "lattice.h"
+#include "phrasetable_fst.h"
+#include "array2d.h"
+
+using namespace std;
+using namespace std::tr1;
+using namespace boost::tuples;
+
+struct Coverage : public vector<bool> {
+ explicit Coverage(int n, bool v = false) : vector<bool>(n, v), first_gap() {}
+ void Cover(int i, int j) {
+ vector<bool>::iterator it = this->begin() + i;
+ vector<bool>::iterator end = this->begin() + j;
+ while (it != end)
+ *it++ = true;
+ if (first_gap == i) {
+ first_gap = j;
+ it = end;
+ while (*it && it != this->end()) {
+ ++it;
+ ++first_gap;
+ }
+ }
+ }
+ bool Collides(int i, int j) const {
+ vector<bool>::const_iterator it = this->begin() + i;
+ vector<bool>::const_iterator end = this->begin() + j;
+ while (it != end)
+ if (*it++) return true;
+ return false;
+ }
+ int GetFirstGap() const { return first_gap; }
+ private:
+ int first_gap;
+};
+struct CoverageHash {
+ size_t operator()(const Coverage& cov) const {
+ return hasher_(static_cast<const vector<bool>&>(cov));
+ }
+ private:
+ boost::hash<vector<bool> > hasher_;
+};
+ostream& operator<<(ostream& os, const Coverage& cov) {
+ os << '[';
+ for (int i = 0; i < cov.size(); ++i)
+ os << (cov[i] ? '*' : '.');
+ return os << " gap=" << cov.GetFirstGap() << ']';
+}
+
+typedef unordered_map<Coverage, int, CoverageHash> CoverageNodeMap;
+typedef unordered_set<Coverage, CoverageHash> UniqueCoverageSet;
+
+struct PhraseBasedTranslatorImpl {
+ PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) :
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ max_distortion(conf["pb_max_distortion"].as<int>()),
+ kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)),
+ kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)),
+ kNT_TYPE(TD::Convert("X") * -1) {
+ assert(max_distortion >= 0);
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ assert(gfiles.size() == 1);
+ cerr << "Reading phrasetable from " << gfiles.front() << endl;
+ ReadFile in(gfiles.front());
+ fst.reset(LoadTextPhrasetable(in.stream()));
+ }
+
+ struct State {
+ State(const Coverage& c, int _i, int _j, const FSTNode* q) :
+ coverage(c), i(_i), j(_j), fst(q) {}
+ Coverage coverage;
+ int i;
+ int j;
+ const FSTNode* fst;
+ };
+
+ // we keep track of unique coverages that have been extended since it's
+ // possible to "extend" the same coverage twice, e.g. translate "a b c"
+ // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b"
+ void EnqueuePossibleContinuations(const Coverage& coverage, queue<State>* q, UniqueCoverageSet* ucs) {
+ if (ucs->insert(coverage).second) {
+ const int gap = coverage.GetFirstGap();
+ const int end = min(static_cast<int>(coverage.size()), gap + max_distortion + 1);
+ for (int i = gap; i < end; ++i)
+ if (!coverage[i]) q->push(State(coverage, i, i, fst.get()));
+ }
+ }
+
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextOrPLF(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
+ minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
+ if (add_pass_through_rules) {
+ SparseVector<double> feats;
+ feats.set_value(FD::Convert("PassThrough"), 1);
+ for (int i = 0; i < lattice.size(); ++i) {
+ const vector<LatticeArc>& arcs = lattice[i];
+ for (int j = 0; j < arcs.size(); ++j) {
+ fst->AddPassThroughTranslation(arcs[j].label, feats);
+ // TODO handle lattice edge features
+ }
+ }
+ }
+ CoverageNodeMap c;
+ queue<State> q;
+ UniqueCoverageSet ucs;
+ const Coverage empty_cov(lattice.size(), false);
+ const Coverage goal_cov(lattice.size(), true);
+ EnqueuePossibleContinuations(empty_cov, &q, &ucs);
+ c[empty_cov] = 0; // have to handle the left edge specially
+ while(!q.empty()) {
+ const State s = q.front();
+ q.pop();
+ // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl;
+ const vector<LatticeArc>& arcs = lattice[s.j];
+ if (s.fst->HasData()) {
+ Coverage new_cov = s.coverage;
+ new_cov.Cover(s.i, s.j);
+ EnqueuePossibleContinuations(new_cov, &q, &ucs);
+ const vector<TRulePtr>& phrases = s.fst->GetTranslations()->GetRules();
+ const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_;
+ for (int i = 0; i < phrases.size(); ++i) {
+ Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector());
+ edge->feature_values_ = edge->rule_->scores_;
+ minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index);
+ }
+ CoverageNodeMap::iterator cit = c.find(s.coverage);
+ assert(cit != c.end());
+ const int tail_node_plus1 = cit->second;
+ if (tail_node_plus1 == 0) { // left edge
+ c[new_cov] = phrase_head_index + 1;
+ } else { // not left edge
+ int& head_node_plus1 = c[new_cov];
+ if (!head_node_plus1)
+ head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1;
+ Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1);
+ tail[1] = phrase_head_index;
+ const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1);
+ }
+ }
+ if (s.j == lattice.size()) continue;
+ for (int l = 0; l < arcs.size(); ++l) {
+ const LatticeArc& arc = arcs[l];
+
+ const FSTNode* next_fst_state = s.fst->Extend(arc.label);
+ const int next_j = s.j + arc.dist2next;
+ if (next_fst_state &&
+ !s.coverage.Collides(s.i, next_j)) {
+ q.push(State(s.coverage, s.i, next_j, next_fst_state));
+ }
+ }
+ }
+ if (add_pass_through_rules)
+ fst->ClearPassThroughTranslations();
+ int pregoal_plus1 = c[goal_cov];
+ if (pregoal_plus1 > 0) {
+ TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]"));
+ int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_;
+ int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_;
+ minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal);
+ // they are almost topo, but not quite always
+ minus_lm_forest->TopologicallySortNodesAndEdges(goal);
+ minus_lm_forest->Reweight(weights);
+ return true;
+ } else {
+ return false; // composition failed
+ }
+ }
+
+ const bool add_pass_through_rules;
+ const int max_distortion;
+ TRulePtr kSOURCE_RULE;
+ const TRulePtr kCONCAT_RULE;
+ const WordID kNT_TYPE;
+ boost::shared_ptr<FSTNode> fst;
+};
+
+PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new PhraseBasedTranslatorImpl(conf)) {}
+
+bool PhraseBasedTranslator::Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
diff --git a/src/phrasebased_translator.h b/src/phrasebased_translator.h
new file mode 100644
index 00000000..d42ce79c
--- /dev/null
+++ b/src/phrasebased_translator.h
@@ -0,0 +1,18 @@
+#ifndef _PHRASEBASED_TRANSLATOR_H_
+#define _PHRASEBASED_TRANSLATOR_H_
+
+#include "translator.h"
+
+class PhraseBasedTranslatorImpl;
+class PhraseBasedTranslator : public Translator {
+ public:
+ PhraseBasedTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<PhraseBasedTranslatorImpl> pimpl_;
+};
+
+#endif
diff --git a/src/phrasetable_fst.cc b/src/phrasetable_fst.cc
new file mode 100644
index 00000000..f421e941
--- /dev/null
+++ b/src/phrasetable_fst.cc
@@ -0,0 +1,141 @@
+#include "phrasetable_fst.h"
+
+#include <cassert>
+#include <iostream>
+#include <map>
+
+#include <boost/shared_ptr.hpp>
+
+#include "filelib.h"
+#include "tdict.h"
+
+using boost::shared_ptr;
+using namespace std;
+
+TargetPhraseSet::~TargetPhraseSet() {}
+FSTNode::~FSTNode() {}
+
+class TextTargetPhraseSet : public TargetPhraseSet {
+ public:
+ void AddRule(TRulePtr rule) {
+ rules_.push_back(rule);
+ }
+ const vector<TRulePtr>& GetRules() const {
+ return rules_;
+ }
+
+ private:
+ // all rules must have arity 0
+ vector<TRulePtr> rules_;
+};
+
+class TextFSTNode : public FSTNode {
+ public:
+ const TargetPhraseSet* GetTranslations() const { return data.get(); }
+ bool HasData() const { return (bool)data; }
+ bool HasOutgoingNonEpsilonEdges() const { return !ptr.empty(); }
+ const FSTNode* Extend(const WordID& t) const {
+ map<WordID, TextFSTNode>::const_iterator it = ptr.find(t);
+ if (it == ptr.end()) return NULL;
+ return &it->second;
+ }
+
+ void AddPhrase(const string& phrase);
+
+ void AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats);
+ void ClearPassThroughTranslations();
+ private:
+ vector<WordID> passthroughs;
+ shared_ptr<TargetPhraseSet> data;
+ map<WordID, TextFSTNode> ptr;
+};
+
+#ifdef DEBUG_CHART_PARSER
+static string TrimRule(const string& r) {
+ size_t start = r.find(" |||") + 5;
+ size_t end = r.rfind(" |||");
+ return r.substr(start, end - start);
+}
+#endif
+
+void TextFSTNode::AddPhrase(const string& phrase) {
+ vector<WordID> words;
+ TRulePtr rule(TRule::CreateRulePhrasetable(phrase));
+ if (!rule) {
+ static int err = 0;
+ ++err;
+ if (err > 2) { cerr << "TOO MANY PHRASETABLE ERRORS\n"; exit(1); }
+ return;
+ }
+
+ TextFSTNode* fsa = this;
+ for (int i = 0; i < rule->FLength(); ++i)
+ fsa = &fsa->ptr[rule->f_[i]];
+
+ if (!fsa->data)
+ fsa->data.reset(new TextTargetPhraseSet);
+ static_cast<TextTargetPhraseSet*>(fsa->data.get())->AddRule(rule);
+}
+
+void TextFSTNode::AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats) {
+ TextFSTNode* next = &ptr[w];
+ // current, rules are only added if the symbol is completely missing as a
+ // word starting the phrase. As a result, it is possible that some sentences
+ // won't parse. If this becomes a problem, fix it here.
+ if (!next->data) {
+ TextTargetPhraseSet* tps = new TextTargetPhraseSet;
+ next->data.reset(tps);
+ TRule* rule = new TRule;
+ rule->e_.resize(1, w);
+ rule->f_.resize(1, w);
+ rule->lhs_ = TD::Convert("___PHRASE") * -1;
+ rule->scores_ = feats;
+ rule->arity_ = 0;
+ tps->AddRule(TRulePtr(rule));
+ passthroughs.push_back(w);
+ }
+}
+
+void TextFSTNode::ClearPassThroughTranslations() {
+ for (int i = 0; i < passthroughs.size(); ++i)
+ ptr.erase(passthroughs[i]);
+ passthroughs.clear();
+}
+
+static void AddPhrasetableToFST(istream* in, TextFSTNode* fst) {
+ int lc = 0;
+ bool flag = false;
+ while(*in) {
+ string line;
+ getline(*in, line);
+ if (line.empty()) continue;
+ ++lc;
+ fst->AddPhrase(line);
+ if (lc % 10000 == 0) { flag = true; cerr << '.' << flush; }
+ if (lc % 500000 == 0) { flag = false; cerr << " [" << lc << ']' << endl << flush; }
+ }
+ if (flag) cerr << endl;
+ cerr << "Loaded " << lc << " source phrases\n";
+}
+
+FSTNode* LoadTextPhrasetable(istream* in) {
+ TextFSTNode *fst = new TextFSTNode;
+ AddPhrasetableToFST(in, fst);
+ return fst;
+}
+
+FSTNode* LoadTextPhrasetable(const vector<string>& filenames) {
+ TextFSTNode* fst = new TextFSTNode;
+ for (int i = 0; i < filenames.size(); ++i) {
+ ReadFile rf(filenames[i]);
+ cerr << "Reading phrase from " << filenames[i] << endl;
+ AddPhrasetableToFST(rf.stream(), fst);
+ }
+ return fst;
+}
+
+FSTNode* LoadBinaryPhrasetable(const string& fname_prefix) {
+ (void) fname_prefix;
+ assert(!"not implemented yet");
+}
+
diff --git a/src/phrasetable_fst.h b/src/phrasetable_fst.h
new file mode 100644
index 00000000..477de1f7
--- /dev/null
+++ b/src/phrasetable_fst.h
@@ -0,0 +1,34 @@
+#ifndef _PHRASETABLE_FST_H_
+#define _PHRASETABLE_FST_H_
+
+#include <vector>
+#include <string>
+
+#include "sparse_vector.h"
+#include "trule.h"
+
+class TargetPhraseSet {
+ public:
+ virtual ~TargetPhraseSet();
+ virtual const std::vector<TRulePtr>& GetRules() const = 0;
+};
+
+class FSTNode {
+ public:
+ virtual ~FSTNode();
+ virtual const TargetPhraseSet* GetTranslations() const = 0;
+ virtual bool HasData() const = 0;
+ virtual bool HasOutgoingNonEpsilonEdges() const = 0;
+ virtual const FSTNode* Extend(const WordID& t) const = 0;
+
+ // these should only be called on q_0:
+ virtual void AddPassThroughTranslation(const WordID& w, const SparseVector<double>& feats) = 0;
+ virtual void ClearPassThroughTranslations() = 0;
+};
+
+// attn caller: you own the memory
+FSTNode* LoadTextPhrasetable(const std::vector<std::string>& filenames);
+FSTNode* LoadTextPhrasetable(std::istream* in);
+FSTNode* LoadBinaryPhrasetable(const std::string& fname_prefix);
+
+#endif
diff --git a/src/prob.h b/src/prob.h
new file mode 100644
index 00000000..bc297870
--- /dev/null
+++ b/src/prob.h
@@ -0,0 +1,8 @@
+#ifndef _PROB_H_
+#define _PROB_H_
+
+#include "logval.h"
+
+typedef LogVal<double> prob_t;
+
+#endif
diff --git a/src/sampler.h b/src/sampler.h
new file mode 100644
index 00000000..e5840f41
--- /dev/null
+++ b/src/sampler.h
@@ -0,0 +1,136 @@
+#ifndef SAMPLER_H_
+#define SAMPLER_H_
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <iostream>
+#include <fstream>
+#include <vector>
+
+#include <boost/random/mersenne_twister.hpp>
+#include <boost/random/uniform_real.hpp>
+#include <boost/random/variate_generator.hpp>
+#include <boost/random/normal_distribution.hpp>
+#include <boost/random/poisson_distribution.hpp>
+
+#include "prob.h"
+
+struct SampleSet;
+
+template <typename RNG>
+struct RandomNumberGenerator {
+ static uint32_t GetTrulyRandomSeed() {
+ uint32_t seed;
+ std::ifstream r("/dev/urandom");
+ if (r) {
+ r.read((char*)&seed,sizeof(uint32_t));
+ }
+ if (r.fail() || !r) {
+ std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl;
+ seed = time(NULL);
+ }
+ std::cerr << "Seeding random number sequence to " << seed << std::endl;
+ return seed;
+ }
+
+ RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ uint32_t seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+ explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) {
+ if (!seed) seed = GetTrulyRandomSeed();
+ m_generator.seed(seed);
+ }
+
+ size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) {
+ if (T == 1.0) {
+ if (this->next() > (a / (a + b))) return 1; else return 0;
+ } else {
+ assert(!"not implemented");
+ }
+ }
+
+ // T is the annealing temperature, if desired
+ size_t SelectSample(const SampleSet& ss, double T = 1.0);
+
+ // draw a value from U(0,1)
+ double next() {return m_random();}
+
+ // draw a value from N(mean,var)
+ double NextNormal(double mean, double var) {
+ return boost::normal_distribution<double>(mean, var)(m_random);
+ }
+
+ // draw a value from a Poisson distribution
+ // lambda must be greater than 0
+ int NextPoisson(int lambda) {
+ return boost::poisson_distribution<int>(lambda)(m_random);
+ }
+
+ bool AcceptMetropolisHastings(const prob_t& p_cur,
+ const prob_t& p_prev,
+ const prob_t& q_cur,
+ const prob_t& q_prev) {
+ const prob_t a = (p_cur / p_prev) * (q_prev / q_cur);
+ if (log(a) >= 0.0) return true;
+ return (prob_t(this->next()) < a);
+ }
+
+ private:
+ boost::uniform_real<> m_dist;
+ RNG m_generator;
+ boost::variate_generator<RNG&, boost::uniform_real<> > m_random;
+};
+
+typedef RandomNumberGenerator<boost::mt19937> MT19937;
+
+class SampleSet {
+ public:
+ const prob_t& operator[](int i) const { return m_scores[i]; }
+ bool empty() const { return m_scores.empty(); }
+ void add(const prob_t& s) { m_scores.push_back(s); }
+ void clear() { m_scores.clear(); }
+ size_t size() const { return m_scores.size(); }
+ std::vector<prob_t> m_scores;
+};
+
+template <typename RNG>
+size_t RandomNumberGenerator<RNG>::SelectSample(const SampleSet& ss, double T) {
+ assert(T > 0.0);
+ assert(ss.m_scores.size() > 0);
+ if (ss.m_scores.size() == 1) return 0;
+ const prob_t annealing_factor(1.0 / T);
+ const bool anneal = (annealing_factor != prob_t::One());
+ prob_t sum = prob_t::Zero();
+ if (anneal) {
+ for (int i = 0; i < ss.m_scores.size(); ++i)
+ sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T)
+ } else {
+ sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero());
+ }
+ //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ",";
+ //std::cerr << std::endl;
+
+ prob_t random(this->next()); // random number between 0 and 1
+ random *= sum; // scale with normalization factor
+ //std::cerr << "Random number " << random << std::endl;
+
+ //now figure out which sample
+ size_t position = 1;
+ sum = ss.m_scores[0];
+ if (anneal) {
+ sum.poweq(annealing_factor);
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position].pow(annealing_factor);
+ } else {
+ for (; position < ss.m_scores.size() && sum < random; ++position)
+ sum += ss.m_scores[position];
+ }
+ //std::cout << "random: " << random << " sample: " << position << std::endl;
+ //std::cerr << "Sample: " << position-1 << std::endl;
+ //exit(1);
+ return position-1;
+}
+
+#endif
diff --git a/src/scfg_translator.cc b/src/scfg_translator.cc
new file mode 100644
index 00000000..03602c6b
--- /dev/null
+++ b/src/scfg_translator.cc
@@ -0,0 +1,66 @@
+#include "translator.h"
+
+#include <vector>
+
+#include "hg.h"
+#include "grammar.h"
+#include "bottom_up_parser.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+Translator::~Translator() {}
+
+struct SCFGTranslatorImpl {
+ SCFGTranslatorImpl(const boost::program_options::variables_map& conf) :
+ max_span_limit(conf["scfg_max_span_limit"].as<int>()),
+ add_pass_through_rules(conf.count("add_pass_through_rules")),
+ goal(conf["goal"].as<string>()),
+ default_nt(conf["scfg_default_nt"].as<string>()) {
+ vector<string> gfiles = conf["grammar"].as<vector<string> >();
+ for (int i = 0; i < gfiles.size(); ++i) {
+ cerr << "Reading SCFG grammar from " << gfiles[i] << endl;
+ TextGrammar* g = new TextGrammar(gfiles[i]);
+ g->SetMaxSpan(max_span_limit);
+ grammars.push_back(GrammarPtr(g));
+ }
+ if (!conf.count("scfg_no_hiero_glue_grammar"))
+ grammars.push_back(GrammarPtr(new GlueGrammar(goal, default_nt)));
+ if (conf.count("scfg_extra_glue_grammar"))
+ grammars.push_back(GrammarPtr(new GlueGrammar(conf["scfg_extra_glue_grammar"].as<string>())));
+ }
+
+ const int max_span_limit;
+ const bool add_pass_through_rules;
+ const string goal;
+ const string default_nt;
+ vector<GrammarPtr> grammars;
+
+ bool Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ vector<GrammarPtr> glist = grammars;
+ Lattice lattice;
+ LatticeTools::ConvertTextOrPLF(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ if (add_pass_through_rules)
+ glist.push_back(GrammarPtr(new PassThroughGrammar(lattice, default_nt)));
+ ExhaustiveBottomUpParser parser(goal, glist);
+ if (!parser.Parse(lattice, forest))
+ return false;
+ forest->Reweight(weights);
+ return true;
+ }
+};
+
+SCFGTranslator::SCFGTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new SCFGTranslatorImpl(conf)) {}
+
+bool SCFGTranslator::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ return pimpl_->Translate(input, smeta, weights, minus_lm_forest);
+}
+
diff --git a/src/sentence_metadata.h b/src/sentence_metadata.h
new file mode 100644
index 00000000..0178f1f5
--- /dev/null
+++ b/src/sentence_metadata.h
@@ -0,0 +1,42 @@
+#ifndef _SENTENCE_METADATA_H_
+#define _SENTENCE_METADATA_H_
+
+#include <cassert>
+#include "lattice.h"
+
+struct SentenceMetadata {
+ SentenceMetadata(int id, const Lattice& ref) :
+ sent_id_(id),
+ src_len_(-1),
+ has_reference_(ref.size() > 0),
+ trg_len_(ref.size()),
+ ref_(has_reference_ ? &ref : NULL) {}
+
+ // this should be called by the Translator object after
+ // it has parsed the source
+ void SetSourceLength(int sl) { src_len_ = sl; }
+
+ // this should be called if a separate model needs to
+ // specify how long the target sentence should be
+ void SetTargetLength(int tl) {
+ assert(!has_reference_);
+ trg_len_ = tl;
+ }
+ bool HasReference() const { return has_reference_; }
+ const Lattice& GetReference() const { return *ref_; }
+ int GetSourceLength() const { return src_len_; }
+ int GetTargetLength() const { return trg_len_; }
+ int GetSentenceID() const { return sent_id_; }
+
+ private:
+ const int sent_id_;
+ int src_len_;
+
+ // you need to be very careful when depending on these values
+ // they will only be set during training / alignment contexts
+ const bool has_reference_;
+ int trg_len_;
+ const Lattice* const ref_;
+};
+
+#endif
diff --git a/src/small_vector.h b/src/small_vector.h
new file mode 100644
index 00000000..800c1df1
--- /dev/null
+++ b/src/small_vector.h
@@ -0,0 +1,187 @@
+#ifndef _SMALL_VECTOR_H_
+
+#include <streambuf> // std::max - where to get this?
+#include <cstring>
+#include <cassert>
+
+#define __SV_MAX_STATIC 2
+
+class SmallVector {
+
+ public:
+ SmallVector() : size_(0) {}
+
+ explicit SmallVector(size_t s, int v = 0) : size_(s) {
+ assert(s < 0x80);
+ if (s <= __SV_MAX_STATIC) {
+ for (int i = 0; i < s; ++i) data_.vals[i] = v;
+ } else {
+ capacity_ = s;
+ size_ = s;
+ data_.ptr = new int[s];
+ for (int i = 0; i < size_; ++i) data_.ptr[i] = v;
+ }
+ }
+
+ SmallVector(const SmallVector& o) : size_(o.size_) {
+ if (size_ <= __SV_MAX_STATIC) {
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ capacity_ = size_ = o.size_;
+ data_.ptr = new int[capacity_];
+ std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int));
+ }
+ }
+
+ const SmallVector& operator=(const SmallVector& o) {
+ if (size_ <= __SV_MAX_STATIC) {
+ if (o.size_ <= __SV_MAX_STATIC) {
+ size_ = o.size_;
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ capacity_ = size_ = o.size_;
+ data_.ptr = new int[capacity_];
+ std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int));
+ }
+ } else {
+ if (o.size_ <= __SV_MAX_STATIC) {
+ delete[] data_.ptr;
+ size_ = o.size_;
+ for (int i = 0; i < size_; ++i) data_.vals[i] = o.data_.vals[i];
+ } else {
+ if (capacity_ < o.size_) {
+ delete[] data_.ptr;
+ capacity_ = o.size_;
+ data_.ptr = new int[capacity_];
+ }
+ size_ = o.size_;
+ for (int i = 0; i < size_; ++i)
+ data_.ptr[i] = o.data_.ptr[i];
+ }
+ }
+ return *this;
+ }
+
+ ~SmallVector() {
+ if (size_ <= __SV_MAX_STATIC) return;
+ delete[] data_.ptr;
+ }
+
+ void clear() {
+ if (size_ > __SV_MAX_STATIC) {
+ delete[] data_.ptr;
+ }
+ size_ = 0;
+ }
+
+ bool empty() const { return size_ == 0; }
+ size_t size() const { return size_; }
+
+ inline void ensure_capacity(unsigned char min_size) {
+ assert(min_size > __SV_MAX_STATIC);
+ if (min_size < capacity_) return;
+ unsigned char new_cap = std::max(static_cast<unsigned char>(capacity_ << 1), min_size);
+ int* tmp = new int[new_cap];
+ std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int));
+ delete[] data_.ptr;
+ data_.ptr = tmp;
+ capacity_ = new_cap;
+ }
+
+ inline void copy_vals_to_ptr() {
+ capacity_ = __SV_MAX_STATIC * 2;
+ int* tmp = new int[capacity_];
+ for (int i = 0; i < __SV_MAX_STATIC; ++i) tmp[i] = data_.vals[i];
+ data_.ptr = tmp;
+ }
+
+ inline void push_back(int v) {
+ if (size_ < __SV_MAX_STATIC) {
+ data_.vals[size_] = v;
+ ++size_;
+ return;
+ } else if (size_ == __SV_MAX_STATIC) {
+ copy_vals_to_ptr();
+ } else if (size_ == capacity_) {
+ ensure_capacity(size_ + 1);
+ }
+ data_.ptr[size_] = v;
+ ++size_;
+ }
+
+ int& back() { return this->operator[](size_ - 1); }
+ const int& back() const { return this->operator[](size_ - 1); }
+ int& front() { return this->operator[](0); }
+ const int& front() const { return this->operator[](0); }
+
+ void resize(size_t s, int v = 0) {
+ if (s <= __SV_MAX_STATIC) {
+ if (size_ > __SV_MAX_STATIC) {
+ int tmp[__SV_MAX_STATIC];
+ for (int i = 0; i < s; ++i) tmp[i] = data_.ptr[i];
+ delete[] data_.ptr;
+ for (int i = 0; i < s; ++i) data_.vals[i] = tmp[i];
+ size_ = s;
+ return;
+ }
+ if (s <= size_) {
+ size_ = s;
+ return;
+ } else {
+ for (int i = size_; i < s; ++i)
+ data_.vals[i] = v;
+ size_ = s;
+ return;
+ }
+ } else {
+ if (size_ <= __SV_MAX_STATIC)
+ copy_vals_to_ptr();
+ if (s > capacity_)
+ ensure_capacity(s);
+ if (s > size_) {
+ for (int i = size_; i < s; ++i)
+ data_.ptr[i] = v;
+ }
+ size_ = s;
+ }
+ }
+
+ int& operator[](size_t i) {
+ if (size_ <= __SV_MAX_STATIC) return data_.vals[i];
+ return data_.ptr[i];
+ }
+
+ const int& operator[](size_t i) const {
+ if (size_ <= __SV_MAX_STATIC) return data_.vals[i];
+ return data_.ptr[i];
+ }
+
+ bool operator==(const SmallVector& o) const {
+ if (size_ != o.size_) return false;
+ if (size_ <= __SV_MAX_STATIC) {
+ for (size_t i = 0; i < size_; ++i)
+ if (data_.vals[i] != o.data_.vals[i]) return false;
+ return true;
+ } else {
+ for (size_t i = 0; i < size_; ++i)
+ if (data_.ptr[i] != o.data_.ptr[i]) return false;
+ return true;
+ }
+ }
+
+ private:
+ unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC
+ unsigned char size_;
+ union StorageType {
+ int vals[__SV_MAX_STATIC];
+ int* ptr;
+ };
+ StorageType data_;
+
+};
+
+inline bool operator!=(const SmallVector& a, const SmallVector& b) {
+ return !(a==b);
+}
+
+#endif
diff --git a/src/small_vector_test.cc b/src/small_vector_test.cc
new file mode 100644
index 00000000..84237791
--- /dev/null
+++ b/src/small_vector_test.cc
@@ -0,0 +1,129 @@
+#include "small_vector.h"
+
+#include <gtest/gtest.h>
+#include <iostream>
+#include <cassert>
+#include <vector>
+
+using namespace std;
+
+class SVTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(SVTest, LargerThan2) {
+ SmallVector v;
+ SmallVector v2;
+ v.push_back(0);
+ v.push_back(1);
+ v.push_back(2);
+ assert(v.size() == 3);
+ assert(v[2] == 2);
+ assert(v[1] == 1);
+ assert(v[0] == 0);
+ v2 = v;
+ SmallVector copy(v);
+ assert(copy.size() == 3);
+ assert(copy[0] == 0);
+ assert(copy[1] == 1);
+ assert(copy[2] == 2);
+ assert(copy == v2);
+ copy[1] = 99;
+ assert(copy != v2);
+ assert(v2.size() == 3);
+ assert(v2[2] == 2);
+ assert(v2[1] == 1);
+ assert(v2[0] == 0);
+ v2[0] = -2;
+ v2[1] = -1;
+ v2[2] = 0;
+ assert(v2[2] == 0);
+ assert(v2[1] == -1);
+ assert(v2[0] == -2);
+ SmallVector v3(1,1);
+ assert(v3[0] == 1);
+ v2 = v3;
+ assert(v2.size() == 1);
+ assert(v2[0] == 1);
+ SmallVector v4(10, 1);
+ assert(v4.size() == 10);
+ assert(v4[5] == 1);
+ assert(v4[9] == 1);
+ v4 = v;
+ assert(v4.size() == 3);
+ assert(v4[2] == 2);
+ assert(v4[1] == 1);
+ assert(v4[0] == 0);
+ SmallVector v5(10, 2);
+ assert(v5.size() == 10);
+ assert(v5[7] == 2);
+ assert(v5[0] == 2);
+ assert(v.size() == 3);
+ v = v5;
+ assert(v.size() == 10);
+ assert(v[2] == 2);
+ assert(v[9] == 2);
+ SmallVector cc;
+ for (int i = 0; i < 33; ++i)
+ cc.push_back(i);
+ for (int i = 0; i < 33; ++i)
+ assert(cc[i] == i);
+ cc.resize(20);
+ assert(cc.size() == 20);
+ for (int i = 0; i < 20; ++i)
+ assert(cc[i] == i);
+ cc[0]=-1;
+ cc.resize(1, 999);
+ assert(cc.size() == 1);
+ assert(cc[0] == -1);
+ cc.resize(99, 99);
+ for (int i = 1; i < 99; ++i) {
+ cerr << i << " " << cc[i] << endl;
+ assert(cc[i] == 99);
+ }
+ cc.clear();
+ assert(cc.size() == 0);
+}
+
+TEST_F(SVTest, Small) {
+ SmallVector v;
+ SmallVector v1(1,0);
+ SmallVector v2(2,10);
+ SmallVector v1a(2,0);
+ EXPECT_TRUE(v1 != v1a);
+ EXPECT_TRUE(v1 == v1);
+ EXPECT_EQ(v1[0], 0);
+ EXPECT_EQ(v2[1], 10);
+ EXPECT_EQ(v2[0], 10);
+ ++v2[1];
+ --v2[0];
+ EXPECT_EQ(v2[0], 9);
+ EXPECT_EQ(v2[1], 11);
+ SmallVector v3(v2);
+ assert(v3[0] == 9);
+ assert(v3[1] == 11);
+ assert(!v3.empty());
+ assert(v3.size() == 2);
+ v3.clear();
+ assert(v3.empty());
+ assert(v3.size() == 0);
+ assert(v3 != v2);
+ assert(v2 != v3);
+ v3 = v2;
+ assert(v3 == v2);
+ assert(v2 == v3);
+ assert(v3[0] == 9);
+ assert(v3[1] == 11);
+ assert(!v3.empty());
+ assert(v3.size() == 2);
+ cerr << sizeof(SmallVector) << endl;
+ cerr << sizeof(vector<int>) << endl;
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/sparse_vector.cc b/src/sparse_vector.cc
new file mode 100644
index 00000000..4035b9ef
--- /dev/null
+++ b/src/sparse_vector.cc
@@ -0,0 +1,98 @@
+#include "sparse_vector.h"
+
+#include <iostream>
+#include <cstring>
+
+#include "hg_io.h"
+
+using namespace std;
+
+namespace B64 {
+
+void Encode(double objective, const SparseVector<double>& v, ostream* out) {
+ const int num_feats = v.num_active();
+ size_t tot_size = 0;
+ const size_t off_objective = tot_size;
+ tot_size += sizeof(double); // objective
+ const size_t off_num_feats = tot_size;
+ tot_size += sizeof(int); // num_feats
+ const size_t off_data = tot_size;
+ tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names;
+ typedef SparseVector<double>::const_iterator const_iterator;
+ for (const_iterator it = v.begin(); it != v.end(); ++it)
+ tot_size += FD::Convert(it->first).size(); // feature names;
+ tot_size += sizeof(double) * num_feats; // gradient
+ const size_t off_magic = tot_size;
+ tot_size += 4; // magic
+
+ // size_t b64_size = tot_size * 4 / 3;
+ // cerr << "Sparse vector binary size: " << tot_size << " (b64 size=" << b64_size << ")\n";
+ char* data = new char[tot_size];
+ *reinterpret_cast<double*>(&data[off_objective]) = objective;
+ *reinterpret_cast<int*>(&data[off_num_feats]) = num_feats;
+ char* cur = &data[off_data];
+ assert(cur - data == off_data);
+ for (const_iterator it = v.begin(); it != v.end(); ++it) {
+ const string& fname = FD::Convert(it->first);
+ *cur++ = static_cast<char>(fname.size()); // name len
+ memcpy(cur, &fname[0], fname.size());
+ cur += fname.size();
+ *reinterpret_cast<double*>(cur) = it->second;
+ cur += sizeof(double);
+ }
+ assert(cur - data == off_magic);
+ *reinterpret_cast<unsigned int*>(cur) = 0xBAABABBAu;
+ cur += sizeof(unsigned int);
+ assert(cur - data == tot_size);
+ b64encode(data, tot_size, out);
+ delete[] data;
+}
+
+bool Decode(double* objective, SparseVector<double>* v, const char* in, size_t size) {
+ v->clear();
+ if (size % 4 != 0) {
+ cerr << "B64 error - line % 4 != 0\n";
+ return false;
+ }
+ const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int);
+ const size_t buf_size = decoded_size + sizeof(unsigned int);
+ if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; }
+ char* data = new char[buf_size];
+ if (!b64decode(reinterpret_cast<const unsigned char*>(in), size, data, buf_size)) {
+ delete[] data;
+ return false;
+ }
+ size_t cur = 0;
+ *objective = *reinterpret_cast<double*>(data);
+ cur += sizeof(double);
+ const int num_feats = *reinterpret_cast<int*>(&data[cur]);
+ cur += sizeof(int);
+ int fc = 0;
+ while(fc < num_feats && cur < decoded_size) {
+ ++fc;
+ const int fname_len = data[cur++];
+ assert(fname_len > 0);
+ assert(fname_len < 256);
+ string fname(fname_len, '\0');
+ memcpy(&fname[0], &data[cur], fname_len);
+ cur += fname_len;
+ const double val = *reinterpret_cast<double*>(&data[cur]);
+ cur += sizeof(double);
+ int fid = FD::Convert(fname);
+ v->set_value(fid, val);
+ }
+ if(num_feats != fc) {
+ cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n";
+ delete[] data;
+ return false;
+ }
+ if (*reinterpret_cast<unsigned int*>(&data[cur]) != 0xBAABABBAu) {
+ cerr << "SparseVector decodeding error : magic does not match!\n";
+ delete[] data;
+ return false;
+ }
+ delete[] data;
+ return true;
+}
+
+}
diff --git a/src/sparse_vector.h b/src/sparse_vector.h
new file mode 100644
index 00000000..6a8c9bf4
--- /dev/null
+++ b/src/sparse_vector.h
@@ -0,0 +1,264 @@
+#ifndef _SPARSE_VECTOR_H_
+#define _SPARSE_VECTOR_H_
+
+// this is a modified version of code originally written
+// by Phil Blunsom
+
+#include <iostream>
+#include <map>
+#include <vector>
+#include <valarray>
+
+#include "fdict.h"
+
+template <typename T>
+class SparseVector {
+public:
+ SparseVector() {}
+
+ const T operator[](int index) const {
+ typename std::map<int, T>::const_iterator found = _values.find(index);
+ if (found == _values.end())
+ return T(0);
+ else
+ return found->second;
+ }
+
+ void set_value(int index, const T &value) {
+ _values[index] = value;
+ }
+
+ void add_value(int index, const T &value) {
+ _values[index] += value;
+ }
+
+ T value(int index) const {
+ typename std::map<int, T>::const_iterator found = _values.find(index);
+ if (found != _values.end())
+ return found->second;
+ else
+ return T(0);
+ }
+
+ void store(std::valarray<T>* target) const {
+ (*target) *= 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it) {
+ if (it->first >= target->size()) break;
+ (*target)[it->first] = it->second;
+ }
+ }
+
+ int max_index() const {
+ if (_values.empty()) return 0;
+ typename std::map<int, T>::const_iterator found =_values.end();
+ --found;
+ return found->first;
+ }
+
+ // dot product with a unit vector of the same length
+ // as the sparse vector
+ T dot() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second;
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const SparseVector<S> &vec) const {
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ {
+ typename std::map<int, T>::const_iterator
+ found = vec._values.find(it->first);
+ if (found != vec._values.end())
+ sum += it->second * found->second;
+ }
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const std::vector<S> &vec) const {
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ {
+ if (it->first < static_cast<int>(vec.size()))
+ sum += it->second * vec[it->first];
+ }
+ return sum;
+ }
+
+ template<typename S>
+ S dot(const S *vec) const {
+ // this is not range checked!
+ S sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second * vec[it->first];
+ std::cout << "dot(*vec) " << sum << std::endl;
+ return sum;
+ }
+
+ T l1norm() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += fabs(it->second);
+ return sum;
+ }
+
+ T l2norm() const {
+ T sum = 0;
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ sum += it->second * it->second;
+ return sqrt(sum);
+ }
+
+ SparseVector<T> &operator+=(const SparseVector<T> &other) {
+ for (typename std::map<int, T>::const_iterator
+ it = other._values.begin(); it != other._values.end(); ++it)
+ {
+ T v = (_values[it->first] += it->second);
+ if (v == 0)
+ _values.erase(it->first);
+ }
+ return *this;
+ }
+
+ SparseVector<T> &operator-=(const SparseVector<T> &other) {
+ for (typename std::map<int, T>::const_iterator
+ it = other._values.begin(); it != other._values.end(); ++it)
+ {
+ T v = (_values[it->first] -= it->second);
+ if (v == 0)
+ _values.erase(it->first);
+ }
+ return *this;
+ }
+
+ SparseVector<T> &operator-=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second -= x;
+ return *this;
+ }
+
+ SparseVector<T> &operator+=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second += x;
+ return *this;
+ }
+
+ SparseVector<T> &operator/=(const double &x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second /= x;
+ return *this;
+ }
+
+ SparseVector<T> &operator*=(const T& x) {
+ for (typename std::map<int, T>::iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ it->second *= x;
+ return *this;
+ }
+
+ SparseVector<T> operator+(const double &x) const {
+ SparseVector<T> result = *this;
+ return result += x;
+ }
+
+ SparseVector<T> operator-(const double &x) const {
+ SparseVector<T> result = *this;
+ return result -= x;
+ }
+
+ SparseVector<T> operator/(const double &x) const {
+ SparseVector<T> result = *this;
+ return result /= x;
+ }
+
+ std::ostream &operator<<(std::ostream &out) const {
+ for (typename std::map<int, T>::const_iterator
+ it = _values.begin(); it != _values.end(); ++it)
+ out << (it == _values.begin() ? "" : ";")
+ << FD::Convert(it->first) << '=' << it->second;
+ return out;
+ }
+
+ bool operator<(const SparseVector<T> &other) const {
+ typename std::map<int, T>::const_iterator it = _values.begin();
+ typename std::map<int, T>::const_iterator other_it = other._values.begin();
+
+ for (; it != _values.end() && other_it != other._values.end(); ++it, ++other_it)
+ {
+ if (it->first < other_it->first) return true;
+ if (it->first > other_it->first) return false;
+ if (it->second < other_it->second) return true;
+ if (it->second > other_it->second) return false;
+ }
+ return _values.size() < other._values.size();
+ }
+
+ int num_active() const { return _values.size(); }
+ bool empty() const { return _values.empty(); }
+
+ typedef typename std::map<int, T>::const_iterator const_iterator;
+ const_iterator begin() const { return _values.begin(); }
+ const_iterator end() const { return _values.end(); }
+
+ void clear() {
+ _values.clear();
+ }
+
+ void swap(SparseVector<T>& other) {
+ _values.swap(other._values);
+ }
+
+private:
+ std::map<int, T> _values;
+};
+
+template <typename T>
+SparseVector<T> operator+(const SparseVector<T>& a, const SparseVector<T>& b) {
+ SparseVector<T> result = a;
+ return result += b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const SparseVector<T>& a, const double& b) {
+ SparseVector<T> result = a;
+ return result *= b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const SparseVector<T>& a, const T& b) {
+ SparseVector<T> result = a;
+ return result *= b;
+}
+
+template <typename T>
+SparseVector<T> operator*(const double& a, const SparseVector<T>& b) {
+ SparseVector<T> result = b;
+ return result *= a;
+}
+
+template <typename T>
+std::ostream &operator<<(std::ostream &out, const SparseVector<T> &vec)
+{
+ return vec.operator<<(out);
+}
+
+namespace B64 {
+ void Encode(double objective, const SparseVector<double>& v, std::ostream* out);
+ // returns false if failed to decode
+ bool Decode(double* objective, SparseVector<double>* v, const char* data, size_t size);
+}
+
+#endif
diff --git a/src/stringlib.cc b/src/stringlib.cc
new file mode 100644
index 00000000..3ed74bef
--- /dev/null
+++ b/src/stringlib.cc
@@ -0,0 +1,97 @@
+#include "stringlib.h"
+
+#include <cstdlib>
+#include <cassert>
+#include <iostream>
+#include <map>
+
+#include "lattice.h"
+
+using namespace std;
+
+void ParseTranslatorInput(const string& line, string* input, string* ref) {
+ size_t hint = 0;
+ if (line.find("{\"rules\":") == 0) {
+ hint = line.find("}}");
+ if (hint == string::npos) {
+ cerr << "Syntax error: " << line << endl;
+ abort();
+ }
+ hint += 2;
+ }
+ size_t pos = line.find("|||", hint);
+ if (pos == string::npos) { *input = line; return; }
+ ref->clear();
+ *input = line.substr(0, pos - 1);
+ string rline = line.substr(pos + 4);
+ if (rline.size() > 0) {
+ assert(ref);
+ *ref = rline;
+ }
+}
+
+void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) {
+ string sref;
+ ParseTranslatorInput(line, input, &sref);
+ if (sref.size() > 0) {
+ assert(ref);
+ LatticeTools::ConvertTextOrPLF(sref, ref);
+ }
+}
+
+void ProcessAndStripSGML(string* pline, map<string, string>* out) {
+ map<string, string>& meta = *out;
+ string& line = *pline;
+ string lline = LowercaseString(line);
+ if (lline.find("<seg")!=0) return;
+ size_t close = lline.find(">");
+ if (close == string::npos) return; // error
+ size_t end = lline.find("</seg>");
+ string seg = Trim(lline.substr(4, close-4));
+ string text = line.substr(close+1, end - close - 1);
+ for (size_t i = 1; i < seg.size(); i++) {
+ if (seg[i] == '=' && seg[i-1] == ' ') {
+ string less = seg.substr(0, i-1) + seg.substr(i);
+ seg = less; i = 0; continue;
+ }
+ if (seg[i] == '=' && seg[i+1] == ' ') {
+ string less = seg.substr(0, i+1);
+ if (i+2 < seg.size()) less += seg.substr(i+2);
+ seg = less; i = 0; continue;
+ }
+ }
+ line = Trim(text);
+ if (seg == "") return;
+ for (size_t i = 1; i < seg.size(); i++) {
+ if (seg[i] == '=') {
+ string label = seg.substr(0, i);
+ string val = seg.substr(i+1);
+ if (val[0] == '"') {
+ val = val.substr(1);
+ size_t close = val.find('"');
+ if (close == string::npos) {
+ cerr << "SGML parse error: missing \"\n";
+ seg = "";
+ i = 0;
+ } else {
+ seg = val.substr(close+1);
+ val = val.substr(0, close);
+ i = 0;
+ }
+ } else {
+ size_t close = val.find(' ');
+ if (close == string::npos) {
+ seg = "";
+ i = 0;
+ } else {
+ seg = val.substr(close+1);
+ val = val.substr(0, close);
+ }
+ }
+ label = Trim(label);
+ seg = Trim(seg);
+ meta[label] = val;
+ }
+ }
+}
+
diff --git a/src/stringlib.h b/src/stringlib.h
new file mode 100644
index 00000000..d26952c7
--- /dev/null
+++ b/src/stringlib.h
@@ -0,0 +1,91 @@
+#ifndef _STRINGLIB_H_
+
+#include <map>
+#include <vector>
+#include <cctype>
+#include <string>
+
+// read line in the form of either:
+// source
+// source ||| target
+// source will be returned as a string, target must be a sentence or
+// a lattice (in PLF format) and will be returned as a Lattice object
+void ParseTranslatorInput(const std::string& line, std::string* input, std::string* ref);
+struct Lattice;
+void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref);
+
+inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") {
+ std::string res = str;
+ res.erase(str.find_last_not_of(dropChars)+1);
+ return res.erase(0, res.find_first_not_of(dropChars));
+}
+
+inline void Tokenize(const std::string& str, char delimiter, std::vector<std::string>* res) {
+ std::string s = str;
+ int last = 0;
+ res->clear();
+ for (int i=0; i < s.size(); ++i)
+ if (s[i] == delimiter) {
+ s[i]=0;
+ if (last != i) {
+ res->push_back(&s[last]);
+ }
+ last = i + 1;
+ }
+ if (last != s.size())
+ res->push_back(&s[last]);
+}
+
+inline std::string LowercaseString(const std::string& in) {
+ std::string res(in.size(),' ');
+ for (int i = 0; i < in.size(); ++i)
+ res[i] = tolower(in[i]);
+ return res;
+}
+
+inline int CountSubstrings(const std::string& str, const std::string& sub) {
+ size_t p = 0;
+ int res = 0;
+ while (p < str.size()) {
+ p = str.find(sub, p);
+ if (p == std::string::npos) break;
+ ++res;
+ p += sub.size();
+ }
+ return res;
+}
+
+inline int SplitOnWhitespace(const std::string& in, std::vector<std::string>* out) {
+ out->clear();
+ int i = 0;
+ int start = 0;
+ std::string cur;
+ while(i < in.size()) {
+ if (in[i] == ' ' || in[i] == '\t') {
+ if (i - start > 0)
+ out->push_back(in.substr(start, i - start));
+ start = i + 1;
+ }
+ ++i;
+ }
+ if (i > start)
+ out->push_back(in.substr(start, i - start));
+ return out->size();
+}
+
+inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) {
+ cmd->clear();
+ param->clear();
+ std::vector<std::string> x;
+ SplitOnWhitespace(in, &x);
+ if (x.size() == 0) return;
+ *cmd = x[0];
+ for (int i = 1; i < x.size(); ++i) {
+ if (i > 1) { *param += " "; }
+ *param += x[i];
+ }
+}
+
+void ProcessAndStripSGML(std::string* line, std::map<std::string, std::string>* out);
+
+#endif
diff --git a/src/synparse.cc b/src/synparse.cc
new file mode 100644
index 00000000..96588f1e
--- /dev/null
+++ b/src/synparse.cc
@@ -0,0 +1,212 @@
+#include <iostream>
+#include <ext/hash_map>
+#include <ext/hash_set>
+#include <utility>
+
+#include <boost/multi_array.hpp>
+#include <boost/functional/hash.hpp>
+#include <boost/program_options.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+#include "prob.h"
+#include "tdict.h"
+#include "filelib.h"
+
+using namespace std;
+using namespace __gnu_cxx;
+namespace po = boost::program_options;
+
+const prob_t kMONO(1.0); // 0.6
+const prob_t kINV(1.0); // 0.1
+const prob_t kLEX(1.0); // 0.3
+
+typedef hash_map<vector<WordID>, hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >, boost::hash<vector<WordID> > > PTable;
+typedef boost::multi_array<prob_t, 4> CChart;
+typedef pair<int,int> SpanType;
+
+void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
+ po::options_description opts("Configuration options");
+ opts.add_options()
+ ("phrasetable,p",po::value<string>(), "[REQD] Phrase pairs for ITG alignment")
+ ("input,i",po::value<string>()->default_value("-"), "Input file")
+ ("help,h", "Help");
+ po::options_description dcmdline_options;
+ dcmdline_options.add(opts);
+ po::store(parse_command_line(argc, argv, dcmdline_options), *conf);
+ bool flag = false;
+ if (!conf->count("phrasetable")) {
+ cerr << "Please specify a grammar file with -p <GRAMMAR.TXT>\n";
+ flag = true;
+ }
+ if (flag || conf->count("help")) {
+ cerr << dcmdline_options << endl;
+ exit(1);
+ }
+}
+
+void LoadITGPhrasetable(const string& fname, PTable* ptable) {
+ const WordID sep = TD::Convert("|||");
+ ReadFile rf(fname);
+ istream& in = *rf.stream();
+ assert(in);
+ int lc = 0;
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ ++lc;
+ vector<WordID> full, f, e;
+ TD::ConvertSentence(line, &full);
+ int i = 0;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ f.push_back(full[i]);
+ }
+ ++i;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ e.push_back(full[i]);
+ }
+ ++i;
+ prob_t prob(0.000001);
+ if (i < full.size()) { prob = prob_t(atof(TD::Convert(full[i]))); ++i; }
+
+ if (i < full.size()) { cerr << "Warning line " << lc << " has extra stuff.\n"; }
+ assert(f.size() > 0);
+ assert(e.size() > 0);
+ (*ptable)[f][e] = prob;
+ }
+ cerr << "Read " << lc << " phrase pairs\n";
+}
+
+void FindPhrases(const vector<WordID>& e, const vector<WordID>& f, const PTable& pt, CChart* pcc) {
+ CChart& cc = *pcc;
+ const size_t n = f.size();
+ const size_t m = e.size();
+ typedef hash_map<vector<WordID>, vector<SpanType>, boost::hash<vector<WordID> > > PhraseToSpan;
+ PhraseToSpan e_locations;
+ for (int i = 0; i < m; ++i) {
+ const int mel = m - i;
+ vector<WordID> e_phrase;
+ for (int el = 0; el < mel; ++el) {
+ e_phrase.push_back(e[i + el]);
+ e_locations[e_phrase].push_back(make_pair(i, i + el + 1));
+ }
+ }
+ //cerr << "Cached the locations of " << e_locations.size() << " e-phrases\n";
+
+ for (int s = 0; s < n; ++s) {
+ const int mfl = n - s;
+ vector<WordID> f_phrase;
+ for (int fl = 0; fl < mfl; ++fl) {
+ f_phrase.push_back(f[s + fl]);
+ PTable::const_iterator it = pt.find(f_phrase);
+ if (it == pt.end()) continue;
+ const hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >& es = it->second;
+ for (hash_map<vector<WordID>, prob_t, boost::hash<vector<WordID> > >::const_iterator eit = es.begin(); eit != es.end(); ++eit) {
+ PhraseToSpan::iterator loc = e_locations.find(eit->first);
+ if (loc == e_locations.end()) continue;
+ const vector<SpanType>& espans = loc->second;
+ for (int j = 0; j < espans.size(); ++j) {
+ cc[s][s + fl + 1][espans[j].first][espans[j].second] = eit->second;
+ //cerr << '[' << s << ',' << (s + fl + 1) << ',' << espans[j].first << ',' << espans[j].second << "] is C\n";
+ }
+ }
+ }
+ }
+}
+
+long long int evals = 0;
+
+void ProcessSynchronousCell(const int s,
+ const int t,
+ const int u,
+ const int v,
+ const prob_t& lex,
+ const prob_t& mono,
+ const prob_t& inv,
+ const CChart& tc, CChart* ntc) {
+ prob_t& inside = (*ntc)[s][t][u][v];
+ // cerr << log(tc[s][t][u][v]) << " + " << log(lex) << endl;
+ inside = tc[s][t][u][v] * lex;
+ // cerr << " terminal span: " << log(inside) << endl;
+ if (t - s == 1) return;
+ if (v - u == 1) return;
+ for (int x = s+1; x < t; ++x) {
+ for (int y = u+1; y < v; ++y) {
+ const prob_t m = (*ntc)[s][x][u][y] * (*ntc)[x][t][y][v] * mono;
+ const prob_t i = (*ntc)[s][x][y][v] * (*ntc)[x][t][u][y] * inv;
+ // cerr << log(i) << "\t" << log(m) << endl;
+ inside += m;
+ inside += i;
+ evals++;
+ }
+ }
+ // cerr << " span: " << log(inside) << endl;
+}
+
+prob_t SynchronousParse(const int n, const int m, const prob_t& lex, const prob_t& mono, const prob_t& inv, const CChart& tc, CChart* ntc) {
+ for (int fl = 0; fl < n; ++fl) {
+ for (int el = 0; el < m; ++el) {
+ const int ms = n - fl;
+ for (int s = 0; s < ms; ++s) {
+ const int t = s + fl + 1;
+ const int mu = m - el;
+ for (int u = 0; u < mu; ++u) {
+ const int v = u + el + 1;
+ //cerr << "Processing cell [" << s << ',' << t << ',' << u << ',' << v << "]\n";
+ ProcessSynchronousCell(s, t, u, v, lex, mono, inv, tc, ntc);
+ }
+ }
+ }
+ }
+ return (*ntc)[0][n][0][m];
+}
+
+int main(int argc, char** argv) {
+ po::variables_map conf;
+ InitCommandLine(argc, argv, &conf);
+ PTable ptable;
+ LoadITGPhrasetable(conf["phrasetable"].as<string>(), &ptable);
+ ReadFile rf(conf["input"].as<string>());
+ istream& in = *rf.stream();
+ int lc = 0;
+ const WordID sep = TD::Convert("|||");
+ while(in) {
+ string line;
+ getline(in, line);
+ if (line.empty()) continue;
+ ++lc;
+ vector<WordID> full, f, e;
+ TD::ConvertSentence(line, &full);
+ int i = 0;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ f.push_back(full[i]);
+ }
+ ++i;
+ for (; i < full.size(); ++i) {
+ if (full[i] == sep) break;
+ e.push_back(full[i]);
+ }
+ if (e.empty()) cerr << "E is empty!\n";
+ if (f.empty()) cerr << "F is empty!\n";
+ if (e.empty() || f.empty()) continue;
+ int n = f.size();
+ int m = e.size();
+ cerr << "Synchronous chart has " << (n * n * m * m) << " cells\n";
+ clock_t start = clock();
+ CChart cc(boost::extents[n+1][n+1][m+1][m+1]);
+ FindPhrases(e, f, ptable, &cc);
+ CChart ntc(boost::extents[n+1][n+1][m+1][m+1]);
+ prob_t likelihood = SynchronousParse(n, m, kLEX, kMONO, kINV, cc, &ntc);
+ clock_t end = clock();
+ cerr << "log Z: " << log(likelihood) << endl;
+ cerr << " Z: " << likelihood << endl;
+ double etime = (end - start) / 1000000.0;
+ cout << " time: " << etime << endl;
+ cout << "evals: " << evals << endl;
+ }
+ return 0;
+}
+
diff --git a/src/tdict.cc b/src/tdict.cc
new file mode 100644
index 00000000..c00d20b8
--- /dev/null
+++ b/src/tdict.cc
@@ -0,0 +1,49 @@
+#include "Ngram.h"
+#include "dict.h"
+#include "tdict.h"
+#include "Vocab.h"
+
+using namespace std;
+
+Vocab* TD::dict_ = new Vocab;
+
+static const string empty;
+static const string space = " ";
+
+WordID TD::Convert(const std::string& s) {
+ return dict_->addWord((VocabString)s.c_str());
+}
+
+const char* TD::Convert(const WordID& w) {
+ return dict_->getWord((VocabIndex)w);
+}
+
+void TD::GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids) {
+ ids->clear();
+ for (vector<string>::const_iterator i = strings.begin(); i != strings.end(); ++i)
+ ids->push_back(TD::Convert(*i));
+}
+
+std::string TD::GetString(const std::vector<WordID>& str) {
+ string res;
+ for (vector<WordID>::const_iterator i = str.begin(); i != str.end(); ++i)
+ res += (i == str.begin() ? empty : space) + TD::Convert(*i);
+ return res;
+}
+
+void TD::ConvertSentence(const std::string& sent, std::vector<WordID>* ids) {
+ string s = sent;
+ int last = 0;
+ ids->clear();
+ for (int i=0; i < s.size(); ++i)
+ if (s[i] == 32 || s[i] == '\t') {
+ s[i]=0;
+ if (last != i) {
+ ids->push_back(Convert(&s[last]));
+ }
+ last = i + 1;
+ }
+ if (last != s.size())
+ ids->push_back(Convert(&s[last]));
+}
+
diff --git a/src/tdict.h b/src/tdict.h
new file mode 100644
index 00000000..9d4318fe
--- /dev/null
+++ b/src/tdict.h
@@ -0,0 +1,19 @@
+#ifndef _TDICT_H_
+#define _TDICT_H_
+
+#include <string>
+#include <vector>
+#include "wordid.h"
+
+class Vocab;
+
+struct TD {
+ static Vocab* dict_;
+ static void ConvertSentence(const std::string& sent, std::vector<WordID>* ids);
+ static void GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids);
+ static std::string GetString(const std::vector<WordID>& str);
+ static WordID Convert(const std::string& s);
+ static const char* Convert(const WordID& w);
+};
+
+#endif
diff --git a/src/timing_stats.cc b/src/timing_stats.cc
new file mode 100644
index 00000000..85b95de5
--- /dev/null
+++ b/src/timing_stats.cc
@@ -0,0 +1,24 @@
+#include "timing_stats.h"
+
+#include <iostream>
+
+using namespace std;
+
+map<string, TimerInfo> Timer::stats;
+
+Timer::Timer(const string& timername) : start_t(clock()), cur(stats[timername]) {}
+
+Timer::~Timer() {
+ ++cur.calls;
+ const clock_t end_t = clock();
+ const double elapsed = (end_t - start_t) / 1000000.0;
+ cur.total_time += elapsed;
+}
+
+void Timer::Summarize() {
+ for (map<string, TimerInfo>::iterator it = stats.begin(); it != stats.end(); ++it) {
+ cerr << it->first << ": " << it->second.total_time << " secs (" << it->second.calls << " calls)\n";
+ }
+ stats.clear();
+}
+
diff --git a/src/timing_stats.h b/src/timing_stats.h
new file mode 100644
index 00000000..0a9f7656
--- /dev/null
+++ b/src/timing_stats.h
@@ -0,0 +1,25 @@
+#ifndef _TIMING_STATS_H_
+#define _TIMING_STATS_H_
+
+#include <string>
+#include <map>
+
+struct TimerInfo {
+ int calls;
+ double total_time;
+ TimerInfo() : calls(), total_time() {}
+};
+
+struct Timer {
+ Timer(const std::string& info);
+ ~Timer();
+ static void Summarize();
+ private:
+ static std::map<std::string, TimerInfo> stats;
+ clock_t start_t;
+ TimerInfo& cur;
+ Timer(const Timer& other);
+ const Timer& operator=(const Timer& other);
+};
+
+#endif
diff --git a/src/translator.h b/src/translator.h
new file mode 100644
index 00000000..194efbaa
--- /dev/null
+++ b/src/translator.h
@@ -0,0 +1,54 @@
+#ifndef _TRANSLATOR_H_
+#define _TRANSLATOR_H_
+
+#include <string>
+#include <vector>
+#include <boost/shared_ptr.hpp>
+#include <boost/program_options/variables_map.hpp>
+
+class Hypergraph;
+class SentenceMetadata;
+
+class Translator {
+ public:
+ virtual ~Translator();
+ // returns true if goal reached, false otherwise
+ // minus_lm_forest will contain the unpruned forest. the
+ // feature values from the phrase table / grammar / etc
+ // should be in the forest already - the "late" features
+ // should not just copy values that are available without
+ // any context or computation.
+ // SentenceMetadata contains information about the sentence,
+ // but it is an input/output parameter since the Translator
+ // is also responsible for setting the value of src_len.
+ virtual bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest) = 0;
+};
+
+class SCFGTranslatorImpl;
+class SCFGTranslator : public Translator {
+ public:
+ SCFGTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<SCFGTranslatorImpl> pimpl_;
+};
+
+class FSTTranslatorImpl;
+class FSTTranslator : public Translator {
+ public:
+ FSTTranslator(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<FSTTranslatorImpl> pimpl_;
+};
+
+#endif
diff --git a/src/trule.cc b/src/trule.cc
new file mode 100644
index 00000000..b8f6995e
--- /dev/null
+++ b/src/trule.cc
@@ -0,0 +1,237 @@
+#include "trule.h"
+
+#include <sstream>
+
+#include "stringlib.h"
+#include "tdict.h"
+
+using namespace std;
+
+static WordID ConvertTrgString(const string& w) {
+ int len = w.size();
+ WordID id = 0;
+ // [X,0] or [0]
+ // for target rules, we ignore the category, just keep the index
+ if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' &&
+ (len == 3 || (len > 4 && w[len-3] == ','))) {
+ id = w[len-2] - '0';
+ id = 1 - id;
+ } else {
+ id = TD::Convert(w);
+ }
+ return id;
+}
+
+static WordID ConvertSrcString(const string& w, bool mono = false) {
+ int len = w.size();
+ // [X,0]
+ // for source rules, we keep the category and ignore the index (source rules are
+ // always numbered 1, 2, 3...
+ if (mono) {
+ if (len > 2 && w[0]=='[' && w[len-1]==']') {
+ if (len > 4 && w[len-3] == ',') {
+ cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n "
+ << w << endl;
+ exit(1);
+ }
+ // TODO check that source indices go 1,2,3,etc.
+ return TD::Convert(w.substr(1, len-2)) * -1;
+ } else {
+ return TD::Convert(w);
+ }
+ } else {
+ if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') {
+ return TD::Convert(w.substr(1, len-4)) * -1;
+ } else {
+ return TD::Convert(w);
+ }
+ }
+}
+
+static WordID ConvertLHS(const string& w) {
+ if (w[0] == '[') {
+ int len = w.size();
+ if (len < 3) { cerr << "Format error: " << w << endl; exit(1); }
+ return TD::Convert(w.substr(1, len-2)) * -1;
+ } else {
+ return TD::Convert(w) * -1;
+ }
+}
+
+TRule* TRule::CreateRuleSynchronous(const std::string& rule) {
+ TRule* res = new TRule;
+ if (res->ReadFromString(rule, true, false)) return res;
+ cerr << "[ERROR] Failed to creating rule from: " << rule << endl;
+ delete res;
+ return NULL;
+}
+
+TRule* TRule::CreateRulePhrasetable(const string& rule) {
+ // TODO make this faster
+ // TODO add configuration for default NT type
+ if (rule[0] == '[') {
+ cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl;
+ return NULL;
+ }
+ TRule* res = new TRule("[X] ||| " + rule, true, false);
+ if (res->Arity() != 0) {
+ cerr << "Phrasetable rules should have arity 0:\n " << rule << endl;
+ delete res;
+ return NULL;
+ }
+ return res;
+}
+
+TRule* TRule::CreateRuleMonolingual(const string& rule) {
+ return new TRule(rule, false, true);
+}
+
+bool TRule::ReadFromString(const string& line, bool strict, bool mono) {
+ e_.clear();
+ f_.clear();
+ scores_.clear();
+
+ string w;
+ istringstream is(line);
+ int format = CountSubstrings(line, "|||");
+ if (strict && format < 2) {
+ cerr << "Bad rule format in strict mode:\n" << line << endl;
+ return false;
+ }
+ if (format >= 2 || (mono && format == 1)) {
+ while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); }
+ while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); }
+ if (!mono) {
+ while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); }
+ }
+ int fv = 0;
+ if (is) {
+ string ss;
+ getline(is, ss);
+ //cerr << "L: " << ss << endl;
+ int start = 0;
+ const int len = ss.size();
+ while (start < len) {
+ while(start < len && (ss[start] == ' ' || ss[start] == ';'))
+ ++start;
+ if (start == len) break;
+ int end = start + 1;
+ while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';'))
+ ++end;
+ if (end == len || ss[end] == ' ' || ss[end] == ';') {
+ //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n";
+ // non-named features
+ if (end != len) { ss[end] = 0; }
+ string fname = "PhraseModel_X";
+ if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); }
+ fname[12]='0' + fv;
+ ++fv;
+ scores_.set_value(FD::Convert(fname), atof(&ss[start]));
+ //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl;
+ } else {
+ const int fid = FD::Convert(ss.substr(start, end - start));
+ start = end + 1;
+ end = start + 1;
+ while(end < len && (ss[end] != ' ' && ss[end] != ';'))
+ ++end;
+ if (end < len) { ss[end] = 0; }
+ assert(start < len);
+ scores_.set_value(fid, atof(&ss[start]));
+ //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl;
+ }
+ start = end + 1;
+ }
+ }
+ } else if (format == 1) {
+ while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); }
+ while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); }
+ f_ = e_;
+ int x = ConvertLHS("[X]");
+ for (int i = 0; i < f_.size(); ++i)
+ if (f_[i] <= 0) { f_[i] = x; }
+ } else {
+ cerr << "F: " << format << endl;
+ cerr << "[ERROR] Don't know how to read:\n" << line << endl;
+ }
+ if (mono) {
+ e_ = f_;
+ int ci = 0;
+ for (int i = 0; i < e_.size(); ++i)
+ if (e_[i] < 0)
+ e_[i] = ci--;
+ }
+ ComputeArity();
+ return SanityCheck();
+}
+
+bool TRule::SanityCheck() const {
+ vector<int> used(f_.size(), 0);
+ int ac = 0;
+ for (int i = 0; i < e_.size(); ++i) {
+ int ind = e_[i];
+ if (ind > 0) continue;
+ ind = -ind;
+ if ((++used[ind]) != 1) {
+ cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n";
+ return false;
+ }
+ ac++;
+ }
+ if (ac != Arity()) {
+ cerr << "[ERROR] e-side arity mismatches f-side\n";
+ return false;
+ }
+ return true;
+}
+
+void TRule::ComputeArity() {
+ int min = 1;
+ for (vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i)
+ if (*i < min) min = *i;
+ arity_ = 1 - min;
+}
+
+static string AnonymousStrVar(int i) {
+ string res("[v]");
+ if(!(i <= 0 && i >= -8)) {
+ cerr << "Can't handle more than 9 non-terminals: index=" << (-i) << endl;
+ abort();
+ }
+ res[1] = '1' - i;
+ return res;
+}
+
+string TRule::AsString(bool verbose) const {
+ ostringstream os;
+ int idx = 0;
+ if (lhs_ && verbose) {
+ os << '[' << TD::Convert(lhs_ * -1) << "] |||";
+ for (int i = 0; i < f_.size(); ++i) {
+ const WordID& w = f_[i];
+ if (w < 0) {
+ int wi = w * -1;
+ ++idx;
+ os << " [" << TD::Convert(wi) << ',' << idx << ']';
+ } else {
+ os << ' ' << TD::Convert(w);
+ }
+ }
+ os << " ||| ";
+ }
+ if (idx > 9) {
+ cerr << "Too many non-terminals!\n partial: " << os.str() << endl;
+ exit(1);
+ }
+ for (int i =0; i<e_.size(); ++i) {
+ if (i) os << ' ';
+ const WordID& w = e_[i];
+ if (w < 1)
+ os << AnonymousStrVar(w);
+ else
+ os << TD::Convert(w);
+ }
+ if (!scores_.empty() && verbose) {
+ os << " ||| " << scores_;
+ }
+ return os.str();
+}
diff --git a/src/trule.h b/src/trule.h
new file mode 100644
index 00000000..d2b1babe
--- /dev/null
+++ b/src/trule.h
@@ -0,0 +1,122 @@
+#ifndef _RULE_H_
+#define _RULE_H_
+
+#include <algorithm>
+#include <vector>
+#include <cassert>
+#include <boost/shared_ptr.hpp>
+
+#include "sparse_vector.h"
+#include "wordid.h"
+
+class TRule;
+typedef boost::shared_ptr<TRule> TRulePtr;
+struct SpanInfo;
+
+// Translation rule
+class TRule {
+ public:
+ TRule() : lhs_(0), prev_i(-1), prev_j(-1) { }
+ explicit TRule(const std::vector<WordID>& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {}
+ TRule(const std::vector<WordID>& e, const std::vector<WordID>& f, const WordID& lhs) :
+ e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {}
+
+ // deprecated - this will be private soon
+ explicit TRule(const std::string& text, bool strict = false, bool mono = false) {
+ ReadFromString(text, strict, mono);
+ }
+
+ // make a rule from a hiero-like rule table, e.g.
+ // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1]
+ // if misformatted, returns NULL
+ static TRule* CreateRuleSynchronous(const std::string& rule);
+
+ // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g:
+ // el gato ||| the cat ||| Feature_2=0.34
+ static TRule* CreateRulePhrasetable(const std::string& rule);
+
+ // make a rule from a non-synchrnous CFG representation, e.g.:
+ // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT]
+ static TRule* CreateRuleMonolingual(const std::string& rule);
+
+ void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = e_.begin(); i != e_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ ++vc;
+ const std::vector<WordID>& var_value = *var_values[-c];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ void FSubstitute(const std::vector<const std::vector<WordID>* >& var_values,
+ std::vector<WordID>* result) const {
+ int vc = 0;
+ result->clear();
+ for (std::vector<WordID>::const_iterator i = f_.begin(); i != f_.end(); ++i) {
+ const WordID& c = *i;
+ if (c < 1) {
+ const std::vector<WordID>& var_value = *var_values[vc++];
+ std::copy(var_value.begin(),
+ var_value.end(),
+ std::back_inserter(*result));
+ } else {
+ result->push_back(c);
+ }
+ }
+ assert(vc == var_values.size());
+ }
+
+ bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false);
+
+ bool Initialized() const { return e_.size(); }
+
+ std::string AsString(bool verbose = true) const;
+
+ static TRule DummyRule() {
+ TRule res;
+ res.e_.resize(1, 0);
+ return res;
+ }
+
+ const std::vector<WordID>& f() const { return f_; }
+ const std::vector<WordID>& e() const { return e_; }
+
+ int EWords() const { return ELength() - Arity(); }
+ int FWords() const { return FLength() - Arity(); }
+ int FLength() const { return f_.size(); }
+ int ELength() const { return e_.size(); }
+ int Arity() const { return arity_; }
+ bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); }
+ const SparseVector<double>& GetFeatureValues() const { return scores_; }
+ double Score(int i) const { return scores_[i]; }
+ WordID GetLHS() const { return lhs_; }
+ void ComputeArity();
+
+ // 0 = first variable, -1 = second variable, -2 = third ...
+ std::vector<WordID> e_;
+ // < 0: *-1 = encoding of category of variable
+ std::vector<WordID> f_;
+ WordID lhs_;
+ SparseVector<double> scores_;
+ char arity_;
+ TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding
+
+ // this is only used when doing synchronous parsing
+ short int prev_i;
+ short int prev_j;
+
+ private:
+ bool SanityCheck() const;
+};
+
+#endif
diff --git a/src/trule_test.cc b/src/trule_test.cc
new file mode 100644
index 00000000..02a70764
--- /dev/null
+++ b/src/trule_test.cc
@@ -0,0 +1,65 @@
+#include "trule.h"
+
+#include <gtest/gtest.h>
+#include <cassert>
+#include <iostream>
+#include "tdict.h"
+
+using namespace std;
+
+class TRuleTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+TEST_F(TRuleTest,TestFSubstitute) {
+ TRule r1("[X] ||| ob [X,1] [X,2] sah . ||| whether [X,1] saw [X,2] . ||| 0.99");
+ TRule r2("[X] ||| ich ||| i ||| 1.0");
+ TRule r3("[X] ||| ihn ||| him ||| 1.0");
+ vector<const vector<WordID>*> ants;
+ vector<WordID> res2;
+ r2.FSubstitute(ants, &res2);
+ assert(TD::GetString(res2) == "ich");
+ vector<WordID> res3;
+ r3.FSubstitute(ants, &res3);
+ assert(TD::GetString(res3) == "ihn");
+ ants.push_back(&res2);
+ ants.push_back(&res3);
+ vector<WordID> res;
+ r1.FSubstitute(ants, &res);
+ cerr << TD::GetString(res) << endl;
+ assert(TD::GetString(res) == "ob ich ihn sah .");
+}
+
+TEST_F(TRuleTest,TestPhrasetableRule) {
+ TRulePtr t(TRule::CreateRulePhrasetable("gato ||| cat ||| PhraseModel_0=-23.2;Foo=1;Bar=12"));
+ cerr << t->AsString() << endl;
+ assert(t->scores_.num_active() == 3);
+};
+
+
+TEST_F(TRuleTest,TestMonoRule) {
+ TRulePtr m(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3]"));
+ assert(m->Arity() == 3);
+ cerr << m->AsString() << endl;
+ TRulePtr m2(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3] ||| Feature1=0.23"));
+ assert(m2->Arity() == 3);
+ cerr << m2->AsString() << endl;
+ EXPECT_FLOAT_EQ(m2->scores_.value(FD::Convert("Feature1")), 0.23);
+}
+
+TEST_F(TRuleTest,TestRuleR) {
+ TRule t6;
+ t6.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [X,2] saw the [X,1] . ||| 0.12321 0.23232 0.121");
+ cerr << "TEXT: " << t6.AsString() << endl;
+ EXPECT_EQ(t6.Arity(), 2);
+ EXPECT_EQ(t6.e_[0], -1);
+ EXPECT_EQ(t6.e_[3], 0);
+}
+
+int main(int argc, char** argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
+
diff --git a/src/ttables.cc b/src/ttables.cc
new file mode 100644
index 00000000..2ea960f0
--- /dev/null
+++ b/src/ttables.cc
@@ -0,0 +1,31 @@
+#include "ttables.h"
+
+#include <cassert>
+
+#include "dict.h"
+
+using namespace std;
+using namespace std::tr1;
+
+void TTable::DeserializeProbsFromText(std::istream* in) {
+ int c = 0;
+ while(*in) {
+ string e;
+ string f;
+ double p;
+ (*in) >> e >> f >> p;
+ if (e.empty()) break;
+ ++c;
+ ttable[TD::Convert(e)][TD::Convert(f)] = prob_t(p);
+ }
+ cerr << "Loaded " << c << " translation parameters.\n";
+}
+
+void TTable::SerializeHelper(string* out, const Word2Word2Double& o) {
+ assert(!"not implemented");
+}
+
+void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) {
+ assert(!"not implemented");
+}
+
diff --git a/src/ttables.h b/src/ttables.h
new file mode 100644
index 00000000..3ffc238a
--- /dev/null
+++ b/src/ttables.h
@@ -0,0 +1,87 @@
+#ifndef _TTABLES_H_
+#define _TTABLES_H_
+
+#include <iostream>
+#include <map>
+
+#include "wordid.h"
+#include "prob.h"
+#include "tdict.h"
+
+class TTable {
+ public:
+ TTable() {}
+ typedef std::map<WordID, double> Word2Double;
+ typedef std::map<WordID, Word2Double> Word2Word2Double;
+ inline const prob_t prob(const int& e, const int& f) const {
+ const Word2Word2Double::const_iterator cit = ttable.find(e);
+ if (cit != ttable.end()) {
+ const Word2Double& cpd = cit->second;
+ const Word2Double::const_iterator it = cpd.find(f);
+ if (it == cpd.end()) return prob_t(0.00001);
+ return prob_t(it->second);
+ } else {
+ return prob_t(0.00001);
+ }
+ }
+ inline void Increment(const int& e, const int& f) {
+ counts[e][f] += 1.0;
+ }
+ inline void Increment(const int& e, const int& f, double x) {
+ counts[e][f] += x;
+ }
+ void Normalize() {
+ ttable.swap(counts);
+ for (Word2Word2Double::iterator cit = ttable.begin();
+ cit != ttable.end(); ++cit) {
+ double tot = 0;
+ Word2Double& cpd = cit->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ tot += it->second;
+ for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it)
+ it->second /= tot;
+ }
+ counts.clear();
+ }
+ // adds counts from another TTable - probabilities remain unchanged
+ TTable& operator+=(const TTable& rhs) {
+ for (Word2Word2Double::const_iterator it = rhs.counts.begin();
+ it != rhs.counts.end(); ++it) {
+ const Word2Double& cpd = it->second;
+ Word2Double& tgt = counts[it->first];
+ for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ tgt[j->first] += j->second;
+ }
+ }
+ return *this;
+ }
+ void ShowTTable() {
+ for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void ShowCounts() {
+ for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) {
+ Word2Double& cpd = it->second;
+ for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) {
+ std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl;
+ }
+ }
+ }
+ void DeserializeProbsFromText(std::istream* in);
+ void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); }
+ void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); }
+ void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); }
+ void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); }
+ private:
+ static void SerializeHelper(std::string*, const Word2Word2Double& o);
+ static void DeserializeHelper(const std::string&, Word2Word2Double* o);
+ public:
+ Word2Word2Double ttable;
+ Word2Word2Double counts;
+};
+
+#endif
diff --git a/src/viterbi.cc b/src/viterbi.cc
new file mode 100644
index 00000000..82b2ce6d
--- /dev/null
+++ b/src/viterbi.cc
@@ -0,0 +1,39 @@
+#include "viterbi.h"
+
+#include <vector>
+#include "hg.h"
+
+using namespace std;
+
+string ViterbiETree(const Hypergraph& hg) {
+ vector<WordID> tmp;
+ const prob_t p = Viterbi<vector<WordID>, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp);
+ return TD::GetString(tmp);
+}
+
+string ViterbiFTree(const Hypergraph& hg) {
+ vector<WordID> tmp;
+ const prob_t p = Viterbi<vector<WordID>, FTreeTraversal, prob_t, EdgeProb>(hg, &tmp);
+ return TD::GetString(tmp);
+}
+
+prob_t ViterbiESentence(const Hypergraph& hg, vector<WordID>* result) {
+ return Viterbi<vector<WordID>, ESentenceTraversal, prob_t, EdgeProb>(hg, result);
+}
+
+prob_t ViterbiFSentence(const Hypergraph& hg, vector<WordID>* result) {
+ return Viterbi<vector<WordID>, FSentenceTraversal, prob_t, EdgeProb>(hg, result);
+}
+
+int ViterbiELength(const Hypergraph& hg) {
+ int len = -1;
+ Viterbi<int, ELengthTraversal, prob_t, EdgeProb>(hg, &len);
+ return len;
+}
+
+int ViterbiPathLength(const Hypergraph& hg) {
+ int len = -1;
+ Viterbi<int, PathLengthTraversal, prob_t, EdgeProb>(hg, &len);
+ return len;
+}
+
diff --git a/src/viterbi.h b/src/viterbi.h
new file mode 100644
index 00000000..46a4f528
--- /dev/null
+++ b/src/viterbi.h
@@ -0,0 +1,130 @@
+#ifndef _VITERBI_H_
+#define _VITERBI_H_
+
+#include <vector>
+#include "prob.h"
+#include "hg.h"
+#include "tdict.h"
+
+// V must implement:
+// void operator()(const vector<const T*>& ants, T* result);
+template<typename T, typename Traversal, typename WeightType, typename WeightFunction>
+WeightType Viterbi(const Hypergraph& hg,
+ T* result,
+ const Traversal& traverse = Traversal(),
+ const WeightFunction& weight = WeightFunction()) {
+ const int num_nodes = hg.nodes_.size();
+ std::vector<T> vit_result(num_nodes);
+ std::vector<WeightType> vit_weight(num_nodes, WeightType::Zero());
+
+ for (int i = 0; i < num_nodes; ++i) {
+ const Hypergraph::Node& cur_node = hg.nodes_[i];
+ WeightType* const cur_node_best_weight = &vit_weight[i];
+ T* const cur_node_best_result = &vit_result[i];
+
+ const int num_in_edges = cur_node.in_edges_.size();
+ if (num_in_edges == 0) {
+ *cur_node_best_weight = WeightType(1);
+ continue;
+ }
+ for (int j = 0; j < num_in_edges; ++j) {
+ const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]];
+ WeightType score = weight(edge);
+ std::vector<const T*> ants(edge.tail_nodes_.size());
+ for (int k = 0; k < edge.tail_nodes_.size(); ++k) {
+ const int tail_node_index = edge.tail_nodes_[k];
+ score *= vit_weight[tail_node_index];
+ ants[k] = &vit_result[tail_node_index];
+ }
+ if (*cur_node_best_weight < score) {
+ *cur_node_best_weight = score;
+ traverse(edge, ants, cur_node_best_result);
+ }
+ }
+ }
+ std::swap(*result, vit_result.back());
+ return vit_weight.back();
+}
+
+struct PathLengthTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const int*>& ants,
+ int* result) const {
+ (void) edge;
+ *result = 1;
+ for (int i = 0; i < ants.size(); ++i) *result += *ants[i];
+ }
+};
+
+struct ESentenceTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ edge.rule_->ESubstitute(ants, result);
+ }
+};
+
+struct ELengthTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const int*>& ants,
+ int* result) const {
+ *result = edge.rule_->ELength() - edge.rule_->Arity();
+ for (int i = 0; i < ants.size(); ++i) *result += *ants[i];
+ }
+};
+
+struct FSentenceTraversal {
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ edge.rule_->FSubstitute(ants, result);
+ }
+};
+
+// create a strings of the form (S (X the man) (X said (X he (X would (X go)))))
+struct ETreeTraversal {
+ ETreeTraversal() : left("("), space(" "), right(")") {}
+ const std::string left;
+ const std::string space;
+ const std::string right;
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ std::vector<WordID> tmp;
+ edge.rule_->ESubstitute(ants, &tmp);
+ const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
+ if (cat == "Goal")
+ result->swap(tmp);
+ else
+ TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
+ result);
+ }
+};
+
+struct FTreeTraversal {
+ FTreeTraversal() : left("("), space(" "), right(")") {}
+ const std::string left;
+ const std::string space;
+ const std::string right;
+ void operator()(const Hypergraph::Edge& edge,
+ const std::vector<const std::vector<WordID>*>& ants,
+ std::vector<WordID>* result) const {
+ std::vector<WordID> tmp;
+ edge.rule_->FSubstitute(ants, &tmp);
+ const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1);
+ if (cat == "Goal")
+ result->swap(tmp);
+ else
+ TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right,
+ result);
+ }
+};
+
+prob_t ViterbiESentence(const Hypergraph& hg, std::vector<WordID>* result);
+std::string ViterbiETree(const Hypergraph& hg);
+prob_t ViterbiFSentence(const Hypergraph& hg, std::vector<WordID>* result);
+std::string ViterbiFTree(const Hypergraph& hg);
+int ViterbiELength(const Hypergraph& hg);
+int ViterbiPathLength(const Hypergraph& hg);
+
+#endif
diff --git a/src/weights.cc b/src/weights.cc
new file mode 100644
index 00000000..bb0a878f
--- /dev/null
+++ b/src/weights.cc
@@ -0,0 +1,73 @@
+#include "weights.h"
+
+#include <sstream>
+
+#include "fdict.h"
+#include "filelib.h"
+
+using namespace std;
+
+void Weights::InitFromFile(const std::string& filename, vector<string>* feature_list) {
+ cerr << "Reading weights from " << filename << endl;
+ ReadFile in_file(filename);
+ istream& in = *in_file.stream();
+ assert(in);
+ int weight_count = 0;
+ bool fl = false;
+ while (in) {
+ double val = 0;
+ string buf;
+ getline(in, buf);
+ if (buf.size() == 0) continue;
+ if (buf[0] == '#') continue;
+ for (int i = 0; i < buf.size(); ++i)
+ if (buf[i] == '=') buf[i] = ' ';
+ int start = 0;
+ while(start < buf.size() && buf[start] == ' ') ++start;
+ int end = 0;
+ while(end < buf.size() && buf[end] != ' ') ++end;
+ int fid = FD::Convert(buf.substr(start, end - start));
+ while(end < buf.size() && buf[end] == ' ') ++end;
+ val = strtod(&buf.c_str()[end], NULL);
+ if (wv_.size() <= fid)
+ wv_.resize(fid + 1);
+ wv_[fid] = val;
+ if (feature_list) { feature_list->push_back(FD::Convert(fid)); }
+ ++weight_count;
+ if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; }
+ if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; }
+ }
+ if (fl) { cerr << endl; }
+ cerr << "Loaded " << weight_count << " feature weights\n";
+}
+
+void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features) const {
+ WriteFile out(fname);
+ ostream& o = *out.stream();
+ assert(o);
+ o.precision(17);
+ const int num_feats = FD::NumFeats();
+ for (int i = 1; i < num_feats; ++i) {
+ const double val = (i < wv_.size() ? wv_[i] : 0.0);
+ if (hide_zero_value_features && val == 0.0) continue;
+ o << FD::Convert(i) << ' ' << val << endl;
+ }
+}
+
+void Weights::InitVector(std::vector<double>* w) const {
+ *w = wv_;
+}
+
+void Weights::InitSparseVector(SparseVector<double>* w) const {
+ for (int i = 1; i < wv_.size(); ++i) {
+ const double& weight = wv_[i];
+ if (weight) w->set_value(i, weight);
+ }
+}
+
+void Weights::InitFromVector(const std::vector<double>& w) {
+ wv_ = w;
+ if (wv_.size() > FD::NumFeats())
+ cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n";
+ wv_.resize(FD::NumFeats(), 0);
+}
diff --git a/src/weights.h b/src/weights.h
new file mode 100644
index 00000000..f19aa3ce
--- /dev/null
+++ b/src/weights.h
@@ -0,0 +1,21 @@
+#ifndef _WEIGHTS_H_
+#define _WEIGHTS_H_
+
+#include <string>
+#include <map>
+#include <vector>
+#include "sparse_vector.h"
+
+class Weights {
+ public:
+ Weights() {}
+ void InitFromFile(const std::string& fname, std::vector<std::string>* feature_list = NULL);
+ void WriteToFile(const std::string& fname, bool hide_zero_value_features = true) const;
+ void InitVector(std::vector<double>* w) const;
+ void InitSparseVector(SparseVector<double>* w) const;
+ void InitFromVector(const std::vector<double>& w);
+ private:
+ std::vector<double> wv_;
+};
+
+#endif
diff --git a/src/weights_test.cc b/src/weights_test.cc
new file mode 100644
index 00000000..aa6b3db2
--- /dev/null
+++ b/src/weights_test.cc
@@ -0,0 +1,28 @@
+#include <cassert>
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <gtest/gtest.h>
+#include "weights.h"
+#include "tdict.h"
+#include "hg.h"
+
+using namespace std;
+
+class WeightsTest : public testing::Test {
+ protected:
+ virtual void SetUp() { }
+ virtual void TearDown() { }
+};
+
+
+TEST_F(WeightsTest,Load) {
+ Weights w;
+ w.InitFromFile("test_data/weights");
+ w.WriteToFile("-");
+}
+
+int main(int argc, char **argv) {
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/src/wordid.h b/src/wordid.h
new file mode 100644
index 00000000..fb50bcc1
--- /dev/null
+++ b/src/wordid.h
@@ -0,0 +1,6 @@
+#ifndef _WORD_ID_H_
+#define _WORD_ID_H_
+
+typedef int WordID;
+
+#endif