From 671c21451542e2dd20e45b4033d44d8e8735f87b Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 3 Dec 2009 16:33:55 -0500 Subject: initial check in --- src/JSON_parser.h | 152 +++++++++ src/Makefile.am | 67 ++++ src/aligner.cc | 204 ++++++++++++ src/aligner.h | 23 ++ src/apply_models.cc | 344 ++++++++++++++++++++ src/apply_models.h | 20 ++ src/array2d.h | 171 ++++++++++ src/bottom_up_parser.cc | 260 +++++++++++++++ src/bottom_up_parser.h | 27 ++ src/cdec.cc | 474 +++++++++++++++++++++++++++ src/cdec_ff.cc | 18 ++ src/collapse_weights.cc | 102 ++++++ src/dict.h | 40 +++ src/dict_test.cc | 30 ++ src/earley_composer.cc | 726 ++++++++++++++++++++++++++++++++++++++++++ src/earley_composer.h | 29 ++ src/exp_semiring.h | 71 +++++ src/fdict.cc | 4 + src/fdict.h | 21 ++ src/ff.cc | 93 ++++++ src/ff.h | 121 +++++++ src/ff_factory.cc | 35 ++ src/ff_factory.h | 39 +++ src/ff_itg_span.h | 7 + src/ff_test.cc | 134 ++++++++ src/ff_wordalign.cc | 221 +++++++++++++ src/ff_wordalign.h | 133 ++++++++ src/filelib.cc | 22 ++ src/filelib.h | 66 ++++ src/forest_writer.cc | 23 ++ src/forest_writer.h | 16 + src/freqdict.cc | 23 ++ src/freqdict.h | 19 ++ src/fst_translator.cc | 91 ++++++ src/grammar.cc | 163 ++++++++++ src/grammar.h | 83 +++++ src/grammar_test.cc | 59 ++++ src/gzstream.cc | 165 ++++++++++ src/gzstream.h | 121 +++++++ src/hg.cc | 483 ++++++++++++++++++++++++++++ src/hg.h | 225 +++++++++++++ src/hg_intersect.cc | 121 +++++++ src/hg_intersect.h | 13 + src/hg_io.cc | 585 ++++++++++++++++++++++++++++++++++ src/hg_io.h | 37 +++ src/hg_test.cc | 441 +++++++++++++++++++++++++ src/ibm_model1.cc | 4 + src/inside_outside.h | 111 +++++++ src/json_parse.cc | 50 +++ src/json_parse.h | 58 ++++ src/kbest.h | 207 ++++++++++++ src/lattice.cc | 27 ++ src/lattice.h | 31 ++ src/lexcrf.cc | 112 +++++++ src/lexcrf.h | 18 ++ src/lm_ff.cc | 328 +++++++++++++++++++ src/lm_ff.h | 32 ++ src/logval.h | 136 ++++++++ src/maxtrans_blunsom.cc | 287 +++++++++++++++++ src/parser_test.cc | 35 ++ src/phrasebased_translator.cc | 206 ++++++++++++ src/phrasebased_translator.h | 18 ++ src/phrasetable_fst.cc | 141 ++++++++ src/phrasetable_fst.h | 34 ++ src/prob.h | 8 + src/sampler.h | 136 ++++++++ src/scfg_translator.cc | 66 ++++ src/sentence_metadata.h | 42 +++ src/small_vector.h | 187 +++++++++++ src/small_vector_test.cc | 129 ++++++++ src/sparse_vector.cc | 98 ++++++ src/sparse_vector.h | 264 +++++++++++++++ src/stringlib.cc | 97 ++++++ src/stringlib.h | 91 ++++++ src/synparse.cc | 212 ++++++++++++ src/tdict.cc | 49 +++ src/tdict.h | 19 ++ src/timing_stats.cc | 24 ++ src/timing_stats.h | 25 ++ src/translator.h | 54 ++++ src/trule.cc | 237 ++++++++++++++ src/trule.h | 122 +++++++ src/trule_test.cc | 65 ++++ src/ttables.cc | 31 ++ src/ttables.h | 87 +++++ src/viterbi.cc | 39 +++ src/viterbi.h | 130 ++++++++ src/weights.cc | 73 +++++ src/weights.h | 21 ++ src/weights_test.cc | 28 ++ src/wordid.h | 6 + 91 files changed, 10497 insertions(+) create mode 100644 src/JSON_parser.h create mode 100644 src/Makefile.am create mode 100644 src/aligner.cc create mode 100644 src/aligner.h create mode 100644 src/apply_models.cc create mode 100644 src/apply_models.h create mode 100644 src/array2d.h create mode 100644 src/bottom_up_parser.cc create mode 100644 src/bottom_up_parser.h create mode 100644 src/cdec.cc create mode 100644 src/cdec_ff.cc create mode 100644 src/collapse_weights.cc create mode 100644 src/dict.h create mode 100644 src/dict_test.cc create mode 100644 src/earley_composer.cc create mode 100644 src/earley_composer.h create mode 100644 src/exp_semiring.h create mode 100644 src/fdict.cc create mode 100644 src/fdict.h create mode 100644 src/ff.cc create mode 100644 src/ff.h create mode 100644 src/ff_factory.cc create mode 100644 src/ff_factory.h create mode 100644 src/ff_itg_span.h create mode 100644 src/ff_test.cc create mode 100644 src/ff_wordalign.cc create mode 100644 src/ff_wordalign.h create mode 100644 src/filelib.cc create mode 100644 src/filelib.h create mode 100644 src/forest_writer.cc create mode 100644 src/forest_writer.h create mode 100644 src/freqdict.cc create mode 100644 src/freqdict.h create mode 100644 src/fst_translator.cc create mode 100644 src/grammar.cc create mode 100644 src/grammar.h create mode 100644 src/grammar_test.cc create mode 100644 src/gzstream.cc create mode 100644 src/gzstream.h create mode 100644 src/hg.cc create mode 100644 src/hg.h create mode 100644 src/hg_intersect.cc create mode 100644 src/hg_intersect.h create mode 100644 src/hg_io.cc create mode 100644 src/hg_io.h create mode 100644 src/hg_test.cc create mode 100644 src/ibm_model1.cc create mode 100644 src/inside_outside.h create mode 100644 src/json_parse.cc create mode 100644 src/json_parse.h create mode 100644 src/kbest.h create mode 100644 src/lattice.cc create mode 100644 src/lattice.h create mode 100644 src/lexcrf.cc create mode 100644 src/lexcrf.h create mode 100644 src/lm_ff.cc create mode 100644 src/lm_ff.h create mode 100644 src/logval.h create mode 100644 src/maxtrans_blunsom.cc create mode 100644 src/parser_test.cc create mode 100644 src/phrasebased_translator.cc create mode 100644 src/phrasebased_translator.h create mode 100644 src/phrasetable_fst.cc create mode 100644 src/phrasetable_fst.h create mode 100644 src/prob.h create mode 100644 src/sampler.h create mode 100644 src/scfg_translator.cc create mode 100644 src/sentence_metadata.h create mode 100644 src/small_vector.h create mode 100644 src/small_vector_test.cc create mode 100644 src/sparse_vector.cc create mode 100644 src/sparse_vector.h create mode 100644 src/stringlib.cc create mode 100644 src/stringlib.h create mode 100644 src/synparse.cc create mode 100644 src/tdict.cc create mode 100644 src/tdict.h create mode 100644 src/timing_stats.cc create mode 100644 src/timing_stats.h create mode 100644 src/translator.h create mode 100644 src/trule.cc create mode 100644 src/trule.h create mode 100644 src/trule_test.cc create mode 100644 src/ttables.cc create mode 100644 src/ttables.h create mode 100644 src/viterbi.cc create mode 100644 src/viterbi.h create mode 100644 src/weights.cc create mode 100644 src/weights.h create mode 100644 src/weights_test.cc create mode 100644 src/wordid.h (limited to 'src') diff --git a/src/JSON_parser.h b/src/JSON_parser.h new file mode 100644 index 00000000..ceb5b24b --- /dev/null +++ b/src/JSON_parser.h @@ -0,0 +1,152 @@ +#ifndef JSON_PARSER_H +#define JSON_PARSER_H + +/* JSON_parser.h */ + + +#include + +/* Windows DLL stuff */ +#ifdef _WIN32 +# ifdef JSON_PARSER_DLL_EXPORTS +# define JSON_PARSER_DLL_API __declspec(dllexport) +# else +# define JSON_PARSER_DLL_API __declspec(dllimport) +# endif +#else +# define JSON_PARSER_DLL_API +#endif + +/* Determine the integer type use to parse non-floating point numbers */ +#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 +typedef long long JSON_int_t; +#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" +#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" +#else +typedef long JSON_int_t; +#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" +#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" +#endif + + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum +{ + JSON_T_NONE = 0, + JSON_T_ARRAY_BEGIN, // 1 + JSON_T_ARRAY_END, // 2 + JSON_T_OBJECT_BEGIN, // 3 + JSON_T_OBJECT_END, // 4 + JSON_T_INTEGER, // 5 + JSON_T_FLOAT, // 6 + JSON_T_NULL, // 7 + JSON_T_TRUE, // 8 + JSON_T_FALSE, // 9 + JSON_T_STRING, // 10 + JSON_T_KEY, // 11 + JSON_T_MAX // 12 +} JSON_type; + +typedef struct JSON_value_struct { + union { + JSON_int_t integer_value; + + long double float_value; + + struct { + const char* value; + size_t length; + } str; + } vu; +} JSON_value; + +typedef struct JSON_parser_struct* JSON_parser; + +/*! \brief JSON parser callback + + \param ctx The pointer passed to new_JSON_parser. + \param type An element of JSON_type but not JSON_T_NONE. + \param value A representation of the parsed value. This parameter is NULL for + JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, + JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned + as zero-terminated C strings. + + \return Non-zero if parsing should continue, else zero. +*/ +typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); + + +/*! \brief The structure used to configure a JSON parser object + + \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise + the depth is the limit + \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. + \param Callback context. This parameter may be NULL. + \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. + \param allowComments. To allow C style comments in JSON, set to non-zero. + \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. + + \return The parser object. +*/ +typedef struct { + JSON_parser_callback callback; + void* callback_ctx; + int depth; + int allow_comments; + int handle_floats_manually; +} JSON_config; + + +/*! \brief Initializes the JSON parser configuration structure to default values. + + The default configuration is + - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) + - no parsing, just checking for JSON syntax + - no comments + + \param config. Used to configure the parser. +*/ +JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); + +/*! \brief Create a JSON parser object + + \param config. Used to configure the parser. Set to NULL to use the default configuration. + See init_JSON_config + + \return The parser object. +*/ +JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); + +/*! \brief Destroy a previously created JSON parser object. */ +JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); + +/*! \brief Parse a character. + + \return Non-zero, if all characters passed to this function are part of are valid JSON. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); + +/*! \brief Finalize parsing. + + Call this method once after all input characters have been consumed. + + \return Non-zero, if all parsed characters are valid JSON, zero otherwise. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); + +/*! \brief Determine if a given string is valid JSON white space + + \return Non-zero if the string is valid, zero otherwise. +*/ +JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); + + +#ifdef __cplusplus +} +#endif + + +#endif /* JSON_PARSER_H */ diff --git a/src/Makefile.am b/src/Makefile.am new file mode 100644 index 00000000..2af6cab7 --- /dev/null +++ b/src/Makefile.am @@ -0,0 +1,67 @@ +bin_PROGRAMS = \ + dict_test \ + weights_test \ + trule_test \ + hg_test \ + ff_test \ + parser_test \ + grammar_test \ + cdec \ + small_vector_test + +cdec_SOURCES = cdec.cc forest_writer.cc maxtrans_blunsom.cc cdec_ff.cc ff_factory.cc timing_stats.cc +small_vector_test_SOURCES = small_vector_test.cc +small_vector_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +parser_test_SOURCES = parser_test.cc +parser_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +dict_test_SOURCES = dict_test.cc +dict_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +ff_test_SOURCES = ff_test.cc +ff_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +grammar_test_SOURCES = grammar_test.cc +grammar_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +hg_test_SOURCES = hg_test.cc +hg_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +trule_test_SOURCES = trule_test.cc +trule_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a +weights_test_SOURCES = weights_test.cc +weights_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libhg.a + +LDADD = libhg.a + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) +AM_LDFLAGS = $(BOOST_LDFLAGS) $(BOOST_PROGRAM_OPTIONS_LIB) -lz + +noinst_LIBRARIES = libhg.a + +libhg_a_SOURCES = \ + fst_translator.cc \ + scfg_translator.cc \ + hg.cc \ + hg_io.cc \ + viterbi.cc \ + lattice.cc \ + logging.cc \ + aligner.cc \ + gzstream.cc \ + apply_models.cc \ + earley_composer.cc \ + phrasetable_fst.cc \ + sparse_vector.cc \ + trule.cc \ + filelib.cc \ + stringlib.cc \ + fdict.cc \ + tdict.cc \ + weights.cc \ + ttables.cc \ + ff.cc \ + lm_ff.cc \ + ff_wordalign.cc \ + hg_intersect.cc \ + lexcrf.cc \ + bottom_up_parser.cc \ + phrasebased_translator.cc \ + JSON_parser.c \ + json_parse.cc \ + grammar.cc diff --git a/src/aligner.cc b/src/aligner.cc new file mode 100644 index 00000000..d9d067e5 --- /dev/null +++ b/src/aligner.cc @@ -0,0 +1,204 @@ +#include "aligner.h" + +#include "array2d.h" +#include "hg.h" +#include "inside_outside.h" +#include + +using namespace std; + +struct EdgeCoverageInfo { + set src_indices; + set trg_indices; +}; + +static bool is_digit(char x) { return x >= '0' && x <= '9'; } + +boost::shared_ptr > AlignerTools::ReadPharaohAlignmentGrid(const string& al) { + int max_x = 0; + int max_y = 0; + int i = 0; + while (i < al.size()) { + int x = 0; + while(i < al.size() && is_digit(al[i])) { + x *= 10; + x += al[i] - '0'; + ++i; + } + if (x > max_x) max_x = x; + assert(i < al.size()); + assert(al[i] == '-'); + ++i; + int y = 0; + while(i < al.size() && is_digit(al[i])) { + y *= 10; + y += al[i] - '0'; + ++i; + } + if (y > max_y) max_y = y; + while(i < al.size() && al[i] == ' ') { ++i; } + } + + boost::shared_ptr > grid(new Array2D(max_x + 1, max_y + 1)); + i = 0; + while (i < al.size()) { + int x = 0; + while(i < al.size() && is_digit(al[i])) { + x *= 10; + x += al[i] - '0'; + ++i; + } + assert(i < al.size()); + assert(al[i] == '-'); + ++i; + int y = 0; + while(i < al.size() && is_digit(al[i])) { + y *= 10; + y += al[i] - '0'; + ++i; + } + (*grid)(x, y) = true; + while(i < al.size() && al[i] == ' ') { ++i; } + } + // cerr << *grid << endl; + return grid; +} + +void AlignerTools::SerializePharaohFormat(const Array2D& alignment, ostream* out) { + bool need_space = false; + for (int i = 0; i < alignment.width(); ++i) + for (int j = 0; j < alignment.height(); ++j) + if (alignment(i,j)) { + if (need_space) (*out) << ' '; else need_space = true; + (*out) << i << '-' << j; + } + (*out) << endl; +} + +// compute the coverage vectors of each edge +// prereq: all derivations yield the same string pair +void ComputeCoverages(const Hypergraph& g, + vector* pcovs) { + for (int i = 0; i < g.edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[i]; + EdgeCoverageInfo& cov = (*pcovs)[i]; + // no words + if (edge.rule_->EWords() == 0 || edge.rule_->FWords() == 0) + continue; + // aligned to NULL (crf ibm variant only) + if (edge.prev_i_ == -1 || edge.i_ == -1) + continue; + assert(edge.j_ >= 0); + assert(edge.prev_j_ >= 0); + if (edge.Arity() == 0) { + for (int k = edge.i_; k < edge.j_; ++k) + cov.trg_indices.insert(k); + for (int k = edge.prev_i_; k < edge.prev_j_; ++k) + cov.src_indices.insert(k); + } else { + // note: this code, which handles mixed NT and terminal + // rules assumes that nodes uniquely define a src and trg + // span. + int k = edge.prev_i_; + int j = 0; + const vector& f = edge.rule_->e(); // rules are inverted + while (k < edge.prev_j_) { + if (f[j] > 0) { + cov.src_indices.insert(k); + // cerr << "src: " << k << endl; + ++k; + ++j; + } else { + const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[-f[j]]]; + assert(tailnode.in_edges_.size() > 0); + // any edge will do: + const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()]; + //cerr << "skip " << (rep_edge.prev_j_ - rep_edge.prev_i_) << endl; // src span + k += (rep_edge.prev_j_ - rep_edge.prev_i_); // src span + ++j; + } + } + int tc = 0; + const vector& e = edge.rule_->f(); // rules are inverted + k = edge.i_; + j = 0; + // cerr << edge.rule_->AsString() << endl; + // cerr << "i=" << k << " j=" << edge.j_ << endl; + while (k < edge.j_) { + //cerr << " k=" << k << endl; + if (e[j] > 0) { + cov.trg_indices.insert(k); + // cerr << "trg: " << k << endl; + ++k; + ++j; + } else { + assert(tc < edge.tail_nodes_.size()); + const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[tc]]; + assert(tailnode.in_edges_.size() > 0); + // any edge will do: + const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()]; + // cerr << "t skip " << (rep_edge.j_ - rep_edge.i_) << endl; // src span + k += (rep_edge.j_ - rep_edge.i_); // src span + ++j; + ++tc; + } + } + //abort(); + } + } +} + +void AlignerTools::WriteAlignment(const string& input, + const Lattice& ref, + const Hypergraph& g, + bool map_instead_of_viterbi) { + if (!map_instead_of_viterbi) { + assert(!"not implemented!"); + } + vector edge_posteriors(g.edges_.size()); + { + SparseVector posts; + InsideOutside, TransitionEventWeightFunction>(g, &posts); + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = posts[i]; + } + vector edge2cov(g.edges_.size()); + ComputeCoverages(g, &edge2cov); + + Lattice src; + // currently only dealing with src text, even if the + // model supports lattice translation (which it probably does) + LatticeTools::ConvertTextToLattice(input, &src); + // TODO assert that src is a "real lattice" + + Array2D align(src.size(), ref.size(), prob_t::Zero()); + for (int c = 0; c < g.edges_.size(); ++c) { + const prob_t& p = edge_posteriors[c]; + const EdgeCoverageInfo& eci = edge2cov[c]; + for (set::const_iterator si = eci.src_indices.begin(); + si != eci.src_indices.end(); ++si) { + for (set::const_iterator ti = eci.trg_indices.begin(); + ti != eci.trg_indices.end(); ++ti) { + align(*si, *ti) += p; + } + } + } + prob_t threshold(0.9); + const bool use_soft_threshold = true; // TODO configure + + Array2D grid(src.size(), ref.size(), false); + for (int j = 0; j < ref.size(); ++j) { + if (use_soft_threshold) { + threshold = prob_t::Zero(); + for (int i = 0; i < src.size(); ++i) + if (align(i, j) > threshold) threshold = align(i, j); + //threshold *= prob_t(0.99); + } + for (int i = 0; i < src.size(); ++i) + grid(i, j) = align(i, j) >= threshold; + } + cerr << align << endl; + cerr << grid << endl; + SerializePharaohFormat(grid, &cout); +}; + diff --git a/src/aligner.h b/src/aligner.h new file mode 100644 index 00000000..970c72f2 --- /dev/null +++ b/src/aligner.h @@ -0,0 +1,23 @@ +#ifndef _ALIGNER_H_ + +#include +#include +#include +#include "array2d.h" +#include "lattice.h" + +class Hypergraph; + +struct AlignerTools { + static boost::shared_ptr > ReadPharaohAlignmentGrid(const std::string& al); + static void SerializePharaohFormat(const Array2D& alignment, std::ostream* out); + + // assumption: g contains derivations of input/ref and + // ONLY input/ref. + static void WriteAlignment(const std::string& input, + const Lattice& ref, + const Hypergraph& g, + bool map_instead_of_viterbi = true); +}; + +#endif diff --git a/src/apply_models.cc b/src/apply_models.cc new file mode 100644 index 00000000..8efb331b --- /dev/null +++ b/src/apply_models.cc @@ -0,0 +1,344 @@ +#include "apply_models.h" + +#include +#include +#include +#include + +#include + +#include "hg.h" +#include "ff.h" + +using namespace std; +using namespace std::tr1; + +struct Candidate; +typedef SmallVector JVector; +typedef vector CandidateHeap; +typedef vector CandidateList; + +// life cycle: candidates are created, placed on the heap +// and retrieved by their estimated cost, when they're +// retrieved, they're incorporated into the +LM hypergraph +// where they also know the head node index they are +// attached to. After they are added to the +LM hypergraph +// vit_prob_ and est_prob_ fields may be updated as better +// derivations are found (this happens since the successor's +// of derivation d may have a better score- they are +// explored lazily). However, the updates don't happen +// when a candidate is in the heap so maintaining the heap +// property is not an issue. +struct Candidate { + int node_index_; // -1 until incorporated + // into the +LM forest + const Hypergraph::Edge* in_edge_; // in -LM forest + Hypergraph::Edge out_edge_; + string state_; + const JVector j_; + prob_t vit_prob_; // these are fixed until the cand + // is popped, then they may be updated + prob_t est_prob_; + + Candidate(const Hypergraph::Edge& e, + const JVector& j, + const Hypergraph& out_hg, + const vector& D, + const SentenceMetadata& smeta, + const ModelSet& models, + bool is_goal) : + node_index_(-1), + in_edge_(&e), + j_(j) { + InitializeCandidate(out_hg, smeta, D, models, is_goal); + } + + // used to query uniqueness + Candidate(const Hypergraph::Edge& e, + const JVector& j) : in_edge_(&e), j_(j) {} + + bool IsIncorporatedIntoHypergraph() const { + return node_index_ >= 0; + } + + void InitializeCandidate(const Hypergraph& out_hg, + const SentenceMetadata& smeta, + const vector >& D, + const ModelSet& models, + const bool is_goal) { + const Hypergraph::Edge& in_edge = *in_edge_; + out_edge_.rule_ = in_edge.rule_; + out_edge_.feature_values_ = in_edge.feature_values_; + out_edge_.i_ = in_edge.i_; + out_edge_.j_ = in_edge.j_; + out_edge_.prev_i_ = in_edge.prev_i_; + out_edge_.prev_j_ = in_edge.prev_j_; + Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; + tail.resize(j_.size()); + prob_t p = prob_t::One(); + // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; + for (int i = 0; i < tail.size(); ++i) { + const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; + assert(ant.IsIncorporatedIntoHypergraph()); + tail[i] = ant.node_index_; + p *= ant.vit_prob_; + } + prob_t edge_estimate = prob_t::One(); + if (is_goal) { + assert(tail.size() == 1); + const string& ant_state = out_hg.nodes_[tail.front()].state_; + models.AddFinalFeatures(ant_state, &out_edge_); + } else { + models.AddFeaturesToEdge(smeta, out_hg, &out_edge_, &state_, &edge_estimate); + } + vit_prob_ = out_edge_.edge_prob_ * p; + est_prob_ = vit_prob_ * edge_estimate; + } +}; + +ostream& operator<<(ostream& os, const Candidate& cand) { + os << "CAND["; + if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } + else { os << "+LM_node=" << cand.node_index_; } + os << " edge=" << cand.in_edge_->id_; + os << " j=<"; + for (int i = 0; i < cand.j_.size(); ++i) + os << (i==0 ? "" : " ") << cand.j_[i]; + os << "> vit=" << log(cand.vit_prob_); + os << " est=" << log(cand.est_prob_); + return os << ']'; +} + +struct HeapCandCompare { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ < r->est_prob_; + } +}; + +struct EstProbSorter { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ > r->est_prob_; + } +}; + +// the same candidate can be added multiple times if +// j is multidimensional (if you're going NW in Manhattan, you +// can first go north, then west, or you can go west then north) +// this is a hash function on the relevant variables from +// Candidate to enforce this. +struct CandidateUniquenessHash { + size_t operator()(const Candidate* c) const { + size_t x = 5381; + x = ((x << 5) + x) ^ c->in_edge_->id_; + for (int i = 0; i < c->j_.size(); ++i) + x = ((x << 5) + x) ^ c->j_[i]; + return x; + } +}; + +struct CandidateUniquenessEquals { + bool operator()(const Candidate* a, const Candidate* b) const { + return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); + } +}; + +typedef unordered_set UniqueCandidateSet; +typedef unordered_map > State2Node; + +class CubePruningRescorer { + +public: + CubePruningRescorer(const ModelSet& m, + const SentenceMetadata& sm, + const Hypergraph& i, + int pop_limit, + Hypergraph* o) : + models(m), + smeta(sm), + in(i), + out(*o), + D(in.nodes_.size()), + pop_limit_(pop_limit) { + cerr << " Rescoring forest (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + int every = 1; + if (num_nodes > 100) every = 10; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (i % every == 0) cerr << '.'; + KBest(i, i == goal_id); + } + cerr << endl; + cerr << " Best path: " << log(D[goal_id].front()->vit_prob_) + << "\t" << log(D[goal_id].front()->est_prob_) << endl; + out.PruneUnreachable(D[goal_id].front()->node_index_); + FreeAll(); + } + + private: + void FreeAll() { + for (int i = 0; i < D.size(); ++i) { + CandidateList& D_i = D[i]; + for (int j = 0; j < D_i.size(); ++j) + delete D_i[j]; + } + D.clear(); + } + + void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_); + new_edge->feature_values_ = item->out_edge_.feature_values_; + new_edge->edge_prob_ = item->out_edge_.edge_prob_; + new_edge->i_ = item->out_edge_.i_; + new_edge->j_ = item->out_edge_.j_; + new_edge->prev_i_ = item->out_edge_.prev_i_; + new_edge->prev_j_ = item->out_edge_.prev_j_; + Candidate*& o_item = (*s2n)[item->state_]; + if (!o_item) o_item = item; + + int& node_id = o_item->node_index_; + if (node_id < 0) { + Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, item->state_); + node_id = new_node->id_; + } + Hypergraph::Node* node = &out.nodes_[node_id]; + out.ConnectEdgeToHeadNode(new_edge, node); + + // update candidate if we have a better derivation + // note: the difference between the vit score and the estimated + // score is the same for all items with a common residual DP + // state + if (item->vit_prob_ > o_item->vit_prob_) { + assert(o_item->state_ == item->state_); // sanity check! + o_item->est_prob_ = item->est_prob_; + o_item->vit_prob_ = item->vit_prob_; + } + if (item != o_item) freelist->push_back(item); + } + + void KBest(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_cands; + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, smeta, models, is_goal)); + assert(unique_cands.insert(cand.back()).second); // these should all be unique! + } +// cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + PushSucc(*item, is_goal, &cand, &unique_cands); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) + D_v[c++] = i->second; + sort(D_v.begin(), D_v.end(), EstProbSorter()); + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (cs->count(&query_unique) == 0) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check + } + } + } + } + + const ModelSet& models; + const SentenceMetadata& smeta; + const Hypergraph& in; + Hypergraph& out; + + vector D; // maps nodes in in-HG to the + // equivalent nodes (many due to state + // splits) in the out-HG. + const int pop_limit_; +}; + +struct NoPruningRescorer { + NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) : + models(m), + in(i), + out(*o) { + cerr << " Rescoring forest (full intersection)\n"; + } + + void RescoreNode(const int node_num, const bool is_goal) { + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + int every = 1; + if (num_nodes > 100) every = 10; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (i % every == 0) cerr << '.'; + RescoreNode(i, i == goal_id); + } + cerr << endl; + } + + private: + const ModelSet& models; + const Hypergraph& in; + Hypergraph& out; +}; + +// each node in the graph has one of these, it keeps track of +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const PruningConfiguration& config, + Hypergraph* out) { + int pl = config.pop_limit; + if (pl > 100 && in.nodes_.size() > 80000) { + cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; + pl = 30; + } + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); +} + diff --git a/src/apply_models.h b/src/apply_models.h new file mode 100644 index 00000000..08fce037 --- /dev/null +++ b/src/apply_models.h @@ -0,0 +1,20 @@ +#ifndef _APPLY_MODELS_H_ +#define _APPLY_MODELS_H_ + +struct ModelSet; +struct Hypergraph; +struct SentenceMetadata; + +struct PruningConfiguration { + const int algorithm; // 0 = full intersection, 1 = cube pruning + const int pop_limit; // max number of pops off the heap at each node + explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {} +}; + +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const PruningConfiguration& config, + Hypergraph* out); + +#endif diff --git a/src/array2d.h b/src/array2d.h new file mode 100644 index 00000000..09d84d0b --- /dev/null +++ b/src/array2d.h @@ -0,0 +1,171 @@ +#ifndef ARRAY2D_H_ +#define ARRAY2D_H_ + +#include +#include +#include +#include +#include + +template +class Array2D { + public: + typedef typename std::vector::reference reference; + typedef typename std::vector::const_reference const_reference; + typedef typename std::vector::iterator iterator; + typedef typename std::vector::const_iterator const_iterator; + Array2D() : width_(0), height_(0) {} + Array2D(int w, int h, const T& d = T()) : + width_(w), height_(h), data_(w*h, d) {} + Array2D(const Array2D& rhs) : + width_(rhs.width_), height_(rhs.height_), data_(rhs.data_) {} + void resize(int w, int h, const T& d = T()) { + data_.resize(w * h, d); + width_ = w; + height_ = h; + } + const Array2D& operator=(const Array2D& rhs) { + data_ = rhs.data_; + width_ = rhs.width_; + height_ = rhs.height_; + return *this; + } + void fill(const T& v) { data_.assign(data_.size(), v); } + int width() const { return width_; } + int height() const { return height_; } + reference operator()(int i, int j) { + return data_[offset(i, j)]; + } + void clear() { data_.clear(); width_=0; height_=0; } + const_reference operator()(int i, int j) const { + return data_[offset(i, j)]; + } + iterator begin_col(int j) { + return data_.begin() + offset(0,j); + } + const_iterator begin_col(int j) const { + return data_.begin() + offset(0,j); + } + iterator end_col(int j) { + return data_.begin() + offset(0,j) + width_; + } + const_iterator end_col(int j) const { + return data_.begin() + offset(0,j) + width_; + } + iterator end() { return data_.end(); } + const_iterator end() const { return data_.end(); } + const Array2D& operator*=(const T& x) { + std::transform(data_.begin(), data_.end(), data_.begin(), + std::bind2nd(std::multiplies(), x)); + } + const Array2D& operator/=(const T& x) { + std::transform(data_.begin(), data_.end(), data_.begin(), + std::bind2nd(std::divides(), x)); + } + const Array2D& operator+=(const Array2D& m) { + std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::plus()); + } + const Array2D& operator-=(const Array2D& m) { + std::transform(m.data_.begin(), m.data_.end(), data_.begin(), data_.begin(), std::minus()); + } + + private: + inline int offset(int i, int j) const { + assert(i data_; +}; + +template +Array2D operator*(const Array2D& l, const T& scalar) { + Array2D res(l); + res *= scalar; + return res; +} + +template +Array2D operator*(const T& scalar, const Array2D& l) { + Array2D res(l); + res *= scalar; + return res; +} + +template +Array2D operator/(const Array2D& l, const T& scalar) { + Array2D res(l); + res /= scalar; + return res; +} + +template +Array2D operator+(const Array2D& l, const Array2D& r) { + Array2D res(l); + res += r; + return res; +} + +template +Array2D operator-(const Array2D& l, const Array2D& r) { + Array2D res(l); + res -= r; + return res; +} + +template +inline std::ostream& operator<<(std::ostream& os, const Array2D& m) { + for (int i=0; i& m) { + os << ' '; + for (int j=0; j >& m) { + os << ' '; + for (int j=0; j& ar = m(i,j); + for (int k=0; k + +#include "hg.h" +#include "array2d.h" +#include "tdict.h" + +using namespace std; + +class ActiveChart; +class PassiveChart { + public: + PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest); + ~PassiveChart(); + + inline const vector& operator()(int i, int j) const { return chart_(i,j); } + bool Parse(); + inline int size() const { return chart_.width(); } + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail); + void ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes); + void ApplyUnaryRules(const int i, const int j); + + const vector& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D > chart_; // chart_(i,j) is the list of nodes derived spanning i,j + typedef map Cat2NodeMap; + Array2D nodemap_; + vector act_chart_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + + static WordID kGOAL; // [Goal] +}; + +WordID PassiveChart::kGOAL = 0; + +class ActiveChart { + public: + ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) : + hg_(hg), + act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} + + struct ActiveItem { + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, double lcost) : + gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + explicit ActiveItem(const GrammarIter* g) : + gptr_(g), ant_nodes_(), lattice_cost() {} + + void ExtendTerminal(int symbol, double src_cost, vector* out_cell) const { + const GrammarIter* ni = gptr_->Extend(symbol); + if (ni) out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector* out_cell) const { + int symbol = hg->nodes_[node_index].cat_; + const GrammarIter* ni = gptr_->Extend(symbol); + if (!ni) return; + Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); + for (int i = 0; i < ant_nodes_.size(); ++i) + na[i] = ant_nodes_[i]; + na[ant_nodes_.size()] = node_index; + out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + } + + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + double lattice_cost; // TODO? use SparseVector + }; + + inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } + void SeedActiveChart(const Grammar& g) { + int size = act_chart_.width(); + for (int i = 0; i < size; ++i) + if (g.HasRuleForSpan(i,i)) + act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); + } + + void ExtendActiveItems(int i, int k, int j) { + //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n"; + vector& cell = act_chart_(i,j); + const vector& icell = act_chart_(i,k); + const vector& idxs = psv_chart_(k, j); + //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; } + for (vector::const_iterator di = icell.begin(); di != icell.end(); ++di) { + for (vector::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) { + di->ExtendNonTerminal(hg_, *ni, &cell); + } + } + } + + void AdvanceDotsForAllItemsInCell(int i, int j, const vector >& input) { + //cerr << "ADVANCE(" << i << "," << j << ")\n"; + for (int k=i+1; k < j; ++k) + ExtendActiveItems(i, k, j); + + const vector& out_arcs = input[j-1]; + for (vector::const_iterator ai = out_arcs.begin(); + ai != out_arcs.end(); ++ai) { + const WordID& f = ai->label; + const double& c = ai->cost; + const int& len = ai->dist2next; + //VLOG(1) << "F: " << TD::Convert(f) << endl; + const vector& ec = act_chart_(i, j-1); + for (vector::const_iterator di = ec.begin(); di != ec.end(); ++di) + di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); + } + } + + private: + const Hypergraph* hg_; + Array2D > act_chart_; + const PassiveChart& psv_chart_; +}; + +PassiveChart::PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), + goal_idx_(-1) { + act_chart_.resize(grammars_.size()); + for (int i = 0; i < grammars_.size(); ++i) + act_chart_[i] = new ActiveChart(forest, *this); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + cerr << " Goal category: [" << goal << ']' << endl; +} + +void PassiveChart::ApplyRule(const int i, const int j, TRulePtr r, const Hypergraph::TailNodeVector& ant_nodes) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + if (ni == c2n.end()) { + node = forest_->AddNode(r->GetLHS(), ""); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); +} + +void PassiveChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail) { + const int n = rules->GetNumRules(); + for (int k = 0; k < n; ++k) + ApplyRule(i, j, rules->GetIthRule(k), tail); +} + +void PassiveChart::ApplyUnaryRules(const int i, const int j) { + const vector& nodes = chart_(i,j); // reference is important! + for (int gi = 0; gi < grammars_.size(); ++gi) { + if (!grammars_[gi]->HasRuleForSpan(i,j)) continue; + for (int di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const vector& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); + for (int ri = 0; ri < unaries.size(); ++ri) { + // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + const Hypergraph::TailNodeVector ant(1, nodes[di]); + ApplyRule(i, j, unaries[ri], ant); // may update nodes + } + } + } +} + +bool PassiveChart::Parse() { + forest_->nodes_.reserve(input_.size() * input_.size() * 2); + forest_->edges_.reserve(input_.size() * input_.size() * 1000); // TODO: reservation?? + goal_idx_ = -1; + for (int gi = 0; gi < grammars_.size(); ++gi) + act_chart_[gi]->SeedActiveChart(*grammars_[gi]); + + cerr << " "; + for (int l=1; lAdvanceDotsForAllItemsInCell(i, j, input_); + + const vector& cell = (*act_chart_[gi])(i,j); + for (vector::const_iterator ai = cell.begin(); + ai != cell.end(); ++ai) { + const RuleBin* rules = (ai->gptr_->GetRules()); + if (!rules) continue; + ApplyRules(i, j, rules, ai->ant_nodes_); + } + } + } + ApplyUnaryRules(i,j); + + for (int gi = 0; gi < grammars_.size(); ++gi) { + const Grammar& g = *grammars_[gi]; + // deal with non-terminals that were just proved + if (g.HasRuleForSpan(i, j)) + act_chart_[gi]->ExtendActiveItems(i, i, j); + } + } + const vector& dh = chart_(0, input_.size()); + for (int di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant); + } + } + } + cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +PassiveChart::~PassiveChart() { + for (int i = 0; i < act_chart_.size(); ++i) + delete act_chart_[i]; +} + +ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( + const string& goal_sym, + const vector& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool ExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + PassiveChart chart(goal_sym_, grammars_, input, forest); + return chart.Parse(); +} diff --git a/src/bottom_up_parser.h b/src/bottom_up_parser.h new file mode 100644 index 00000000..546bfb54 --- /dev/null +++ b/src/bottom_up_parser.h @@ -0,0 +1,27 @@ +#ifndef _BOTTOM_UP_PARSER_H_ +#define _BOTTOM_UP_PARSER_H_ + +#include +#include + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +class ExhaustiveBottomUpParser { + public: + ExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector grammars_; +}; + +#endif diff --git a/src/cdec.cc b/src/cdec.cc new file mode 100644 index 00000000..c5780cef --- /dev/null +++ b/src/cdec.cc @@ -0,0 +1,474 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "timing_stats.h" +#include "translator.h" +#include "phrasebased_translator.h" +#include "aligner.h" +#include "stringlib.h" +#include "forest_writer.h" +#include "filelib.h" +#include "sampler.h" +#include "sparse_vector.h" +#include "lexcrf.h" +#include "weights.h" +#include "tdict.h" +#include "ff.h" +#include "ff_factory.h" +#include "hg_intersect.h" +#include "apply_models.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" +#include "exp_semiring.h" +#include "sentence_metadata.h" + +using namespace std; +using namespace std::tr1; +using boost::shared_ptr; +namespace po = boost::program_options; + +// some globals ... +boost::shared_ptr > rng; + +namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } + +void ShowBanner() { + cerr << "cdec v1.0 (c) 2009 by Chris Dyer\n"; +} + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("formalism,f",po::value()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment") + ("input,i",po::value()->default_value("-"),"Source file") + ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") + ("weights,w",po::value(),"Feature weights file") + ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") + ("list_feature_functions,L","List available feature functions") + ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") + ("k_best,k",po::value(),"Extract the k best derivations") + ("unique_k_best,r", "Unique k-best translation list") + ("aligner,a", "Run as a word/phrase aligner (src & ref required)") + ("cubepruning_pop_limit,K",po::value()->default_value(200), "Max number of pops from the candidate heap at each node") + ("goal",po::value()->default_value("S"),"Goal symbol (SCFG & FST)") + ("scfg_extra_glue_grammar", po::value(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") + ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") + ("scfg_default_nt,d",po::value()->default_value("X"),"Default non-terminal symbol in SCFG") + ("scfg_max_span_limit,S",po::value()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") + ("show_tree_structure,T", "Show the Viterbi derivation structure") + ("show_expected_length", "Show the expected translation length under the model") + ("show_partition,z", "Compute and show the partition (inside score)") + ("extract_rules", po::value(), "Extract the rules used in translation (de-duped) to this file") + ("graphviz","Show (constrained) translation forest in GraphViz format") + ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") + ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") + ("pb_max_distortion,D", po::value()->default_value(4), "Phrase-based decoder: maximum distortion") + ("gradient,G","Compute d log p(e|f) / d lambda_i and write to STDOUT (src & ref required)") + ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") + ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") + ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") + ("forest_output,O",po::value(),"Directory to write forests to") + ("minimal_forests,m","Write minimal forests (excludes Rule information). Such forests can be used for ML/MAP training, but not rescoring, etc."); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + const string cfg = (*conf)["config"].as(); + cerr << "Configuration file: " << cfg << endl; + ifstream config(cfg.c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("list_feature_functions")) { + cerr << "Available feature functions (specify with -F):\n"; + global_ff_registry->DisplayList(); + cerr << endl; + exit(1); + } + + if (conf->count("help") || conf->count("grammar") == 0) { + cerr << dcmdline_options << endl; + exit(1); + } + + const string formalism = LowercaseString((*conf)["formalism"].as()); + if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +// TODO move out of cdec into some sampling decoder file +void SampleRecurse(const Hypergraph& hg, const vector& ss, int n, vector* out) { + const SampleSet& s = ss[n]; + int i = rng->SelectSample(s); + const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; + vector > ants(edge.tail_nodes_.size()); + for (int j = 0; j < ants.size(); ++j) + SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); + + vector*> pants(ants.size()); + for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; + edge.rule_->ESubstitute(pants, out); +} + +struct SampleSort { + bool operator()(const pair& a, const pair& b) const { + return a.first > b.first; + } +}; + +// TODO move out of cdec into some sampling decoder file +void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { + unordered_map > m; + hg->PushWeightsToGoal(); + const int num_nodes = hg->nodes_.size(); + vector ss(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + SampleSet& s = ss[i]; + const vector& in_edges = hg->nodes_[i].in_edges_; + for (int j = 0; j < in_edges.size(); ++j) { + s.add(hg->edges_[in_edges[j]].edge_prob_); + } + } + for (int i = 0; i < samples; ++i) { + vector yield; + SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); + const string trans = TD::GetString(yield); + ++m[trans]; + } + vector > dist; + for (unordered_map >::iterator i = m.begin(); + i != m.end(); ++i) { + dist.push_back(make_pair(i->second, i->first)); + } + sort(dist.begin(), dist.end(), SampleSort()); + if (k) { + for (int i = 0; i < k; ++i) + cout << dist[i].first << " ||| " << dist[i].second << endl; + } else { + cout << dist[0].second << endl; + } +} + +// TODO decoder output should probably be moved to another file +void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique) { + if (unique) { + KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + cout << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +struct ELengthWeightFunction { + double operator()(const Hypergraph::Edge& e) const { + return e.rule_->ELength() - e.rule_->Arity(); + } +}; + + +struct TRPHash { + size_t operator()(const TRulePtr& o) const { return reinterpret_cast(o.get()); } +}; +static void ExtractRulesDedupe(const Hypergraph& hg, ostream* os) { + static unordered_set written; + for (int i = 0; i < hg.edges_.size(); ++i) { + const TRulePtr& rule = hg.edges_[i].rule_; + if (written.insert(rule).second) { + (*os) << rule->AsString() << endl; + } + } +} + +void register_feature_functions(); + +int main(int argc, char** argv) { + global_ff_registry.reset(new FFRegistry); + register_feature_functions(); + ShowBanner(); + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const bool write_gradient = conf.count("gradient"); + const bool feature_expectations = conf.count("feature_expectations"); + if (write_gradient && feature_expectations) { + cerr << "You can only specify --gradient or --feature_expectations, not both!\n"; + exit(1); + } + const bool output_training_vector = (write_gradient || feature_expectations); + + boost::shared_ptr translator; + const string formalism = LowercaseString(conf["formalism"].as()); + if (formalism == "scfg") + translator.reset(new SCFGTranslator(conf)); + else if (formalism == "fst") + translator.reset(new FSTTranslator(conf)); + else if (formalism == "pb") + translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "lexcrf") + translator.reset(new LexicalCRF(conf)); + else + assert(!"error"); + + vector wv; + Weights w; + if (conf.count("weights")) { + w.InitFromFile(conf["weights"].as()); + wv.resize(FD::NumFeats()); + w.InitVector(&wv); + } + + // set up additional scoring features + vector > pffs; + vector late_ffs; + if (conf.count("feature_function") > 0) { + const vector& add_ffs = conf["feature_function"].as >(); + for (int i = 0; i < add_ffs.size(); ++i) { + string ff, param; + SplitCommandAndParam(add_ffs[i], &ff, ¶m); + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + shared_ptr pff = global_ff_registry->Create(ff, param); + if (!pff) { exit(1); } + // TODO check that multiple features aren't trying to set the same fid + pffs.push_back(pff); + late_ffs.push_back(pff.get()); + } + } + ModelSet late_models(wv, late_ffs); + + const int sample_max_trans = conf.count("max_translation_sample") ? + conf["max_translation_sample"].as() : 0; + if (sample_max_trans) + rng.reset(new RandomNumberGenerator); + const bool aligner_mode = conf.count("aligner"); + const bool minimal_forests = conf.count("minimal_forests"); + const bool graphviz = conf.count("graphviz"); + const bool encode_b64 = conf["vector_format"].as() == "b64"; + const bool kbest = conf.count("k_best"); + const bool unique_kbest = conf.count("unique_k_best"); + shared_ptr extract_file; + if (conf.count("extract_rules")) + extract_file.reset(new WriteFile(conf["extract_rules"].as())); + + int combine_size = conf["combine_size"].as(); + if (combine_size < 1) combine_size = 1; + const string input = conf["input"].as(); + cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; + ReadFile in_read(input); + istream *in = in_read.stream(); + assert(*in); + + SparseVector acc_vec; // accumulate gradient + double acc_obj = 0; // accumulate objective + int g_count = 0; // number of gradient pieces computed + int sent_id = -1; // line counter + + while(*in) { + Timer::Summarize(); + ++sent_id; + string buf; + getline(*in, buf); + if (buf.empty()) continue; + map sgml; + ProcessAndStripSGML(&buf, &sgml); + if (sgml.find("id") != sgml.end()) + sent_id = atoi(sgml["id"].c_str()); + + cerr << "\nINPUT: "; + if (buf.size() < 100) + cerr << buf << endl; + else { + size_t x = buf.rfind(" ", 100); + if (x == string::npos) x = 100; + cerr << buf.substr(0, x) << " ..." << endl; + } + cerr << " id = " << sent_id << endl; + string to_translate; + Lattice ref; + ParseTranslatorInputLattice(buf, &to_translate, &ref); + const bool has_ref = ref.size() > 0; + SentenceMetadata smeta(sent_id, ref); + const bool hadoop_counters = (write_gradient); + Hypergraph forest; // -LM forest + Timer t("Translation"); + if (!translator->Translate(to_translate, &smeta, wv, &forest)) { + cerr << " NO PARSE FOUND.\n"; + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,FParseFailed,1" << endl; + cout << endl << flush; + continue; + } + cerr << " -LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " -LM forest (paths): " << forest.NumberOfPaths() << endl; + if (conf.count("show_expected_length")) { + const PRPair res = + Inside, + PRWeightFunction >(forest); + cerr << " Expected length (words): " << res.r / res.p << "\t" << res << endl; + } + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " -LM partition log(Z): " << log(z) << endl; + } + if (extract_file) + ExtractRulesDedupe(forest, extract_file->stream()); + vector trans; + const prob_t vs = ViterbiESentence(forest, &trans); + cerr << " -LM Viterbi: " << TD::GetString(trans) << endl; + if (conf.count("show_tree_structure")) + cerr << " -LM tree: " << ViterbiETree(forest) << endl;; + cerr << " -LM Viterbi: " << log(vs) << endl; + + bool has_late_models = !late_models.empty(); + if (has_late_models) { + forest.Reweight(wv); + forest.SortInEdgesByEdgeWeights(); + Hypergraph lm_forest; + int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as(); + ApplyModelSet(forest, + smeta, + late_models, + PruningConfiguration(cubepruning_pop_limit), + &lm_forest); + forest.swap(lm_forest); + forest.Reweight(wv); + trans.clear(); + ViterbiESentence(forest, &trans); + cerr << " +LM forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " +LM forest (paths): " << forest.NumberOfPaths() << endl; + cerr << " +LM Viterbi: " << TD::GetString(trans) << endl; + } + if (conf.count("forest_output") && !has_ref) { + ForestWriter writer(conf["forest_output"].as(), sent_id); + assert(writer.Write(forest, minimal_forests)); + } + + if (sample_max_trans) { + MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as() : 0); + } else { + if (kbest) { + DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest); + } else { + if (!graphviz && !has_ref) { + cout << TD::GetString(trans) << endl << flush; + } + } + } + + const int max_trans_beam_size = conf.count("max_translation_beam") ? + conf["max_translation_beam"].as() : 0; + if (max_trans_beam_size) { + Hack::MaxTrans(forest, max_trans_beam_size); + continue; + } + + if (graphviz && !has_ref) forest.PrintGraphviz(); + + // the following are only used if write_gradient is true! + SparseVector full_exp, ref_exp, gradient; + double log_z = 0, log_ref_z = 0; + if (write_gradient) + log_z = log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &full_exp)); + + if (has_ref) { + if (HG::Intersect(ref, &forest)) { + cerr << " Constr. forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + cerr << " Constr. forest (paths): " << forest.NumberOfPaths() << endl; + forest.Reweight(wv); + cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,SentencePairsParsed,1" << endl; + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Contst. partition log(Z): " << log(z) << endl; + } + //DumpKBest(sent_id, forest, 1000); + if (conf.count("forest_output")) { + ForestWriter writer(conf["forest_output"].as(), sent_id); + assert(writer.Write(forest, minimal_forests)); + } + if (aligner_mode && !output_training_vector) + AlignerTools::WriteAlignment(to_translate, ref, forest); + if (write_gradient) { + log_ref_z = log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &ref_exp)); + if (log_z < log_ref_z) { + cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; + exit(1); + } + //cerr << "FULL: " << full_exp << endl; + //cerr << " REF: " << ref_exp << endl; + ref_exp -= full_exp; + acc_vec += ref_exp; + acc_obj += (log_z - log_ref_z); + } + if (feature_expectations) { + acc_obj += log( + InsideOutside, EdgeFeaturesWeightFunction>(forest, &ref_exp)); + acc_vec += ref_exp; + } + + if (output_training_vector) { + ++g_count; + if (g_count % combine_size == 0) { + if (encode_b64) { + cout << "0\t"; + B64::Encode(acc_obj, acc_vec, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + acc_vec.clear(); + acc_obj = 0; + } + } + if (conf.count("graphviz")) forest.PrintGraphviz(); + } else { + cerr << " REFERENCE UNREACHABLE.\n"; + if (write_gradient) { + if (hadoop_counters) + cerr << "reporter:counter:UserCounters,EFParseFailed,1" << endl; + cout << endl << flush; + } + } + } + } + if (output_training_vector && !acc_vec.empty()) { + if (encode_b64) { + cout << "0\t"; + B64::Encode(acc_obj, acc_vec, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + } +} + diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc new file mode 100644 index 00000000..86abfb4a --- /dev/null +++ b/src/cdec_ff.cc @@ -0,0 +1,18 @@ +#include + +#include "ff.h" +#include "lm_ff.h" +#include "ff_factory.h" +#include "ff_wordalign.h" + +boost::shared_ptr global_ff_registry; + +void register_feature_functions() { + global_ff_registry->Register("LanguageModel", new FFFactory); + global_ff_registry->Register("WordPenalty", new FFFactory); + global_ff_registry->Register("RelativeSentencePosition", new FFFactory); + global_ff_registry->Register("MarkovJump", new FFFactory); + global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory); + global_ff_registry->Register("AlignerResults", new FFFactory); +}; + diff --git a/src/collapse_weights.cc b/src/collapse_weights.cc new file mode 100644 index 00000000..5e0f3f72 --- /dev/null +++ b/src/collapse_weights.cc @@ -0,0 +1,102 @@ +#include +#include +#include + +#include +#include +#include + +#include "prob.h" +#include "filelib.h" +#include "trule.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +typedef std::tr1::unordered_map, prob_t, boost::hash > > MarginalMap; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("grammar,g", po::value(), "Grammar file") + ("weights,w", po::value(), "Weights file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + const string cfg = (*conf)["config"].as(); + cerr << "Configuration file: " << cfg << endl; + ifstream config(cfg.c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string wfile = conf["weights"].as(); + const string gfile = conf["grammar"].as(); + Weights wm; + wm.InitFromFile(wfile); + vector w; + wm.InitVector(&w); + MarginalMap e_tots; + MarginalMap f_tots; + prob_t tot; + { + ReadFile rf(gfile); + assert(*rf.stream()); + istream& in = *rf.stream(); + cerr << "Computing marginals...\n"; + int lc = 0; + while(in) { + string line; + getline(in, line); + ++lc; + if (line.empty()) continue; + TRule tr(line, true); + if (tr.GetFeatureValues().empty()) + cerr << "Line " << lc << ": empty features - may introduce bias\n"; + prob_t prob; + prob.logeq(tr.GetFeatureValues().dot(w)); + e_tots[tr.e_] += prob; + f_tots[tr.f_] += prob; + tot += prob; + } + } + bool normalized = (fabs(log(tot)) < 0.001); + cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl; + ReadFile rf(gfile); + istream&in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + TRule tr(line, true); + const double lp = tr.GetFeatureValues().dot(w); + if (isinf(lp)) { continue; } + tr.scores_.clear(); + + cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot); + if (!normalized) { + cout << ";ZF_and_E=" << lp; + } + cout << ";F_given_E=" << lp - log(e_tots[tr.e_]) + << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl; + } + return 0; +} + diff --git a/src/dict.h b/src/dict.h new file mode 100644 index 00000000..bae9debe --- /dev/null +++ b/src/dict.h @@ -0,0 +1,40 @@ +#ifndef DICT_H_ +#define DICT_H_ + +#include +#include +#include +#include +#include + +#include + +#include "wordid.h" + +class Dict { + typedef std::tr1::unordered_map > Map; + public: + Dict() : b0_("") { words_.reserve(1000); } + inline int max() const { return words_.size(); } + inline WordID Convert(const std::string& word) { + Map::iterator i = d_.find(word); + if (i == d_.end()) { + words_.push_back(word); + d_[word] = words_.size(); + return words_.size(); + } else { + return i->second; + } + } + inline const std::string& Convert(const WordID& id) const { + if (id == 0) return b0_; + assert(id <= words_.size()); + return words_[id-1]; + } + private: + const std::string b0_; + std::vector words_; + Map d_; +}; + +#endif diff --git a/src/dict_test.cc b/src/dict_test.cc new file mode 100644 index 00000000..5c5d84f0 --- /dev/null +++ b/src/dict_test.cc @@ -0,0 +1,30 @@ +#include "dict.h" + +#include +#include + +class DTest : public testing::Test { + public: + DTest() {} + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(DTest, Convert) { + Dict d; + WordID a = d.Convert("foo"); + WordID b = d.Convert("bar"); + std::string x = "foo"; + WordID c = d.Convert(x); + EXPECT_NE(a, b); + EXPECT_EQ(a, c); + EXPECT_EQ(d.Convert(a), "foo"); + EXPECT_EQ(d.Convert(b), "bar"); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/earley_composer.cc b/src/earley_composer.cc new file mode 100644 index 00000000..a59686e0 --- /dev/null +++ b/src/earley_composer.cc @@ -0,0 +1,726 @@ +#include "earley_composer.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "phrasetable_fst.h" +#include "sparse_vector.h" +#include "tdict.h" +#include "hg.h" + +using boost::shared_ptr; +namespace po = boost::program_options; +using namespace std; +using namespace std::tr1; + +// Define the following macro if you want to see lots of debugging output +// when you run the chart parser +#undef DEBUG_CHART_PARSER + +// A few constants used by the chart parser /////////////// +static const int kMAX_NODES = 2000000; +static const string kPHRASE_STRING = "X"; +static bool constants_need_init = true; +static WordID kUNIQUE_START; +static WordID kPHRASE; +static TRulePtr kX1X2; +static TRulePtr kX1; +static WordID kEPS; +static TRulePtr kEPSRule; + +static void InitializeConstants() { + if (constants_need_init) { + kPHRASE = TD::Convert(kPHRASE_STRING) * -1; + kUNIQUE_START = TD::Convert("S") * -1; + kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]")); + kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kEPS = TD::Convert(""); + constants_need_init = false; + } +} +//////////////////////////////////////////////////////////// + +class EGrammarNode { + friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + friend void AddGrammarRule(const string& r, map* g); + public: +#ifdef DEBUG_CHART_PARSER + string hint; +#endif + EGrammarNode() : is_some_rule_complete(false), is_root(false) {} + const map& GetTerminals() const { return tptr; } + const map& GetNonTerminals() const { return ntptr; } + bool HasNonTerminals() const { return (!ntptr.empty()); } + bool HasTerminals() const { return (!tptr.empty()); } + bool RuleCompletes() const { + return (is_some_rule_complete || (ntptr.empty() && tptr.empty())); + } + bool GrammarContinues() const { + return !(ntptr.empty() && tptr.empty()); + } + bool IsRoot() const { + return is_root; + } + // these are the features associated with the rule from the start + // node up to this point. If you use these features, you must + // not Extend() this rule. + const SparseVector& GetCFGProductionFeatures() const { + return input_features; + } + + const EGrammarNode* Extend(const WordID& t) const { + if (t < 0) { + map::const_iterator it = ntptr.find(t); + if (it == ntptr.end()) return NULL; + return &it->second; + } else { + map::const_iterator it = tptr.find(t); + if (it == tptr.end()) return NULL; + return &it->second; + } + } + + private: + map tptr; + map ntptr; + SparseVector input_features; + bool is_some_rule_complete; + bool is_root; +}; +typedef map EGrammar; // indexed by the rule LHS + +// edges are immutable once created +struct Edge { +#ifdef DEBUG_CHART_PARSER + static int id_count; + const int id; +#endif + const WordID cat; // lhs side of rule proved/being proved + const EGrammarNode* const dot; // dot position + const FSTNode* const q; // start of span + const FSTNode* const r; // end of span + const Edge* const active_parent; // back pointer, NULL for PREDICT items + const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items + const TargetPhraseSet* const tps; // translations + shared_ptr > features; // features from CFG rule + + bool IsPassive() const { + // when a rule is completed, this value will be set + return static_cast(features); + } + bool IsActive() const { return !IsPassive(); } + bool IsInitial() const { + return !(active_parent || passive_parent); + } + bool IsCreatedByScan() const { + return active_parent && !passive_parent && !dot->IsRoot(); + } + bool IsCreatedByPredict() const { + return dot->IsRoot(); + } + bool IsCreatedByComplete() const { + return active_parent && passive_parent; + } + + // constructor for PREDICT + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r, const Edge* act_parent) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps(NULL) {} + + // constructors for SCAN + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {} + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations, + const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations), + features(new SparseVector(feats)) {} + + // constructors for COMPLETE + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par, const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL), + features(new SparseVector(feats)) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + // constructor for COMPLETE query + Edge(const FSTNode* _r) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(NULL), + r(_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + // constructor for MERGE quere + Edge(const FSTNode* _q, int) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(_q), + r(NULL), active_parent(NULL), passive_parent(NULL), tps(NULL) {} +}; +#ifdef DEBUG_CHART_PARSER +int Edge::id_count = 0; +#endif + +ostream& operator<<(ostream& os, const Edge& e) { + string type = "PREDICT"; + if (e.IsCreatedByScan()) + type = "SCAN"; + else if (e.IsCreatedByComplete()) + type = "COMPLETE"; + os << "[" +#ifdef DEBUG_CHART_PARSER + << '(' << e.id << ") " +#else + << '(' << &e << ") " +#endif + << "q=" << e.q << ", r=" << e.r + << ", cat="<< TD::Convert(e.cat*-1) << ", dot=" + << e.dot +#ifdef DEBUG_CHART_PARSER + << e.dot->hint +#endif + << (e.IsActive() ? ", Active" : ", Passive") + << ", " << type; +#ifdef DEBUG_CHART_PARSER + if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; } + if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; } +#endif + if (e.tps) { os << ", tps=" << e.tps; } + return os << ']'; +} + +struct Traversal { + const Edge* const edge; // result from the active / passive combination + const Edge* const active; + const Edge* const passive; + Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {} +}; + +struct UniqueTraversalHash { + size_t operator()(const Traversal* t) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(t->active); + x = ((x << 5) + x) ^ reinterpret_cast(t->passive); + x = ((x << 5) + x) ^ t->edge->IsActive(); + return x; + } +}; + +struct UniqueTraversalEquals { + size_t operator()(const Traversal* a, const Traversal* b) const { + return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive()); + } +}; + +struct UniqueEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + if (e->IsActive()) { + x = ((x << 5) + x) ^ reinterpret_cast(e->dot); + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + x += 13; + } else { // with passive edges, we don't care about the dot + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + } + return x; + } +}; + +struct UniqueEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + if (a->IsActive() != b->IsActive()) return false; + if (a->IsActive()) { + return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r); + } else { + return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r); + } + } +}; + +struct REdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + return x; + } +}; + +struct REdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->r == b->r); + } +}; + +struct QEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + return x; + } +}; + +struct QEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->q == b->q); + } +}; + +struct EdgeQueue { + queue q; + EdgeQueue() {} + void clear() { while(!q.empty()) q.pop(); } + bool HasWork() const { return !q.empty(); } + const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; } + void AddEdge(const Edge* s) { q.push(s); } +}; + +class EarleyComposerImpl { + public: + EarleyComposerImpl(WordID start_cat, const FSTNode& q_0) : start_cat_(start_cat), q_0_(&q_0) {} + + // returns false if the intersection is empty + bool Compose(const EGrammar& g, Hypergraph* forest) { + goal_node = NULL; + EGrammar::const_iterator sit = g.find(start_cat_); + forest->ReserveNodes(kMAX_NODES); + assert(sit != g.end()); + Edge* init = new Edge(start_cat_, &sit->second, q_0_); + assert(IncorporateNewEdge(init)); + while (exp_agenda.HasWork() || agenda.HasWork()) { + while(exp_agenda.HasWork()) { + const Edge* edge = exp_agenda.Next(); + FinishEdge(edge, forest); + } + if (agenda.HasWork()) { + const Edge* edge = agenda.Next(); +#ifdef DEBUG_CHART_PARSER + cerr << "processing (" << edge->id << ')' << endl; +#endif + if (edge->IsActive()) { + if (edge->dot->HasTerminals()) + DoScan(edge); + if (edge->dot->HasNonTerminals()) { + DoMergeWithPassives(edge); + DoPredict(edge, g); + } + } else { + DoComplete(edge); + } + } + } + if (goal_node) { + forest->PruneUnreachable(goal_node->id_); + forest->EpsilonRemove(kEPS); + } + FreeAll(); + return goal_node; + } + + void FreeAll() { + for (int i = 0; i < free_list_.size(); ++i) + delete free_list_[i]; + free_list_.clear(); + for (int i = 0; i < traversal_free_list_.size(); ++i) + delete traversal_free_list_[i]; + traversal_free_list_.clear(); + all_traversals.clear(); + exp_agenda.clear(); + agenda.clear(); + tps2node.clear(); + edge2node.clear(); + all_edges.clear(); + passive_edges.clear(); + active_edges.clear(); + } + + ~EarleyComposerImpl() { + FreeAll(); + } + + // returns the total number of edges created during composition + int EdgesCreated() const { + return free_list_.size(); + } + + private: + void DoScan(const Edge* edge) { + // here, we assume that the FST will potentially have many more outgoing + // edges than the grammar, which will be just a couple. If you want to + // efficiently handle the case where both are relatively large, this code + // will need to change how the intersection is done. The best general + // solution would probably be the Baeza-Yates double binary search. + + const EGrammarNode* dot = edge->dot; + const FSTNode* r = edge->r; + const map& terms = dot->GetTerminals(); + for (map::const_iterator git = terms.begin(); + git != terms.end(); ++git) { + const FSTNode* next_r = r->Extend(git->first); + if (!next_r) continue; + const EGrammarNode* next_dot = &git->second; + const bool grammar_continues = next_dot->GrammarContinues(); + const bool rule_completes = next_dot->RuleCompletes(); + assert(grammar_continues || rule_completes); + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // create up to 4 new edges! + if (next_r->HasOutgoingNonEpsilonEdges()) { // are there further symbols in the FST? + const TargetPhraseSet* translations = NULL; + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations)); + } + if (next_r->HasData()) { // indicates a loop back to q_0 in the FST + const TargetPhraseSet* translations = next_r->GetTranslations(); + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations)); + } + } + } + + void DoPredict(const Edge* edge, const EGrammar& g) { + const EGrammarNode* dot = edge->dot; + const map& non_terms = dot->GetNonTerminals(); + for (map::const_iterator git = non_terms.begin(); + git != non_terms.end(); ++git) { + const WordID nt_to_predict = git->first; + //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl; + EGrammar::const_iterator egi = g.find(nt_to_predict); + if (egi == g.end()) { + cerr << "[ERROR] Can't find any grammar rules with a LHS of type " + << TD::Convert(-1*nt_to_predict) << '!' << endl; + continue; + } + assert(edge->IsActive()); + const EGrammarNode* new_dot = &egi->second; + Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge); + IncorporateNewEdge(new_edge); + } + } + + void DoComplete(const Edge* passive) { +#ifdef DEBUG_CHART_PARSER + cerr << " complete: " << *passive << endl; +#endif + const WordID completed_nt = passive->cat; + const FSTNode* q = passive->q; + const FSTNode* next_r = passive->r; + const Edge query(q); + const pair::iterator, + unordered_multiset::iterator > p = + active_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* active = *it; +#ifdef DEBUG_CHART_PARSER + cerr << " pos: " << *active << endl; +#endif + const EGrammarNode* next_dot = active->dot->Extend(completed_nt); + if (!next_dot) continue; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // add up to 2 rules + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + void DoMergeWithPassives(const Edge* active) { + // edge is active, has non-terminals, we need to find the passives that can extend it + assert(active->IsActive()); + assert(active->dot->HasNonTerminals()); +#ifdef DEBUG_CHART_PARSER + cerr << " merge active with passives: ACT=" << *active << endl; +#endif + const Edge query(active->r, 1); + const pair::iterator, + unordered_multiset::iterator > p = + passive_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* passive = *it; + const EGrammarNode* next_dot = active->dot->Extend(passive->cat); + if (!next_dot) continue; + const FSTNode* next_r = passive->r; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + // take ownership of edge memory, add to various indexes, etc + // returns true if this edge is new + bool IncorporateNewEdge(Edge* edge) { + free_list_.push_back(edge); + if (edge->passive_parent && edge->active_parent) { + Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent); + traversal_free_list_.push_back(t); + if (all_traversals.find(t) != all_traversals.end()) { + return false; + } else { + all_traversals.insert(t); + } + } + exp_agenda.AddEdge(edge); + return true; + } + + bool FinishEdge(const Edge* edge, Hypergraph* hg) { + bool is_new = false; + if (all_edges.find(edge) == all_edges.end()) { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NEW\n"; +#endif + all_edges.insert(edge); + is_new = true; + if (edge->IsPassive()) passive_edges.insert(edge); + if (edge->IsActive()) active_edges.insert(edge); + agenda.AddEdge(edge); + } else { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NOT NEW.\n"; +#endif + } + AddEdgeToTranslationForest(edge, hg); + return is_new; + } + + // build the translation forest + void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) { + assert(hg->nodes_.size() < kMAX_NODES); + Hypergraph::Node* tps = NULL; + // first add any target language rules + if (edge->tps) { + Hypergraph::Node*& node = tps2node[(size_t)edge->tps]; + if (!node) { + // cerr << "Creating phrases for " << edge->tps << endl; + const vector& rules = edge->tps->GetRules(); + node = hg->AddNode(kPHRASE, ""); + for (int i = 0; i < rules.size(); ++i) { + Hypergraph::Edge* hg_edge = hg->AddEdge(rules[i], Hypergraph::TailNodeVector()); + hg_edge->feature_values_ += rules[i]->GetFeatureValues(); + hg->ConnectEdgeToHeadNode(hg_edge, node); + } + } + tps = node; + } + Hypergraph::Node*& head_node = edge2node[edge]; + if (!head_node) + head_node = hg->AddNode(kPHRASE, ""); + if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { + assert(goal_node == NULL || goal_node == head_node); + goal_node = head_node; + } + Hypergraph::TailNodeVector tail; + SparseVector extra; + if (edge->IsCreatedByPredict()) { + // extra.set_value(FD::Convert("predict"), 1); + } else if (edge->IsCreatedByScan()) { + tail.push_back(edge2node[edge->active_parent]->id_); + if (tps) { + tail.push_back(tps->id_); + } + //extra.set_value(FD::Convert("scan"), 1); + } else if (edge->IsCreatedByComplete()) { + tail.push_back(edge2node[edge->active_parent]->id_); + tail.push_back(edge2node[edge->passive_parent]->id_); + //extra.set_value(FD::Convert("complete"), 1); + } else { + assert(!"unexpected edge type!"); + } + //cerr << head_node->id_ << "<--" << *edge << endl; + +#ifdef DEBUG_CHART_PARSER + for (int i = 0; i < tail.size(); ++i) + if (tail[i] == head_node->id_) { + cerr << "ERROR: " << *edge << "\n i=" << i << endl; + if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; } + if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; } + assert(!"self-loop found!"); + } +#endif + Hypergraph::Edge* hg_edge = NULL; + if (tail.size() == 0) { + hg_edge = hg->AddEdge(kEPSRule, tail); + } else if (tail.size() == 1) { + hg_edge = hg->AddEdge(kX1, tail); + } else if (tail.size() == 2) { + hg_edge = hg->AddEdge(kX1X2, tail); + } + if (edge->features) + hg_edge->feature_values_ += *edge->features; + hg_edge->feature_values_ += extra; + hg->ConnectEdgeToHeadNode(hg_edge, head_node); + } + + Hypergraph::Node* goal_node; + EdgeQueue exp_agenda; + EdgeQueue agenda; + unordered_map tps2node; + unordered_map edge2node; + unordered_set all_traversals; + unordered_set all_edges; + unordered_multiset passive_edges; + unordered_multiset active_edges; + vector free_list_; + vector traversal_free_list_; + const WordID start_cat_; + const FSTNode* const q_0_; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void AddGrammarRule(const string& r, EGrammar* g) { + const size_t pos = r.find(" ||| "); + if (pos == string::npos || r[0] != '[') { + cerr << "Bad rule: " << r << endl; + return; + } + const size_t rpos = r.rfind(" ||| "); + string feats; + string rs = r; + if (rpos != pos) { + feats = r.substr(rpos + 5); + rs = r.substr(0, rpos); + } + string rhs = rs.substr(pos + 5); + string trule = rs + " ||| " + rhs + " ||| " + feats; + TRule tr(trule); +#ifdef DEBUG_CHART_PARSER + string hint_last_rule; +#endif + EGrammarNode* cur = &(*g)[tr.GetLHS()]; + cur->is_root = true; + for (int i = 0; i < tr.FLength(); ++i) { + WordID sym = tr.f()[i]; +#ifdef DEBUG_CHART_PARSER + hint_last_rule = TD::Convert(sym < 0 ? -sym : sym); + cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString()); +#endif + if (sym < 0) + cur = &cur->ntptr[sym]; + else + cur = &cur->tptr[sym]; + } +#ifdef DEBUG_CHART_PARSER + cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString()); +#endif + cur->is_some_rule_complete = true; + cur->input_features = tr.GetFeatureValues(); +} + +EarleyComposer::~EarleyComposer() { + delete pimpl_; +} + +EarleyComposer::EarleyComposer(const FSTNode* fst) { + InitializeConstants(); + pimpl_ = new EarleyComposerImpl(kUNIQUE_START, *fst); +} + +bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) { + // first, convert the src forest into an EGrammar + EGrammar g; + const int nedges = src_forest.edges_.size(); + const int nnodes = src_forest.nodes_.size(); + vector cats(nnodes); + bool assign_cats = false; + for (int i = 0; i < nnodes; ++i) + if (assign_cats) { + cats[i] = TD::Convert("CAT_" + boost::lexical_cast(i)) * -1; + } else { + cats[i] = src_forest.nodes_[i].cat_; + } + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = src_forest.edges_[i]; + const vector& src = edge.rule_->f(); + EGrammarNode* cur = &g[cats[edge.head_node_]]; + cur->is_root = true; + int ntc = 0; + for (int j = 0; j < src.size(); ++j) { + WordID sym = src[j]; + if (sym <= 0) { + sym = cats[edge.tail_nodes_[ntc]]; + ++ntc; + cur = &cur->ntptr[sym]; + } else { + cur = &cur->tptr[sym]; + } + } + cur->is_some_rule_complete = true; + cur->input_features = edge.feature_values_; + } + EGrammarNode& goal_rule = g[kUNIQUE_START]; + assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) || + (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1)); + + return pimpl_->Compose(g, trg_forest); +} + +bool EarleyComposer::Compose(istream* in, Hypergraph* trg_forest) { + EGrammar g; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + AddGrammarRule(line, &g); + } + + return pimpl_->Compose(g, trg_forest); +} diff --git a/src/earley_composer.h b/src/earley_composer.h new file mode 100644 index 00000000..9f786bf6 --- /dev/null +++ b/src/earley_composer.h @@ -0,0 +1,29 @@ +#ifndef _EARLEY_COMPOSER_H_ +#define _EARLEY_COMPOSER_H_ + +#include + +class EarleyComposerImpl; +class FSTNode; +class Hypergraph; + +class EarleyComposer { + public: + ~EarleyComposer(); + EarleyComposer(const FSTNode* phrasetable_root); + bool Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + + // reads the grammar from a file. There must be a single top-level + // S -> X rule. Anything else is possible. Format is: + // [S] ||| [SS,1] + // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3 + // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8 + // [NP] ||| [DET,1] [N,2] ||| Feature3=2 + // ... + bool Compose(std::istream* grammar_file, Hypergraph* trg_forest); + + private: + EarleyComposerImpl* pimpl_; +}; + +#endif diff --git a/src/exp_semiring.h b/src/exp_semiring.h new file mode 100644 index 00000000..f91beee4 --- /dev/null +++ b/src/exp_semiring.h @@ -0,0 +1,71 @@ +#ifndef _EXP_SEMIRING_H_ +#define _EXP_SEMIRING_H_ + +#include + +// this file implements the first-order expectation semiring described +// in Li & Eisner (EMNLP 2009) + +// requirements: +// RType * RType ==> RType +// PType * PType ==> PType +// RType * PType ==> RType +// good examples: +// PType scalar, RType vector +// BAD examples: +// PType vector, RType scalar +template +struct PRPair { + PRPair() : p(), r() {} + // Inside algorithm requires that T(0) and T(1) + // return the 0 and 1 values of the semiring + explicit PRPair(double x) : p(x), r() {} + PRPair(const PType& p, const RType& r) : p(p), r(r) {} + PRPair& operator+=(const PRPair& o) { + p += o.p; + r += o.r; + return *this; + } + PRPair& operator*=(const PRPair& o) { + r = (o.r * p) + (o.p * r); + p *= o.p; + return *this; + } + PType p; + RType r; +}; + +template +std::ostream& operator<<(std::ostream& o, const PRPair& x) { + return o << '<' << x.p << ", " << x.r << '>'; +} + +template +const PRPair operator+(const PRPair& a, const PRPair& b) { + PRPair result = a; + result += b; + return result; +} + +template +const PRPair operator*(const PRPair& a, const PRPair& b) { + PRPair result = a; + result *= b; + return result; +} + +template +struct PRWeightFunction { + explicit PRWeightFunction(const PWeightFunction& pwf = PWeightFunction(), + const RWeightFunction& rwf = RWeightFunction()) : + pweight(pwf), rweight(rwf) {} + PRPair operator()(const Hypergraph::Edge& e) const { + const P p = pweight(e); + const R r = rweight(e); + return PRPair(p, r * p); + } + const PWeightFunction pweight; + const RWeightFunction rweight; +}; + +#endif diff --git a/src/fdict.cc b/src/fdict.cc new file mode 100644 index 00000000..83aa7cea --- /dev/null +++ b/src/fdict.cc @@ -0,0 +1,4 @@ +#include "fdict.h" + +Dict FD::dict_; + diff --git a/src/fdict.h b/src/fdict.h new file mode 100644 index 00000000..ff491cfb --- /dev/null +++ b/src/fdict.h @@ -0,0 +1,21 @@ +#ifndef _FDICT_H_ +#define _FDICT_H_ + +#include +#include +#include "dict.h" + +struct FD { + static Dict dict_; + static inline int NumFeats() { + return dict_.max() + 1; + } + static inline WordID Convert(const std::string& s) { + return dict_.Convert(s); + } + static inline const std::string& Convert(const WordID& w) { + return dict_.Convert(w); + } +}; + +#endif diff --git a/src/ff.cc b/src/ff.cc new file mode 100644 index 00000000..488e6468 --- /dev/null +++ b/src/ff.cc @@ -0,0 +1,93 @@ +#include "ff.h" + +#include "tdict.h" +#include "hg.h" + +using namespace std; + +FeatureFunction::~FeatureFunction() {} + + +void FeatureFunction::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + (void) ant_state; + (void) features; +} + +// Hiero and Joshua use log_10(e) as the value, so I do to +WordPenalty::WordPenalty(const string& param) : + fid_(FD::Convert("WordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning WordPenalty ignoring parameter: " << param << endl; + } +} + +void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->EWords() * value_); +} + +ModelSet::ModelSet(const vector& w, const vector& models) : + models_(models), + weights_(w), + state_size_(0), + model_state_pos_(models.size()) { + for (int i = 0; i < models_.size(); ++i) { + model_state_pos_[i] = state_size_; + state_size_ += models_[i]->NumBytesContext(); + } +} + +void ModelSet::AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + Hypergraph::Edge* edge, + string* context, + prob_t* combination_cost_estimate) const { + context->resize(state_size_); + memset(&(*context)[0], 0, state_size_); + SparseVector est_vals; // only computed if combination_cost_estimate is non-NULL + if (combination_cost_estimate) *combination_cost_estimate = prob_t::One(); + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + void* cur_ff_context = NULL; + vector ants(edge->tail_nodes_.size()); + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + cur_ff_context = &(*context)[spos]; + for (int i = 0; i < ants.size(); ++i) { + ants[i] = &hg.nodes_[edge->tail_nodes_[i]].state_[spos]; + } + } + ff.TraversalFeatures(smeta, *edge, ants, &edge->feature_values_, &est_vals, cur_ff_context); + } + if (combination_cost_estimate) + combination_cost_estimate->logeq(est_vals.dot(weights_)); + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + +void ModelSet::AddFinalFeatures(const std::string& state, Hypergraph::Edge* edge) const { + assert(1 == edge->rule_->Arity()); + + for (int i = 0; i < models_.size(); ++i) { + const FeatureFunction& ff = *models_[i]; + const void* ant_state = NULL; + bool has_context = ff.NumBytesContext() > 0; + if (has_context) { + int spos = model_state_pos_[i]; + ant_state = &state[spos]; + } + ff.FinalTraversalFeatures(ant_state, &edge->feature_values_); + } + edge->edge_prob_.logeq(edge->feature_values_.dot(weights_)); +} + diff --git a/src/ff.h b/src/ff.h new file mode 100644 index 00000000..c97e2fe2 --- /dev/null +++ b/src/ff.h @@ -0,0 +1,121 @@ +#ifndef _FF_H_ +#define _FF_H_ + +#include + +#include "fdict.h" +#include "hg.h" + +class SentenceMetadata; +class FeatureFunction; // see definition below + +// if you want to develop a new feature, inherit from this class and +// override TraversalFeaturesImpl(...). If it's a feature that returns / +// depends on context, you may also need to implement +// FinalTraversalFeatures(...) +class FeatureFunction { + public: + FeatureFunction() : state_size_() {} + explicit FeatureFunction(int state_size) : state_size_(state_size) {} + virtual ~FeatureFunction(); + + // returns the number of bytes of context that this feature function will + // (maximally) use. By default, 0 ("stateless" models in Hiero/Joshua). + // NOTE: this value is fixed for the instance of your class, you cannot + // use different amounts of memory for different nodes in the forest. + inline int NumBytesContext() const { return state_size_; } + + // Compute the feature values and (if this applies) the estimates of the + // feature values when this edge is used incorporated into a larger context + inline void TraversalFeatures(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_state) const { + TraversalFeaturesImpl(smeta, edge, ant_contexts, + features, estimated_features, out_state); + // TODO it's easy for careless feature function developers to overwrite + // the end of their state and clobber someone else's memory. These bugs + // will be horrendously painful to track down. There should be some + // optional strict mode that's enforced here that adds some kind of + // barrier between the blocks reserved for the residual contexts + } + + // if there's some state left when you transition to the goal state, score + // it here. For example, the language model computes the cost of adding + // and . + virtual void FinalTraversalFeatures(const void* residual_state, + SparseVector* final_features) const; + + protected: + // context is a pointer to a buffer of size NumBytesContext() that the + // feature function can write its state to. It's up to the feature function + // to determine how much space it needs and to determine how to encode its + // residual contextual information since it is OPAQUE to all clients outside + // of the particular FeatureFunction class. There is one exception: + // equality of the contents (i.e., memcmp) is required to determine whether + // two states can be combined. + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const = 0; + + // !!! ONLY call this from subclass *CONSTRUCTORS* !!! + void SetStateSize(size_t state_size) { + state_size_ = state_size; + } + + private: + int state_size_; +}; + +// word penalty feature, for each word on the E side of a rule, +// add value_ +class WordPenalty : public FeatureFunction { + public: + WordPenalty(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +// this class is a set of FeatureFunctions that can be used to score, rescore, +// etc. a (translation?) forest +class ModelSet { + public: + ModelSet() : state_size_(0) {} + + ModelSet(const std::vector& weights, + const std::vector& models); + + // sets edge->feature_values_ and edge->edge_prob_ + // NOTE: edge must not necessarily be in hg.edges_ but its TAIL nodes + // must be. + void AddFeaturesToEdge(const SentenceMetadata& smeta, + const Hypergraph& hg, + Hypergraph::Edge* edge, + std::string* residual_context, + prob_t* combination_cost_estimate = NULL) const; + + void AddFinalFeatures(const std::string& residual_context, + Hypergraph::Edge* edge) const; + + bool empty() const { return models_.empty(); } + private: + std::vector models_; + std::vector weights_; + int state_size_; + std::vector model_state_pos_; +}; + +#endif diff --git a/src/ff_factory.cc b/src/ff_factory.cc new file mode 100644 index 00000000..1854e0bb --- /dev/null +++ b/src/ff_factory.cc @@ -0,0 +1,35 @@ +#include "ff_factory.h" + +#include "ff.h" + +using boost::shared_ptr; +using namespace std; + +FFFactoryBase::~FFFactoryBase() {} + +void FFRegistry::DisplayList() const { + for (map >::const_iterator it = reg_.begin(); + it != reg_.end(); ++it) { + cerr << " " << it->first << endl; + } +} + +shared_ptr FFRegistry::Create(const string& ffname, const string& param) const { + map >::const_iterator it = reg_.find(ffname); + shared_ptr res; + if (it == reg_.end()) { + cerr << "I don't know how to create feature " << ffname << endl; + } else { + res = it->second->Create(param); + } + return res; +} + +void FFRegistry::Register(const string& ffname, FFFactoryBase* factory) { + if (reg_.find(ffname) != reg_.end()) { + cerr << "Duplicate registration of FeatureFunction with name " << ffname << "!\n"; + abort(); + } + reg_[ffname].reset(factory); +} + diff --git a/src/ff_factory.h b/src/ff_factory.h new file mode 100644 index 00000000..bc586567 --- /dev/null +++ b/src/ff_factory.h @@ -0,0 +1,39 @@ +#ifndef _FF_FACTORY_H_ +#define _FF_FACTORY_H_ + +#include +#include +#include + +#include + +class FeatureFunction; +class FFRegistry; +class FFFactoryBase; +extern boost::shared_ptr global_ff_registry; + +class FFRegistry { + friend int main(int argc, char** argv); + friend class FFFactoryBase; + public: + boost::shared_ptr Create(const std::string& ffname, const std::string& param) const; + void DisplayList() const; + void Register(const std::string& ffname, FFFactoryBase* factory); + private: + FFRegistry() {} + std::map > reg_; +}; + +struct FFFactoryBase { + virtual ~FFFactoryBase(); + virtual boost::shared_ptr Create(const std::string& param) const = 0; +}; + +template +class FFFactory : public FFFactoryBase { + boost::shared_ptr Create(const std::string& param) const { + return boost::shared_ptr(new FF(param)); + } +}; + +#endif diff --git a/src/ff_itg_span.h b/src/ff_itg_span.h new file mode 100644 index 00000000..b990f86a --- /dev/null +++ b/src/ff_itg_span.h @@ -0,0 +1,7 @@ +#ifndef _FF_ITG_SPAN_H_ +#define _FF_ITG_SPAN_H_ + +class ITGSpanFeatures : public FeatureFunction { +}; + +#endif diff --git a/src/ff_test.cc b/src/ff_test.cc new file mode 100644 index 00000000..1c20f9ac --- /dev/null +++ b/src/ff_test.cc @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include "hg.h" +#include "lm_ff.h" +#include "ff.h" +#include "trule.h" +#include "sentence_metadata.h" + +using namespace std; + +LanguageModel* lm_ = NULL; +LanguageModel* lm3_ = NULL; + +class FFTest : public testing::Test { + public: + FFTest() : smeta(0,Lattice()) { + if (!lm_) { + static LanguageModel slm("-o 2 ./test_data/brown.lm.gz"); + lm_ = &slm; + static LanguageModel slm3("./test_data/dummy.3gram.lm -o 3"); + lm3_ = &slm3; + } + } + protected: + virtual void SetUp() { } + virtual void TearDown() { } + SentenceMetadata smeta; +}; + +TEST_F(FFTest,LanguageModel) { + vector ms(1, lm_); + TRulePtr tr1(new TRule("[X] ||| [X,1] said")); + TRulePtr tr2(new TRule("[X] ||| the man said")); + TRulePtr tr3(new TRule("[X] ||| the fat man")); + Hypergraph hg; + const int lm_fid = FD::Convert("LanguageModel"); + vector w(lm_fid + 1,1); + ModelSet models(w, ms); + string state; + Hypergraph::Edge edge; + edge.rule_ = tr2; + models.AddFeaturesToEdge(smeta, hg, &edge, &state); + double lm1 = edge.feature_values_.dot(w); + cerr << "lm=" << edge.feature_values_[lm_fid] << endl; + + hg.nodes_.resize(1); + hg.edges_.resize(2); + hg.edges_[0].rule_ = tr3; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_); + hg.edges_[1].tail_nodes_.push_back(0); + hg.edges_[1].rule_ = tr1; + string state2; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2); + double tot = hg.edges_[1].feature_values_[lm_fid] + hg.edges_[0].feature_values_[lm_fid]; + cerr << "lm=" << tot << endl; + EXPECT_TRUE(state2 == state); + EXPECT_FALSE(state == hg.nodes_[0].state_); +} + +TEST_F(FFTest, Small) { + WordPenalty wp(""); + vector ms(2, lm_); + ms[1] = ℘ + TRulePtr tr1(new TRule("[X] ||| [X,1] said")); + TRulePtr tr2(new TRule("[X] ||| john said")); + TRulePtr tr3(new TRule("[X] ||| john")); + cerr << "RULE: " << tr1->AsString() << endl; + Hypergraph hg; + vector w(2); w[0]=1.0; w[1]=-2.0; + ModelSet models(w, ms); + string state; + Hypergraph::Edge edge; + edge.rule_ = tr2; + cerr << tr2->AsString() << endl; + models.AddFeaturesToEdge(smeta, hg, &edge, &state); + double s1 = edge.feature_values_.dot(w); + cerr << "lm=" << edge.feature_values_[0] << endl; + cerr << "wp=" << edge.feature_values_[1] << endl; + + hg.nodes_.resize(1); + hg.edges_.resize(2); + hg.edges_[0].rule_ = tr3; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[0], &hg.nodes_[0].state_); + double acc = hg.edges_[0].feature_values_.dot(w); + cerr << hg.edges_[0].feature_values_[0] << endl; + hg.edges_[1].tail_nodes_.push_back(0); + hg.edges_[1].rule_ = tr1; + string state2; + models.AddFeaturesToEdge(smeta, hg, &hg.edges_[1], &state2); + acc += hg.edges_[1].feature_values_.dot(w); + double tot = hg.edges_[1].feature_values_[0] + hg.edges_[0].feature_values_[0]; + cerr << "lm=" << tot << endl; + cerr << "acc=" << acc << endl; + cerr << " s1=" << s1 << endl; + EXPECT_TRUE(state2 == state); + EXPECT_FALSE(state == hg.nodes_[0].state_); + EXPECT_FLOAT_EQ(acc, s1); +} + +TEST_F(FFTest, LM3) { + int x = lm3_->NumBytesContext(); + Hypergraph::Edge edge1; + edge1.rule_.reset(new TRule("[X] ||| x y ||| one ||| 1.0 -2.4 3.0")); + Hypergraph::Edge edge2; + edge2.rule_.reset(new TRule("[X] ||| [X,1] a ||| [X,1] two ||| 1.0 -2.4 3.0")); + Hypergraph::Edge edge3; + edge3.rule_.reset(new TRule("[X] ||| [X,1] a ||| zero [X,1] two ||| 1.0 -2.4 3.0")); + vector ants1; + string state(x, '\0'); + SparseVector feats; + SparseVector est; + lm3_->TraversalFeatures(smeta, edge1, ants1, &feats, &est, (void *)&state[0]); + cerr << "returned " << feats << endl; + cerr << edge1.feature_values_ << endl; + cerr << lm3_->DebugStateToString((const void*)&state[0]) << endl; + EXPECT_EQ("[ one ]", lm3_->DebugStateToString((const void*)&state[0])); + ants1.push_back((const void*)&state[0]); + string state2(x, '\0'); + lm3_->TraversalFeatures(smeta, edge2, ants1, &feats, &est, (void *)&state2[0]); + cerr << lm3_->DebugStateToString((const void*)&state2[0]) << endl; + EXPECT_EQ("[ one two ]", lm3_->DebugStateToString((const void*)&state2[0])); + string state3(x, '\0'); + lm3_->TraversalFeatures(smeta, edge3, ants1, &feats, &est, (void *)&state3[0]); + cerr << lm3_->DebugStateToString((const void*)&state3[0]) << endl; + EXPECT_EQ("[ zero one <{STAR}> one two ]", lm3_->DebugStateToString((const void*)&state3[0])); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ff_wordalign.cc b/src/ff_wordalign.cc new file mode 100644 index 00000000..e605ac8d --- /dev/null +++ b/src/ff_wordalign.cc @@ -0,0 +1,221 @@ +#include "ff_wordalign.h" + +#include +#include + +#include "stringlib.h" +#include "sentence_metadata.h" +#include "hg.h" +#include "fdict.h" +#include "aligner.h" +#include "tdict.h" // Blunsom hack +#include "filelib.h" // Blunsom hack + +using namespace std; + +RelativeSentencePosition::RelativeSentencePosition(const string& param) : + fid_(FD::Convert("RelativeSentencePosition")) {} + +void RelativeSentencePosition::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + // if the source word is either null or the generated word + // has no position in the reference + if (edge.i_ == -1 || edge.prev_i_ == -1) + return; + + assert(smeta.GetTargetLength() > 0); + const double val = fabs(static_cast(edge.i_) / smeta.GetSourceLength() - + static_cast(edge.prev_i_) / smeta.GetTargetLength()); + features->set_value(fid_, val); +// cerr << f_len_ << " " << e_len_ << " [" << edge.i_ << "," << edge.j_ << "|" << edge.prev_i_ << "," << edge.prev_j_ << "]\t" << edge.rule_->AsString() << "\tVAL=" << val << endl; +} + +MarkovJump::MarkovJump(const string& param) : + FeatureFunction(1), + fid_(FD::Convert("MarkovJump")), + individual_params_per_jumpsize_(false), + condition_on_flen_(false) { + cerr << " MarkovJump: Blunsom&Cohn feature"; + vector argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc > 0) { + if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) { + cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n"; + exit(1); + } + individual_params_per_jumpsize_ = (argv[0][1] == 'i'); + condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f'); + if (individual_params_per_jumpsize_) { + template_ = "Jump:000"; + cerr << ", individual jump parameters"; + if (condition_on_flen_) { + template_ += ":F00"; + cerr << " (split by f-length)"; + } + } + } + cerr << endl; +} + +void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + unsigned char& dpstate = *((unsigned char*)state); + if (edge.Arity() == 0) { + dpstate = static_cast(edge.i_); + } else if (edge.Arity() == 1) { + dpstate = *((unsigned char*)ant_states[0]); + } else if (edge.Arity() == 2) { + int left_index = *((unsigned char*)ant_states[0]); + int right_index = *((unsigned char*)ant_states[1]); + if (right_index == -1) + dpstate = static_cast(left_index); + else + dpstate = static_cast(right_index); + const int jumpsize = right_index - left_index; + features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def + + if (individual_params_per_jumpsize_) { + string fname = template_; + int param = jumpsize; + if (jumpsize < 0) { + param *= -1; + fname[5]='L'; + } else if (jumpsize > 0) { + fname[5]='R'; + } + if (param) { + fname[6] = '0' + (param / 10); + fname[7] = '0' + (param % 10); + } + if (condition_on_flen_) { + const int flen = smeta.GetSourceLength(); + fname[10] = '0' + (flen / 10); + fname[11] = '0' + (flen % 10); + } + features->set_value(FD::Convert(fname), 1.0); + } + } else { + assert(!"something really unexpected is happening"); + } +} + +AlignerResults::AlignerResults(const std::string& param) : + cur_sent_(-1), + cur_grid_(NULL) { + vector argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc != 2) { + cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n"; + exit(1); + } + cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl; + fid_ = FD::Convert(argv[0]); + ReadFile rf(argv[1]); + istream& in = *rf.stream(); int lc = 0; + while(in) { + string line; + getline(in, line); + if (!in) break; + ++lc; + is_aligned_.push_back(AlignerTools::ReadPharaohAlignmentGrid(line)); + } + cerr << " Loaded " << lc << " refs\n"; +} + +void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + if (edge.i_ == -1 || edge.prev_i_ == -1) + return; + + if (cur_sent_ != smeta.GetSentenceID()) { + assert(smeta.HasReference()); + cur_sent_ = smeta.GetSentenceID(); + assert(cur_sent_ < is_aligned_.size()); + cur_grid_ = is_aligned_[cur_sent_].get(); + } + + //cerr << edge.rule_->AsString() << endl; + + int j = edge.i_; // source side (f) + int i = edge.prev_i_; // target side (e) + if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) { +// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) { + features->set_value(fid_, 1.0); +// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n"; +// } + } +} + +BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) : + FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) { + ReadFile rf(param); + istream& in = *rf.stream(); int lc = 0; + while(in) { + string line; + getline(in, line); + if (!in) break; + ++lc; + refs_.push_back(vector()); + TD::ConvertSentence(line, &refs_.back()); + } + cerr << " Loaded " << lc << " refs\n"; +} + +void BlunsomSynchronousParseHack::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + if (cur_sent_ != smeta.GetSentenceID()) { + // assert(smeta.HasReference()); + cur_sent_ = smeta.GetSentenceID(); + assert(cur_sent_ < refs_.size()); + cur_ref_ = &refs_[cur_sent_]; + cur_map_.clear(); + for (int i = 0; i < cur_ref_->size(); ++i) { + vector phrase; + for (int j = i; j < cur_ref_->size(); ++j) { + phrase.push_back((*cur_ref_)[j]); + cur_map_[phrase] = i; + } + } + } + //cerr << edge.rule_->AsString() << endl; + for (int i = 0; i < ant_states.size(); ++i) { + if (DoesNotBelong(ant_states[i])) { + //cerr << " ant " << i << " does not belong\n"; + return; + } + } + vector > ants(ant_states.size()); + vector* > pants(ant_states.size()); + for (int i = 0; i < ant_states.size(); ++i) { + AppendAntecedentString(ant_states[i], &ants[i]); + //cerr << " ant[" << i << "]: " << ((int)*(static_cast(ant_states[i]))) << " " << TD::GetString(ants[i]) << endl; + pants[i] = &ants[i]; + } + vector yield; + edge.rule_->ESubstitute(pants, &yield); + //cerr << "YIELD: " << TD::GetString(yield) << endl; + Vec2Int::iterator it = cur_map_.find(yield); + if (it == cur_map_.end()) { + features->set_value(fid_, 1); + //cerr << " BAD!\n"; + return; + } + SetStateMask(it->second, it->second + yield.size(), state); +} + diff --git a/src/ff_wordalign.h b/src/ff_wordalign.h new file mode 100644 index 00000000..1581641c --- /dev/null +++ b/src/ff_wordalign.h @@ -0,0 +1,133 @@ +#ifndef _FF_WORD_ALIGN_H_ +#define _FF_WORD_ALIGN_H_ + +#include "ff.h" +#include "array2d.h" + +class RelativeSentencePosition : public FeatureFunction { + public: + RelativeSentencePosition(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; +}; + +class MarkovJump : public FeatureFunction { + public: + MarkovJump(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; + bool individual_params_per_jumpsize_; + bool condition_on_flen_; + std::string template_; +}; + +class AlignerResults : public FeatureFunction { + public: + AlignerResults(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + int fid_; + std::vector > > is_aligned_; + mutable int cur_sent_; + const Array2D mutable* cur_grid_; +}; + +#include +#include +#include +class BlunsomSynchronousParseHack : public FeatureFunction { + public: + BlunsomSynchronousParseHack(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + inline bool DoesNotBelong(const void* state) const { + for (int i = 0; i < NumBytesContext(); ++i) { + if (*(static_cast(state) + i)) return false; + } + return true; + } + + inline void AppendAntecedentString(const void* state, std::vector* yield) const { + int i = 0; + int ind = 0; + while (i < NumBytesContext() && !(*(static_cast(state) + i))) { ++i; ind += 8; } + // std::cerr << i << " " << NumBytesContext() << std::endl; + assert(i != NumBytesContext()); + assert(ind < cur_ref_->size()); + int cur = *(static_cast(state) + i); + int comp = 1; + while (comp < 256 && (comp & cur) == 0) { comp <<= 1; ++ind; } + assert(ind < cur_ref_->size()); + assert(comp < 256); + do { + assert(ind < cur_ref_->size()); + yield->push_back((*cur_ref_)[ind]); + ++ind; + comp <<= 1; + if (comp == 256) { + comp = 1; + ++i; + cur = *(static_cast(state) + i); + } + } while (comp & cur); + } + + inline void SetStateMask(int start, int end, void* state) const { + assert((end / 8) < NumBytesContext()); + int i = 0; + int comp = 1; + for (int j = 0; j < start; ++j) { + comp <<= 1; + if (comp == 256) { + ++i; + comp = 1; + } + } + //std::cerr << "SM: " << i << "\n"; + for (int j = start; j < end; ++j) { + *(static_cast(state) + i) |= comp; + //std::cerr << " " << comp << "\n"; + comp <<= 1; + if (comp == 256) { + ++i; + comp = 1; + } + } + //std::cerr << " MASK: " << ((int)*(static_cast(state))) << "\n"; + } + + const int fid_; + mutable int cur_sent_; + typedef std::tr1::unordered_map, int, boost::hash > > Vec2Int; + mutable Vec2Int cur_map_; + const std::vector mutable * cur_ref_; + mutable std::vector > refs_; +}; + +#endif diff --git a/src/filelib.cc b/src/filelib.cc new file mode 100644 index 00000000..79ad2847 --- /dev/null +++ b/src/filelib.cc @@ -0,0 +1,22 @@ +#include "filelib.h" + +#include +#include + +using namespace std; + +bool FileExists(const std::string& fn) { + struct stat info; + int s = stat(fn.c_str(), &info); + return (s==0); +} + +bool DirectoryExists(const string& dir) { + if (access(dir.c_str(),0) == 0) { + struct stat status; + stat(dir.c_str(), &status); + if (status.st_mode & S_IFDIR) return true; + } + return false; +} + diff --git a/src/filelib.h b/src/filelib.h new file mode 100644 index 00000000..62cb9427 --- /dev/null +++ b/src/filelib.h @@ -0,0 +1,66 @@ +#ifndef _FILELIB_H_ +#define _FILELIB_H_ + +#include +#include +#include +#include +#include "gzstream.h" + +// reads from standard in if filename is - +// uncompresses if file ends with .gz +// otherwise, reads from a normal file +class ReadFile { + public: + ReadFile(const std::string& filename) : + no_delete_on_exit_(filename == "-"), + in_(no_delete_on_exit_ ? static_cast(&std::cin) : + (EndsWith(filename, ".gz") ? + static_cast(new igzstream(filename.c_str())) : + static_cast(new std::ifstream(filename.c_str())))) { + if (!*in_) { + std::cerr << "Failed to open " << filename << std::endl; + abort(); + } + } + ~ReadFile() { + if (!no_delete_on_exit_) delete in_; + } + + inline std::istream* stream() { return in_; } + + private: + static bool EndsWith(const std::string& f, const std::string& suf) { + return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); + } + const bool no_delete_on_exit_; + std::istream* const in_; +}; + +class WriteFile { + public: + WriteFile(const std::string& filename) : + no_delete_on_exit_(filename == "-"), + out_(no_delete_on_exit_ ? static_cast(&std::cout) : + (EndsWith(filename, ".gz") ? + static_cast(new ogzstream(filename.c_str())) : + static_cast(new std::ofstream(filename.c_str())))) {} + ~WriteFile() { + (*out_) << std::flush; + if (!no_delete_on_exit_) delete out_; + } + + inline std::ostream* stream() { return out_; } + + private: + static bool EndsWith(const std::string& f, const std::string& suf) { + return (f.size() > suf.size()) && (f.rfind(suf) == f.size() - suf.size()); + } + const bool no_delete_on_exit_; + std::ostream* const out_; +}; + +bool FileExists(const std::string& file_name); +bool DirectoryExists(const std::string& dir_name); + +#endif diff --git a/src/forest_writer.cc b/src/forest_writer.cc new file mode 100644 index 00000000..a9117d18 --- /dev/null +++ b/src/forest_writer.cc @@ -0,0 +1,23 @@ +#include "forest_writer.h" + +#include + +#include + +#include "filelib.h" +#include "hg_io.h" +#include "hg.h" + +using namespace std; + +ForestWriter::ForestWriter(const std::string& path, int num) : + fname_(path + '/' + boost::lexical_cast(num) + ".json.gz"), used_(false) {} + +bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { + assert(!used_); + used_ = true; + cerr << " Writing forest to " << fname_ << endl; + WriteFile wf(fname_); + return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); +} + diff --git a/src/forest_writer.h b/src/forest_writer.h new file mode 100644 index 00000000..819a8940 --- /dev/null +++ b/src/forest_writer.h @@ -0,0 +1,16 @@ +#ifndef _FOREST_WRITER_H_ +#define _FOREST_WRITER_H_ + +#include + +class Hypergraph; + +struct ForestWriter { + ForestWriter(const std::string& path, int num); + bool Write(const Hypergraph& forest, bool minimal_rules); + + const std::string fname_; + bool used_; +}; + +#endif diff --git a/src/freqdict.cc b/src/freqdict.cc new file mode 100644 index 00000000..4cfffe58 --- /dev/null +++ b/src/freqdict.cc @@ -0,0 +1,23 @@ +#include +#include +#include +#include "freqdict.h" + +void FreqDict::load(const std::string& fname) { + std::ifstream ifs(fname.c_str()); + int cc=0; + while (!ifs.eof()) { + std::string word; + ifs >> word; + if (word.size() == 0) continue; + if (word[0] == '#') continue; + double count = 0; + ifs >> count; + assert(count > 0.0); // use -log(f) + counts_[word]=count; + ++cc; + if (cc % 10000 == 0) { std::cerr << "."; } + } + std::cerr << "\n"; + std::cerr << "Loaded " << cc << " words\n"; +} diff --git a/src/freqdict.h b/src/freqdict.h new file mode 100644 index 00000000..c9bb4c42 --- /dev/null +++ b/src/freqdict.h @@ -0,0 +1,19 @@ +#ifndef _FREQDICT_H_ +#define _FREQDICT_H_ + +#include +#include + +class FreqDict { + public: + void load(const std::string& fname); + float frequency(const std::string& word) const { + std::map::const_iterator i = counts_.find(word); + if (i == counts_.end()) return 0; + return i->second; + } + private: + std::map counts_; +}; + +#endif diff --git a/src/fst_translator.cc b/src/fst_translator.cc new file mode 100644 index 00000000..57feb227 --- /dev/null +++ b/src/fst_translator.cc @@ -0,0 +1,91 @@ +#include "translator.h" + +#include +#include + +#include "sentence_metadata.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "earley_composer.h" +#include "phrasetable_fst.h" +#include "tdict.h" + +using namespace std; + +struct FSTTranslatorImpl { + FSTTranslatorImpl(const boost::program_options::variables_map& conf) : + goal_sym(conf["goal"].as()), + kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")), + kGOAL(TD::Convert("Goal") * -1), + add_pass_through_rules(conf.count("add_pass_through_rules")) { + fst.reset(LoadTextPhrasetable(conf["grammar"].as >())); + ec.reset(new EarleyComposer(fst.get())); + } + + bool Translate(const string& input, + const vector& weights, + Hypergraph* forest) { + bool composed = false; + if (input.find("{\"rules\"") == 0) { + istringstream is(input); + Hypergraph src_cfg_hg; + assert(HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)); + if (add_pass_through_rules) { + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < src_cfg_hg.edges_.size(); ++i) { + const vector& f = src_cfg_hg.edges_[i].rule_->f_; + for (int j = 0; j < f.size(); ++j) { + if (f[j] > 0) { + fst->AddPassThroughTranslation(f[j], feats); + } + } + } + } + composed = ec->Compose(src_cfg_hg, forest); + } else { + const string dummy_grammar("[" + goal_sym + "] ||| " + input + " ||| TOP=1"); + cerr << " Dummy grammar: " << dummy_grammar << endl; + istringstream is(dummy_grammar); + if (add_pass_through_rules) { + vector words; + TD::ConvertSentence(input, &words); + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < words.size(); ++i) + fst->AddPassThroughTranslation(words[i], feats); + } + composed = ec->Compose(&is, forest); + } + if (composed) { + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1, ""); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->Reweight(weights); + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + return composed; + } + + const string goal_sym; + const TRulePtr kGOAL_RULE; + const WordID kGOAL; + const bool add_pass_through_rules; + boost::shared_ptr ec; + boost::shared_ptr fst; +}; + +FSTTranslator::FSTTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new FSTTranslatorImpl(conf)) {} + +bool FSTTranslator::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + smeta->SetSourceLength(0); // don't know how to compute this + return pimpl_->Translate(input, weights, minus_lm_forest); +} + diff --git a/src/grammar.cc b/src/grammar.cc new file mode 100644 index 00000000..69e38320 --- /dev/null +++ b/src/grammar.cc @@ -0,0 +1,163 @@ +#include "grammar.h" + +#include +#include +#include + +#include "filelib.h" +#include "tdict.h" + +using namespace std; + +const vector Grammar::NO_RULES; + +RuleBin::~RuleBin() {} +GrammarIter::~GrammarIter() {} +Grammar::~Grammar() {} + +bool Grammar::HasRuleForSpan(int i, int j) const { + (void) i; + (void) j; + return true; // always true by default +} + +struct TextRuleBin : public RuleBin { + int GetNumRules() const { + return rules_.size(); + } + TRulePtr GetIthRule(int i) const { + return rules_[i]; + } + void AddRule(TRulePtr t) { + rules_.push_back(t); + } + int Arity() const { + return rules_.front()->Arity(); + } + void Dump() const { + for (int i = 0; i < rules_.size(); ++i) + VLOG(1) << rules_[i]->AsString() << endl; + } + private: + vector rules_; +}; + +struct TextGrammarNode : public GrammarIter { + TextGrammarNode() : rb_(NULL) {} + ~TextGrammarNode() { + delete rb_; + } + const GrammarIter* Extend(int symbol) const { + map::const_iterator i = tree_.find(symbol); + if (i == tree_.end()) return NULL; + return &i->second; + } + + const RuleBin* GetRules() const { + if (rb_) { + //rb_->Dump(); + } + return rb_; + } + + map tree_; + TextRuleBin* rb_; +}; + +struct TGImpl { + TextGrammarNode root_; +}; + +TextGrammar::TextGrammar() : max_span_(10), pimpl_(new TGImpl) {} +TextGrammar::TextGrammar(const string& file) : + max_span_(10), + pimpl_(new TGImpl) { + ReadFromFile(file); +} + +const GrammarIter* TextGrammar::GetRoot() const { + return &pimpl_->root_; +} + +void TextGrammar::AddRule(const TRulePtr& rule) { + if (rule->IsUnary()) { + rhs2unaries_[rule->f().front()].push_back(rule); + unaries_.push_back(rule); + } else { + TextGrammarNode* cur = &pimpl_->root_; + for (int i = 0; i < rule->f_.size(); ++i) + cur = &cur->tree_[rule->f_[i]]; + if (cur->rb_ == NULL) + cur->rb_ = new TextRuleBin; + cur->rb_->AddRule(rule); + } +} + +void TextGrammar::ReadFromFile(const string& filename) { + ReadFile in(filename); + istream& in_file = *in.stream(); + assert(in_file); + long long int rule_count = 0; + bool fl = false; + while(in_file) { + string line; + getline(in_file, line); + if (line.empty()) continue; + ++rule_count; + if (rule_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (rule_count % 2000000 == 0) { cerr << " [" << rule_count << "]\n"; fl = false; } + TRulePtr rule(TRule::CreateRuleSynchronous(line)); + if (rule) { + AddRule(rule); + } else { + if (fl) { cerr << endl; } + cerr << "Skipping badly formatted rule in line " << rule_count << " of " << filename << endl; + fl = false; + } + } + if (fl) cerr << endl; + cerr << " " << rule_count << " rules read.\n"; +} + +bool TextGrammar::HasRuleForSpan(int i, int j) const { + return (max_span_ >= (j - i)); +} + +GlueGrammar::GlueGrammar(const string& file) : TextGrammar(file) {} + +GlueGrammar::GlueGrammar(const string& goal_nt, const string& default_nt) { + TRulePtr stop_glue(new TRule("[" + goal_nt + "] ||| [" + default_nt + ",1] ||| [" + default_nt + ",1]")); + TRulePtr glue(new TRule("[" + goal_nt + "] ||| [" + goal_nt + ",1] [" + + default_nt + ",2] ||| [" + goal_nt + ",1] [" + default_nt + ",2] ||| Glue=1")); + + AddRule(stop_glue); + AddRule(glue); + //cerr << "GLUE: " << stop_glue->AsString() << endl; + //cerr << "GLUE: " << glue->AsString() << endl; +} + +bool GlueGrammar::HasRuleForSpan(int i, int j) const { + (void) j; + return (i == 0); +} + +PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat) : + has_rule_(input.size() + 1) { + for (int i = 0; i < input.size(); ++i) { + const vector& alts = input[i]; + for (int k = 0; k < alts.size(); ++k) { + const int j = alts[k].dist2next + i; + has_rule_[i].insert(j); + const string& src = TD::Convert(alts[k].label); + TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); + AddRule(pt); +// cerr << "PT: " << pt->AsString() << endl; + } + } +} + +bool PassThroughGrammar::HasRuleForSpan(int i, int j) const { + const set& hr = has_rule_[i]; + if (i == j) { return !hr.empty(); } + return (hr.find(j) != hr.end()); +} diff --git a/src/grammar.h b/src/grammar.h new file mode 100644 index 00000000..4a03c505 --- /dev/null +++ b/src/grammar.h @@ -0,0 +1,83 @@ +#ifndef GRAMMAR_H_ +#define GRAMMAR_H_ + +#include +#include +#include +#include + +#include "lattice.h" +#include "trule.h" + +struct RuleBin { + virtual ~RuleBin(); + virtual int GetNumRules() const = 0; + virtual TRulePtr GetIthRule(int i) const = 0; + virtual int Arity() const = 0; +}; + +struct GrammarIter { + virtual ~GrammarIter(); + virtual const RuleBin* GetRules() const = 0; + virtual const GrammarIter* Extend(int symbol) const = 0; +}; + +struct Grammar { + typedef std::map > Cat2Rules; + static const std::vector NO_RULES; + + virtual ~Grammar(); + virtual const GrammarIter* GetRoot() const = 0; + virtual bool HasRuleForSpan(int i, int j) const; + + // cat is the category to be rewritten + inline const std::vector& GetAllUnaryRules() const { + return unaries_; + } + + // get all the unary rules that rewrite category cat + inline const std::vector& GetUnaryRulesForRHS(const WordID& cat) const { + Cat2Rules::const_iterator found = rhs2unaries_.find(cat); + if (found == rhs2unaries_.end()) + return NO_RULES; + else + return found->second; + } + + protected: + Cat2Rules rhs2unaries_; // these must be filled in by subclasses! + std::vector unaries_; +}; + +typedef boost::shared_ptr GrammarPtr; + +class TGImpl; +struct TextGrammar : public Grammar { + TextGrammar(); + TextGrammar(const std::string& file); + void SetMaxSpan(int m) { max_span_ = m; } + virtual const GrammarIter* GetRoot() const; + void AddRule(const TRulePtr& rule); + void ReadFromFile(const std::string& filename); + virtual bool HasRuleForSpan(int i, int j) const; + const std::vector& GetUnaryRules(const WordID& cat) const; + private: + int max_span_; + boost::shared_ptr pimpl_; +}; + +struct GlueGrammar : public TextGrammar { + // read glue grammar from file + explicit GlueGrammar(const std::string& file); + GlueGrammar(const std::string& goal_nt, const std::string& default_nt); // "S", "X" + virtual bool HasRuleForSpan(int i, int j) const; +}; + +struct PassThroughGrammar : public TextGrammar { + PassThroughGrammar(const Lattice& input, const std::string& cat); + virtual bool HasRuleForSpan(int i, int j) const; + private: + std::vector > has_rule_; // index by [i][j] +}; + +#endif diff --git a/src/grammar_test.cc b/src/grammar_test.cc new file mode 100644 index 00000000..62b8f958 --- /dev/null +++ b/src/grammar_test.cc @@ -0,0 +1,59 @@ +#include +#include +#include +#include +#include +#include "trule.h" +#include "tdict.h" +#include "grammar.h" +#include "bottom_up_parser.h" +#include "ff.h" +#include "weights.h" + +using namespace std; + +class GrammarTest : public testing::Test { + public: + GrammarTest() { + wts.InitFromFile("test_data/weights.gt"); + } + protected: + virtual void SetUp() { } + virtual void TearDown() { } + Weights wts; +}; + +TEST_F(GrammarTest,TestTextGrammar) { + vector w; + vector ms; + ModelSet models(w, ms); + + TextGrammar g; + TRulePtr r1(new TRule("[X] ||| a b c ||| A B C ||| 0.1 0.2 0.3", true)); + TRulePtr r2(new TRule("[X] ||| a b c ||| 1 2 3 ||| 0.2 0.3 0.4", true)); + TRulePtr r3(new TRule("[X] ||| a b c d ||| A B C D ||| 0.1 0.2 0.3", true)); + cerr << r1->AsString() << endl; + g.AddRule(r1); + g.AddRule(r2); + g.AddRule(r3); +} + +TEST_F(GrammarTest,TestTextGrammarFile) { + GrammarPtr g(new TextGrammar("./test_data/grammar.prune")); + vector grammars(1, g); + + LatticeArc a(TD::Convert("ein"), 0.0, 1); + LatticeArc b(TD::Convert("haus"), 0.0, 1); + Lattice lattice(2); + lattice[0].push_back(a); + lattice[1].push_back(b); + Hypergraph forest; + ExhaustiveBottomUpParser parser("PHRASE", grammars); + parser.Parse(lattice, &forest); + forest.PrintGraphviz(); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/gzstream.cc b/src/gzstream.cc new file mode 100644 index 00000000..9703e6ad --- /dev/null +++ b/src/gzstream.cc @@ -0,0 +1,165 @@ +// ============================================================================ +// gzstream, C++ iostream classes wrapping the zlib compression library. +// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// ============================================================================ +// +// File : gzstream.C +// Revision : $Revision: 1.1 $ +// Revision_date : $Date: 2006/03/30 04:05:52 $ +// Author(s) : Deepak Bandyopadhyay, Lutz Kettner +// +// Standard streambuf implementation following Nicolai Josuttis, "The +// Standard C++ Library". +// ============================================================================ + +#include "gzstream.h" +#include +#include + +#ifdef GZSTREAM_NAMESPACE +namespace GZSTREAM_NAMESPACE { +#endif + +// ---------------------------------------------------------------------------- +// Internal classes to implement gzstream. See header file for user classes. +// ---------------------------------------------------------------------------- + +// -------------------------------------- +// class gzstreambuf: +// -------------------------------------- + +gzstreambuf* gzstreambuf::open( const char* name, int open_mode) { + if ( is_open()) + return (gzstreambuf*)0; + mode = open_mode; + // no append nor read/write mode + if ((mode & std::ios::ate) || (mode & std::ios::app) + || ((mode & std::ios::in) && (mode & std::ios::out))) + return (gzstreambuf*)0; + char fmode[10]; + char* fmodeptr = fmode; + if ( mode & std::ios::in) + *fmodeptr++ = 'r'; + else if ( mode & std::ios::out) + *fmodeptr++ = 'w'; + *fmodeptr++ = 'b'; + *fmodeptr = '\0'; + file = gzopen( name, fmode); + if (file == 0) + return (gzstreambuf*)0; + opened = 1; + return this; +} + +gzstreambuf * gzstreambuf::close() { + if ( is_open()) { + sync(); + opened = 0; + if ( gzclose( file) == Z_OK) + return this; + } + return (gzstreambuf*)0; +} + +int gzstreambuf::underflow() { // used for input buffer only + if ( gptr() && ( gptr() < egptr())) + return * reinterpret_cast( gptr()); + + if ( ! (mode & std::ios::in) || ! opened) + return EOF; + // Josuttis' implementation of inbuf + int n_putback = gptr() - eback(); + if ( n_putback > 4) + n_putback = 4; + memcpy( buffer + (4 - n_putback), gptr() - n_putback, n_putback); + + int num = gzread( file, buffer+4, bufferSize-4); + if (num <= 0) // ERROR or EOF + return EOF; + + // reset buffer pointers + setg( buffer + (4 - n_putback), // beginning of putback area + buffer + 4, // read position + buffer + 4 + num); // end of buffer + + // return next character + return * reinterpret_cast( gptr()); +} + +int gzstreambuf::flush_buffer() { + // Separate the writing of the buffer from overflow() and + // sync() operation. + int w = pptr() - pbase(); + if ( gzwrite( file, pbase(), w) != w) + return EOF; + pbump( -w); + return w; +} + +int gzstreambuf::overflow( int c) { // used for output buffer only + if ( ! ( mode & std::ios::out) || ! opened) + return EOF; + if (c != EOF) { + *pptr() = c; + pbump(1); + } + if ( flush_buffer() == EOF) + return EOF; + return c; +} + +int gzstreambuf::sync() { + // Changed to use flush_buffer() instead of overflow( EOF) + // which caused improper behavior with std::endl and flush(), + // bug reported by Vincent Ricard. + if ( pptr() && pptr() > pbase()) { + if ( flush_buffer() == EOF) + return -1; + } + return 0; +} + +// -------------------------------------- +// class gzstreambase: +// -------------------------------------- + +gzstreambase::gzstreambase( const char* name, int mode) { + init( &buf); + open( name, mode); +} + +gzstreambase::~gzstreambase() { + buf.close(); +} + +void gzstreambase::open( const char* name, int open_mode) { + if ( ! buf.open( name, open_mode)) + clear( rdstate() | std::ios::badbit); +} + +void gzstreambase::close() { + if ( buf.is_open()) + if ( ! buf.close()) + clear( rdstate() | std::ios::badbit); +} + +#ifdef GZSTREAM_NAMESPACE +} // namespace GZSTREAM_NAMESPACE +#endif + +// ============================================================================ +// EOF // diff --git a/src/gzstream.h b/src/gzstream.h new file mode 100644 index 00000000..ad9785fd --- /dev/null +++ b/src/gzstream.h @@ -0,0 +1,121 @@ +// ============================================================================ +// gzstream, C++ iostream classes wrapping the zlib compression library. +// Copyright (C) 2001 Deepak Bandyopadhyay, Lutz Kettner +// +// This library is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 2.1 of the License, or (at your option) any later version. +// +// This library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License along with this library; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// ============================================================================ +// +// File : gzstream.h +// Revision : $Revision: 1.1 $ +// Revision_date : $Date: 2006/03/30 04:05:52 $ +// Author(s) : Deepak Bandyopadhyay, Lutz Kettner +// +// Standard streambuf implementation following Nicolai Josuttis, "The +// Standard C++ Library". +// ============================================================================ + +#ifndef GZSTREAM_H +#define GZSTREAM_H 1 + +// standard C++ with new header file names and std:: namespace +#include +#include +#include + +#ifdef GZSTREAM_NAMESPACE +namespace GZSTREAM_NAMESPACE { +#endif + +// ---------------------------------------------------------------------------- +// Internal classes to implement gzstream. See below for user classes. +// ---------------------------------------------------------------------------- + +class gzstreambuf : public std::streambuf { +private: + static const int bufferSize = 47+256; // size of data buff + // totals 512 bytes under g++ for igzstream at the end. + + gzFile file; // file handle for compressed file + char buffer[bufferSize]; // data buffer + char opened; // open/close state of stream + int mode; // I/O mode + + int flush_buffer(); +public: + gzstreambuf() : opened(0) { + setp( buffer, buffer + (bufferSize-1)); + setg( buffer + 4, // beginning of putback area + buffer + 4, // read position + buffer + 4); // end position + // ASSERT: both input & output capabilities will not be used together + } + int is_open() { return opened; } + gzstreambuf* open( const char* name, int open_mode); + gzstreambuf* close(); + ~gzstreambuf() { close(); } + + virtual int overflow( int c = EOF); + virtual int underflow(); + virtual int sync(); +}; + +class gzstreambase : virtual public std::ios { +protected: + gzstreambuf buf; +public: + gzstreambase() { init(&buf); } + gzstreambase( const char* name, int open_mode); + ~gzstreambase(); + void open( const char* name, int open_mode); + void close(); + gzstreambuf* rdbuf() { return &buf; } +}; + +// ---------------------------------------------------------------------------- +// User classes. Use igzstream and ogzstream analogously to ifstream and +// ofstream respectively. They read and write files based on the gz* +// function interface of the zlib. Files are compatible with gzip compression. +// ---------------------------------------------------------------------------- + +class igzstream : public gzstreambase, public std::istream { +public: + igzstream() : std::istream( &buf) {} + igzstream( const char* name, int open_mode = std::ios::in) + : gzstreambase( name, open_mode), std::istream( &buf) {} + gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); } + void open( const char* name, int open_mode = std::ios::in) { + gzstreambase::open( name, open_mode); + } +}; + +class ogzstream : public gzstreambase, public std::ostream { +public: + ogzstream() : std::ostream( &buf) {} + ogzstream( const char* name, int mode = std::ios::out) + : gzstreambase( name, mode), std::ostream( &buf) {} + gzstreambuf* rdbuf() { return gzstreambase::rdbuf(); } + void open( const char* name, int open_mode = std::ios::out) { + gzstreambase::open( name, open_mode); + } +}; + +#ifdef GZSTREAM_NAMESPACE +} // namespace GZSTREAM_NAMESPACE +#endif + +#endif // GZSTREAM_H +// ============================================================================ +// EOF // + diff --git a/src/hg.cc b/src/hg.cc new file mode 100644 index 00000000..dd8f8eba --- /dev/null +++ b/src/hg.cc @@ -0,0 +1,483 @@ +#include "hg.h" + +#include +#include +#include +#include +#include + +#include "viterbi.h" +#include "inside_outside.h" +#include "tdict.h" + +using namespace std; + +double Hypergraph::NumberOfPaths() const { + return Inside(*this); +} + +prob_t Hypergraph::ComputeEdgePosteriors(double scale, vector* posts) const { + const ScaledEdgeProb weight(scale); + SparseVector pv; + const double inside = InsideOutside, + EdgeFeaturesWeightFunction>(*this, &pv, weight); + posts->resize(edges_.size()); + for (int i = 0; i < edges_.size(); ++i) + (*posts)[i] = prob_t(pv.value(i)); + return prob_t(inside); +} + +prob_t Hypergraph::ComputeBestPathThroughEdges(vector* post) const { + vector in(edges_.size()); + vector out(edges_.size()); + post->resize(edges_.size()); + + vector ins_node_best(nodes_.size()); + for (int i = 0; i < nodes_.size(); ++i) { + const Node& node = nodes_[i]; + prob_t& node_ins_best = ins_node_best[i]; + if (node.in_edges_.empty()) node_ins_best = prob_t::One(); + for (int j = 0; j < node.in_edges_.size(); ++j) { + const Edge& edge = edges_[node.in_edges_[j]]; + prob_t& in_edge_sco = in[node.in_edges_[j]]; + in_edge_sco = edge.edge_prob_; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) + in_edge_sco *= ins_node_best[edge.tail_nodes_[k]]; + if (in_edge_sco > node_ins_best) node_ins_best = in_edge_sco; + } + } + const prob_t ins_sco = ins_node_best[nodes_.size() - 1]; + + // sanity check + int tots = 0; + for (int i = 0; i < nodes_.size(); ++i) { if (nodes_[i].out_edges_.empty()) tots++; } + assert(tots == 1); + + // compute outside scores, potentially using inside scores + vector out_node_best(nodes_.size()); + for (int i = nodes_.size() - 1; i >= 0; --i) { + const Node& node = nodes_[i]; + prob_t& node_out_best = out_node_best[node.id_]; + if (node.out_edges_.empty()) node_out_best = prob_t::One(); + for (int j = 0; j < node.out_edges_.size(); ++j) { + const Edge& edge = edges_[node.out_edges_[j]]; + prob_t sco = edge.edge_prob_ * out_node_best[edge.head_node_]; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + if (edge.tail_nodes_[k] != i) + sco *= ins_node_best[edge.tail_nodes_[k]]; + } + if (sco > node_out_best) node_out_best = sco; + } + for (int j = 0; j < node.in_edges_.size(); ++j) { + out[node.in_edges_[j]] = node_out_best; + } + } + + for (int i = 0; i < in.size(); ++i) + (*post)[i] = in[i] * out[i]; + + return ins_sco; +} + +void Hypergraph::PushWeightsToSource(double scale) { + vector posts; + ComputeEdgePosteriors(scale, &posts); + for (int i = 0; i < nodes_.size(); ++i) { + const Hypergraph::Node& node = nodes_[i]; + prob_t z = prob_t::Zero(); + for (int j = 0; j < node.out_edges_.size(); ++j) + z += posts[node.out_edges_[j]]; + for (int j = 0; j < node.out_edges_.size(); ++j) { + edges_[node.out_edges_[j]].edge_prob_ = posts[node.out_edges_[j]] / z; + } + } +} + +void Hypergraph::PushWeightsToGoal(double scale) { + vector posts; + ComputeEdgePosteriors(scale, &posts); + for (int i = 0; i < nodes_.size(); ++i) { + const Hypergraph::Node& node = nodes_[i]; + prob_t z = prob_t::Zero(); + for (int j = 0; j < node.in_edges_.size(); ++j) + z += posts[node.in_edges_[j]]; + for (int j = 0; j < node.in_edges_.size(); ++j) { + edges_[node.in_edges_[j]].edge_prob_ = posts[node.in_edges_[j]] / z; + } + } +} + +void Hypergraph::PruneEdges(const std::vector& prune_edge) { + assert(prune_edge.size() == edges_.size()); + TopologicallySortNodesAndEdges(nodes_.size() - 1, &prune_edge); +} + +void Hypergraph::DensityPruneInsideOutside(const double scale, + const bool use_sum_prod_semiring, + const double density, + const vector* preserve_mask) { + assert(density >= 1.0); + const int plen = ViterbiPathLength(*this); + vector bp; + int rnum = min(static_cast(edges_.size()), static_cast(density * static_cast(plen))); + if (rnum == edges_.size()) { + cerr << "No pruning required: denisty already sufficient"; + return; + } + vector io(edges_.size()); + if (use_sum_prod_semiring) + ComputeEdgePosteriors(scale, &io); + else + ComputeBestPathThroughEdges(&io); + assert(edges_.size() == io.size()); + vector sorted = io; + nth_element(sorted.begin(), sorted.begin() + rnum, sorted.end(), greater()); + const double cutoff = sorted[rnum]; + vector prune(edges_.size()); + for (int i = 0; i < edges_.size(); ++i) { + prune[i] = (io[i] < cutoff); + if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; + } + PruneEdges(prune); +} + +void Hypergraph::BeamPruneInsideOutside( + const double scale, + const bool use_sum_prod_semiring, + const double alpha, + const vector* preserve_mask) { + assert(alpha > 0.0); + assert(scale > 0.0); + vector io(edges_.size()); + if (use_sum_prod_semiring) + ComputeEdgePosteriors(scale, &io); + else + ComputeBestPathThroughEdges(&io); + assert(edges_.size() == io.size()); + prob_t best; // initializes to zero + for (int i = 0; i < io.size(); ++i) + if (io[i] > best) best = io[i]; + const prob_t aprob(exp(-alpha)); + const prob_t cutoff = best * aprob; + vector prune(edges_.size()); + //cerr << preserve_mask.size() << " " << edges_.size() << endl; + int pc = 0; + for (int i = 0; i < io.size(); ++i) { + const bool prune_edge = (io[i] < cutoff); + if (prune_edge) ++pc; + prune[i] = (io[i] < cutoff); + if (preserve_mask && (*preserve_mask)[i]) prune[i] = false; + } + cerr << "Beam pruning " << pc << "/" << io.size() << " edges\n"; + PruneEdges(prune); +} + +void Hypergraph::PrintGraphviz() const { + int ei = 0; + cerr << "digraph G {\n rankdir=LR;\n nodesep=.05;\n"; + for (vector::const_iterator i = edges_.begin(); + i != edges_.end(); ++i) { + const Edge& edge=*i; + ++ei; + static const string none = ""; + string rule = (edge.rule_ ? edge.rule_->AsString(false) : none); + + cerr << " A_" << ei << " [label=\"" << rule << " p=" << edge.edge_prob_ + << " F:" << edge.feature_values_ + << "\" shape=\"rect\"];\n"; + for (int i = 0; i < edge.tail_nodes_.size(); ++i) { + cerr << " " << edge.tail_nodes_[i] << " -> A_" << ei << ";\n"; + } + cerr << " A_" << ei << " -> " << edge.head_node_ << ";\n"; + } + for (vector::const_iterator ni = nodes_.begin(); + ni != nodes_.end(); ++ni) { + cerr << " " << ni->id_ << "[label=\"" << (ni->cat_ < 0 ? TD::Convert(ni->cat_ * -1) : "") + //cerr << " " << ni->id_ << "[label=\"" << ni->cat_ + << " n=" << ni->id_ +// << ",x=" << &*ni +// << ",in=" << ni->in_edges_.size() +// << ",out=" << ni->out_edges_.size() + << "\"];\n"; + } + cerr << "}\n"; +} + +void Hypergraph::Union(const Hypergraph& other) { + if (&other == this) return; + if (nodes_.empty()) { nodes_ = other.nodes_; edges_ = other.edges_; return; } + int noff = nodes_.size(); + int eoff = edges_.size(); + int ogoal = other.nodes_.size() - 1; + int cgoal = noff - 1; + // keep a single goal node, so add nodes.size - 1 + nodes_.resize(nodes_.size() + ogoal); + // add all edges + edges_.resize(edges_.size() + other.edges_.size()); + + for (int i = 0; i < ogoal; ++i) { + const Node& on = other.nodes_[i]; + Node& cn = nodes_[i + noff]; + cn.id_ = i + noff; + cn.in_edges_.resize(on.in_edges_.size()); + for (int j = 0; j < on.in_edges_.size(); ++j) + cn.in_edges_[j] = on.in_edges_[j] + eoff; + + cn.out_edges_.resize(on.out_edges_.size()); + for (int j = 0; j < on.out_edges_.size(); ++j) + cn.out_edges_[j] = on.out_edges_[j] + eoff; + } + + for (int i = 0; i < other.edges_.size(); ++i) { + const Edge& oe = other.edges_[i]; + Edge& ce = edges_[i + eoff]; + ce.id_ = i + eoff; + ce.rule_ = oe.rule_; + ce.feature_values_ = oe.feature_values_; + if (oe.head_node_ == ogoal) { + ce.head_node_ = cgoal; + nodes_[cgoal].in_edges_.push_back(ce.id_); + } else { + ce.head_node_ = oe.head_node_ + noff; + } + ce.tail_nodes_.resize(oe.tail_nodes_.size()); + for (int j = 0; j < oe.tail_nodes_.size(); ++j) + ce.tail_nodes_[j] = oe.tail_nodes_[j] + noff; + } + + TopologicallySortNodesAndEdges(cgoal); +} + +int Hypergraph::MarkReachable(const Node& node, + vector* rmap, + const vector* prune_edges) const { + int total = 0; + if (!(*rmap)[node.id_]) { + total = 1; + (*rmap)[node.id_] = true; + for (int i = 0; i < node.in_edges_.size(); ++i) { + if (!(prune_edges && (*prune_edges)[node.in_edges_[i]])) { + for (int j = 0; j < edges_[node.in_edges_[i]].tail_nodes_.size(); ++j) + total += MarkReachable(nodes_[edges_[node.in_edges_[i]].tail_nodes_[j]], rmap, prune_edges); + } + } + } + return total; +} + +void Hypergraph::PruneUnreachable(int goal_node_id) { + TopologicallySortNodesAndEdges(goal_node_id, NULL); +} + +void Hypergraph::RemoveNoncoaccessibleStates(int goal_node_id) { + if (goal_node_id < 0) goal_node_id += nodes_.size(); + assert(goal_node_id >= 0); + assert(goal_node_id < nodes_.size()); + + // TODO finish implementation + abort(); +} + +void Hypergraph::TopologicallySortNodesAndEdges(int goal_index, + const vector* prune_edges) { + vector sedges(edges_.size()); + // figure out which nodes are reachable from the goal + vector reachable(nodes_.size(), false); + int num_reachable = MarkReachable(nodes_[goal_index], &reachable, prune_edges); + vector snodes(num_reachable); snodes.clear(); + + // enumerate all reachable nodes in topologically sorted order + vector old_node_to_new_id(nodes_.size(), -1); + vector node_to_incount(nodes_.size(), -1); + vector node_processed(nodes_.size(), false); + typedef map > PQueue; + PQueue pri_q; + for (int i = 0; i < nodes_.size(); ++i) { + if (!reachable[i]) + continue; + const int inedges = nodes_[i].in_edges_.size(); + int incount = inedges; + for (int j = 0; j < inedges; ++j) + if (edges_[nodes_[i].in_edges_[j]].tail_nodes_.size() == 0 || + (prune_edges && (*prune_edges)[nodes_[i].in_edges_[j]])) + --incount; + // cerr << &nodes_[i] <<" : incount=" << incount << "\tout=" << nodes_[i].out_edges_.size() << "\t(in-edges=" << inedges << ")\n"; + assert(node_to_incount[i] == -1); + node_to_incount[i] = incount; + pri_q[incount].insert(i); + } + + int edge_count = 0; + while (!pri_q.empty()) { + PQueue::iterator iter = pri_q.find(0); + assert(iter != pri_q.end()); + assert(!iter->second.empty()); + + // get first node with incount = 0 + const int cur_index = *iter->second.begin(); + const Node& node = nodes_[cur_index]; + assert(reachable[cur_index]); + //cerr << "node: " << node << endl; + const int new_node_index = snodes.size(); + old_node_to_new_id[cur_index] = new_node_index; + snodes.push_back(node); + Node& new_node = snodes.back(); + new_node.id_ = new_node_index; + new_node.out_edges_.clear(); + + // fix up edges - we can now process the in edges and + // the out edges of their tails + int oi = 0; + for (int i = 0; i < node.in_edges_.size(); ++i, ++oi) { + if (prune_edges && (*prune_edges)[node.in_edges_[i]]) { + --oi; + continue; + } + new_node.in_edges_[oi] = edge_count; + Edge& edge = sedges[edge_count]; + edge.id_ = edge_count; + ++edge_count; + const Edge& old_edge = edges_[node.in_edges_[i]]; + edge.rule_ = old_edge.rule_; + edge.feature_values_ = old_edge.feature_values_; + edge.head_node_ = new_node_index; + edge.tail_nodes_.resize(old_edge.tail_nodes_.size()); + edge.edge_prob_ = old_edge.edge_prob_; + edge.i_ = old_edge.i_; + edge.j_ = old_edge.j_; + edge.prev_i_ = old_edge.prev_i_; + edge.prev_j_ = old_edge.prev_j_; + for (int j = 0; j < old_edge.tail_nodes_.size(); ++j) { + const Node& old_tail_node = nodes_[old_edge.tail_nodes_[j]]; + edge.tail_nodes_[j] = old_node_to_new_id[old_tail_node.id_]; + snodes[edge.tail_nodes_[j]].out_edges_.push_back(edge_count-1); + assert(edge.tail_nodes_[j] != new_node_index); + } + } + assert(oi <= new_node.in_edges_.size()); + new_node.in_edges_.resize(oi); + + for (int i = 0; i < node.out_edges_.size(); ++i) { + const Edge& edge = edges_[node.out_edges_[i]]; + const int next_index = edge.head_node_; + assert(cur_index != next_index); + if (!reachable[next_index]) continue; + if (prune_edges && (*prune_edges)[edge.id_]) continue; + + bool dontReduce = false; + for (int j = 0; j < edge.tail_nodes_.size() && !dontReduce; ++j) { + int tail_index = edge.tail_nodes_[j]; + dontReduce = (tail_index != cur_index) && !node_processed[tail_index]; + } + if (dontReduce) + continue; + + const int incount = node_to_incount[next_index]; + if (incount <= 0) { + cerr << "incount = " << incount << ", should be > 0!\n"; + cerr << "do you have a cycle in your hypergraph?\n"; + abort(); + } + PQueue::iterator it = pri_q.find(incount); + assert(it != pri_q.end()); + it->second.erase(next_index); + if (it->second.empty()) pri_q.erase(it); + + // reinsert node with reduced incount + pri_q[incount-1].insert(next_index); + --node_to_incount[next_index]; + } + + // remove node from set + iter->second.erase(cur_index); + if (iter->second.empty()) + pri_q.erase(iter); + node_processed[cur_index] = true; + } + + sedges.resize(edge_count); + nodes_.swap(snodes); + edges_.swap(sedges); + assert(nodes_.back().out_edges_.size() == 0); +} + +TRulePtr Hypergraph::kEPSRule; +TRulePtr Hypergraph::kUnaryRule; + +void Hypergraph::EpsilonRemove(WordID eps) { + if (!kEPSRule) { + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + } + vector kill(edges_.size(), false); + for (int i = 0; i < edges_.size(); ++i) { + const Edge& edge = edges_[i]; + if (edge.tail_nodes_.empty() && + edge.rule_->f_.size() == 1 && + edge.rule_->f_[0] == eps) { + kill[i] = true; + if (!edge.feature_values_.empty()) { + Node& node = nodes_[edge.head_node_]; + if (node.in_edges_.size() != 1) { + cerr << "[WARNING] edge with features going into non-empty node - can't promote\n"; + // this *probably* means that there are multiple derivations of the + // same sequence via different paths through the input forest + // this needs to be investigated and fixed + } else { + for (int j = 0; j < node.out_edges_.size(); ++j) + edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_; + // cerr << "PROMOTED " << edge.feature_values_ << endl; + } + } + } + } + bool created_eps = false; + PruneEdges(kill); + for (int i = 0; i < nodes_.size(); ++i) { + const Node& node = nodes_[i]; + if (node.in_edges_.empty()) { + for (int j = 0; j < node.out_edges_.size(); ++j) { + Edge& edge = edges_[node.out_edges_[j]]; + if (edge.rule_->Arity() == 2) { + assert(edge.rule_->f_.size() == 2); + assert(edge.rule_->e_.size() == 2); + edge.rule_ = kUnaryRule; + int cur = node.id_; + int t = -1; + assert(edge.tail_nodes_.size() == 2); + for (int i = 0; i < 2; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; } + assert(t != -1); + edge.tail_nodes_.resize(1); + edge.tail_nodes_[0] = t; + } else { + edge.rule_ = kEPSRule; + edge.rule_->f_[0] = eps; + edge.rule_->e_[0] = eps; + edge.tail_nodes_.clear(); + created_eps = true; + } + } + } + } + vector k2(edges_.size(), false); + PruneEdges(k2); + if (created_eps) EpsilonRemove(eps); +} + +struct EdgeWeightSorter { + const Hypergraph& hg; + EdgeWeightSorter(const Hypergraph& h) : hg(h) {} + bool operator()(int a, int b) const { + return hg.edges_[a].edge_prob_ > hg.edges_[b].edge_prob_; + } +}; + +void Hypergraph::SortInEdgesByEdgeWeights() { + for (int i = 0; i < nodes_.size(); ++i) { + Node& node = nodes_[i]; + sort(node.in_edges_.begin(), node.in_edges_.end(), EdgeWeightSorter(*this)); + } +} + diff --git a/src/hg.h b/src/hg.h new file mode 100644 index 00000000..7a2658b8 --- /dev/null +++ b/src/hg.h @@ -0,0 +1,225 @@ +#ifndef _HG_H_ +#define _HG_H_ + +#include +#include + +#include "small_vector.h" +#include "sparse_vector.h" +#include "wordid.h" +#include "trule.h" +#include "prob.h" + +// class representing an acyclic hypergraph +// - edges have 1 head, 0..n tails +class Hypergraph { + public: + Hypergraph() {} + + // SmallVector is a fast, small vector implementation for sizes <= 2 + typedef SmallVector TailNodeVector; + + // TODO get rid of state_ and cat_? + struct Node { + Node() : id_(), cat_() {} + int id_; // equal to this object's position in the nodes_ vector + WordID cat_; // non-terminal category if <0, 0 if not set + std::vector in_edges_; // contents refer to positions in edges_ + std::vector out_edges_; // contents refer to positions in edges_ + std::string state_; // opaque state + }; + + // TODO get rid of edge_prob_? (can be computed on the fly as the dot + // product of the weight vector and the feature values) + struct Edge { + Edge() : i_(-1), j_(-1), prev_i_(-1), prev_j_(-1) {} + inline int Arity() const { return tail_nodes_.size(); } + int head_node_; // refers to a position in nodes_ + TailNodeVector tail_nodes_; // contents refer to positions in nodes_ + TRulePtr rule_; + SparseVector feature_values_; + prob_t edge_prob_; // dot product of weights and feat_values + int id_; // equal to this object's position in the edges_ vector + + // span info. typically, i_ and j_ refer to indices in the source sentence + // if a synchronous parse has been executed i_ and j_ will refer to indices + // in the target sentence / lattice and prev_i_ prev_j_ will refer to + // positions in the source. Note: it is up to the translator implementation + // to properly set these values. For some models (like the Forest-input + // phrase based model) it may not be straightforward to do. if these values + // are not properly set, most things will work but alignment and any features + // that depend on them will be broken. + short int i_; + short int j_; + short int prev_i_; + short int prev_j_; + }; + + void swap(Hypergraph& other) { + other.nodes_.swap(nodes_); + other.edges_.swap(edges_); + } + + void ResizeNodes(int size) { + nodes_.resize(size); + for (int i = 0; i < size; ++i) nodes_[i].id_ = i; + } + + // reserves space in the nodes vector to prevent memory locations + // from changing + void ReserveNodes(size_t n, size_t e = 0) { + nodes_.reserve(n); + if (e) edges_.reserve(e); + } + + Edge* AddEdge(const TRulePtr& rule, const TailNodeVector& tail) { + edges_.push_back(Edge()); + Edge* edge = &edges_.back(); + edge->rule_ = rule; + edge->tail_nodes_ = tail; + edge->id_ = edges_.size() - 1; + for (int i = 0; i < edge->tail_nodes_.size(); ++i) + nodes_[edge->tail_nodes_[i]].out_edges_.push_back(edge->id_); + return edge; + } + + Node* AddNode(const WordID& cat, const std::string& state = "") { + nodes_.push_back(Node()); + nodes_.back().cat_ = cat; + nodes_.back().state_ = state; + nodes_.back().id_ = nodes_.size() - 1; + return &nodes_.back(); + } + + void ConnectEdgeToHeadNode(const int edge_id, const int head_id) { + edges_[edge_id].head_node_ = head_id; + nodes_[head_id].in_edges_.push_back(edge_id); + } + + // TODO remove this - use the version that takes indices + void ConnectEdgeToHeadNode(Edge* edge, Node* head) { + edge->head_node_ = head->id_; + head->in_edges_.push_back(edge->id_); + } + + // merge the goal node from other with this goal node + void Union(const Hypergraph& other); + + void PrintGraphviz() const; + + // compute the total number of paths in the forest + double NumberOfPaths() const; + + // BEWARE. this assumes that the source and target language + // strings are identical and that there are no loops. + // It assumes a bunch of other things about where the + // epsilons will be. It tries to assert failure if you + // break these assumptions, but it may not. + // TODO - make this work + void EpsilonRemove(WordID eps); + + // multiple the weights vector by the edge feature vector + // (inner product) to set the edge probabilities + template + void Reweight(const V& weights) { + for (int i = 0; i < edges_.size(); ++i) { + Edge& e = edges_[i]; + e.edge_prob_.logeq(e.feature_values_.dot(weights)); + } + } + + // computes inside and outside scores for each + // edge in the hypergraph + // alpha->size = edges_.size = beta->size + // returns inside prob of goal node + prob_t ComputeEdgePosteriors(double scale, + std::vector* posts) const; + + // find the score of the very best path passing through each edge + prob_t ComputeBestPathThroughEdges(std::vector* posts) const; + + // move weights as near to the source as possible, resulting in a + // stochastic automaton. ONLY FUNCTIONAL FOR *LATTICES*. + // See M. Mohri and M. Riley. A Weight Pushing Algorithm for Large + // Vocabulary Speech Recognition. 2001. + // the log semiring (NOT tropical) is used + void PushWeightsToSource(double scale = 1.0); + // same, except weights are pushed to the goal, works for HGs, + // not just lattices + void PushWeightsToGoal(double scale = 1.0); + + void SortInEdgesByEdgeWeights(); + + void PruneUnreachable(int goal_node_id); // DEPRECATED + + void RemoveNoncoaccessibleStates(int goal_node_id = -1); + + // remove edges from the hypergraph if prune_edge[edge_id] is true + void PruneEdges(const std::vector& prune_edge); + + // if you don't know, use_sum_prod_semiring should be false + void DensityPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double density, + const std::vector* preserve_mask = NULL); + + // prunes any edge whose score on the best path taking that edge is more than alpha away + // from the score of the global best past (or the highest edge posterior) + void BeamPruneInsideOutside(const double scale, const bool use_sum_prod_semiring, const double alpha, + const std::vector* preserve_mask = NULL); + + void clear() { + nodes_.clear(); + edges_.clear(); + } + + inline size_t NumberOfEdges() const { return edges_.size(); } + inline size_t NumberOfNodes() const { return nodes_.size(); } + inline bool empty() const { return nodes_.empty(); } + + // nodes_ is sorted in topological order + std::vector nodes_; + // edges_ is not guaranteed to be in any particular order + std::vector edges_; + + // reorder nodes_ so they are in topological order + // source nodes at 0 sink nodes at size-1 + void TopologicallySortNodesAndEdges(int goal_idx, + const std::vector* prune_edges = NULL); + private: + // returns total nodes reachable + int MarkReachable(const Node& node, + std::vector* rmap, + const std::vector* prune_edges) const; + + static TRulePtr kEPSRule; + static TRulePtr kUnaryRule; +}; + +// common WeightFunctions, map an edge -> WeightType +// for generic Viterbi/Inside algorithms +struct EdgeProb { + inline const prob_t& operator()(const Hypergraph::Edge& e) const { return e.edge_prob_; } +}; + +struct ScaledEdgeProb { + ScaledEdgeProb(const double& alpha) : alpha_(alpha) {} + inline prob_t operator()(const Hypergraph::Edge& e) const { return e.edge_prob_.pow(alpha_); } + const double alpha_; +}; + +struct EdgeFeaturesWeightFunction { + inline const SparseVector& operator()(const Hypergraph::Edge& e) const { return e.feature_values_; } +}; + +struct TransitionEventWeightFunction { + inline SparseVector operator()(const Hypergraph::Edge& e) const { + SparseVector result; + result.set_value(e.id_, prob_t::One()); + return result; + } +}; + +struct TransitionCountWeightFunction { + inline double operator()(const Hypergraph::Edge& e) const { (void)e; return 1.0; } +}; + +#endif diff --git a/src/hg_intersect.cc b/src/hg_intersect.cc new file mode 100644 index 00000000..a5e8913a --- /dev/null +++ b/src/hg_intersect.cc @@ -0,0 +1,121 @@ +#include "hg_intersect.h" + +#include +#include +#include +#include + +#include "tdict.h" +#include "hg.h" +#include "trule.h" +#include "wordid.h" +#include "bottom_up_parser.h" + +using boost::lexical_cast; +using namespace std::tr1; +using namespace std; + +struct RuleFilter { + unordered_map, bool, boost::hash > > exists_; + bool true_lattice; + RuleFilter(const Lattice& target, int max_phrase_size) { + true_lattice = false; + for (int i = 0; i < target.size(); ++i) { + vector phrase; + int lim = min(static_cast(target.size()), i + max_phrase_size); + for (int j = i; j < lim; ++j) { + if (target[j].size() > 1) { true_lattice = true; break; } + phrase.push_back(target[j][0].label); + exists_[phrase] = true; + } + } + vector sos(1, TD::Convert("")); + exists_[sos] = true; + } + bool operator()(const TRule& r) const { + // TODO do some smarter filtering for lattices + if (true_lattice) return false; // don't filter "true lattice" input + const vector& e = r.e(); + for (int i = 0; i < e.size(); ++i) { + if (e[i] <= 0) continue; + vector phrase; + for (int j = i; j < e.size(); ++j) { + if (e[j] <= 0) break; + phrase.push_back(e[j]); + if (exists_.count(phrase) == 0) return true; + } + } + return false; + } +}; + +bool HG::Intersect(const Lattice& target, Hypergraph* hg) { + vector rem(hg->edges_.size(), false); + const RuleFilter filter(target, 15); // TODO make configurable + for (int i = 0; i < rem.size(); ++i) + rem[i] = filter(*hg->edges_[i].rule_); + hg->PruneEdges(rem); + + const int nedges = hg->edges_.size(); + const int nnodes = hg->nodes_.size(); + + TextGrammar* g = new TextGrammar; + GrammarPtr gp(g); + vector cats(nnodes); + // each node in the translation forest becomes a "non-terminal" in the new + // grammar, create the labels here + for (int i = 0; i < nnodes; ++i) + cats[i] = TD::Convert("CAT_" + lexical_cast(i)) * -1; + + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = hg->edges_[i]; + const vector& tgt = edge.rule_->e(); + const vector& src = edge.rule_->f(); + TRulePtr rule(new TRule); + rule->prev_i = edge.i_; + rule->prev_j = edge.j_; + rule->lhs_ = cats[edge.head_node_]; + vector& f = rule->f_; + vector& e = rule->e_; + f.resize(tgt.size()); // swap source and target, since the parser + e.resize(src.size()); // parses using the source side! + Hypergraph::TailNodeVector tn(edge.tail_nodes_.size()); + int ntc = 0; + for (int j = 0; j < tgt.size(); ++j) { + const WordID& cur = tgt[j]; + if (cur > 0) { + f[j] = cur; + } else { + tn[ntc++] = cur; + f[j] = cats[edge.tail_nodes_[-cur]]; + } + } + ntc = 0; + for (int j = 0; j < src.size(); ++j) { + const WordID& cur = src[j]; + if (cur > 0) { + e[j] = cur; + } else { + e[j] = tn[ntc++]; + } + } + rule->scores_ = edge.feature_values_; + rule->parent_rule_ = edge.rule_; + rule->ComputeArity(); + //cerr << "ADD: " << rule->AsString() << endl; + + g->AddRule(rule); + } + g->SetMaxSpan(target.size() + 1); + const string& new_goal = TD::Convert(cats.back() * -1); + vector grammars(1, gp); + Hypergraph tforest; + ExhaustiveBottomUpParser parser(new_goal, grammars); + if (!parser.Parse(target, &tforest)) + return false; + else + hg->swap(tforest); + return true; +} + diff --git a/src/hg_intersect.h b/src/hg_intersect.h new file mode 100644 index 00000000..826bdaae --- /dev/null +++ b/src/hg_intersect.h @@ -0,0 +1,13 @@ +#ifndef _HG_INTERSECT_H_ +#define _HG_INTERSECT_H_ + +#include + +#include "lattice.h" + +class Hypergraph; +struct HG { + static bool Intersect(const Lattice& target, Hypergraph* hg); +}; + +#endif diff --git a/src/hg_io.cc b/src/hg_io.cc new file mode 100644 index 00000000..629e65f1 --- /dev/null +++ b/src/hg_io.cc @@ -0,0 +1,585 @@ +#include "hg_io.h" + +#include +#include + +#include + +#include "tdict.h" +#include "json_parse.h" +#include "hg.h" + +using namespace std; + +struct HGReader : public JSONParser { + HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } + + void CreateNode(const string& cat, const vector& in_edges) { + WordID c = TD::Convert("X") * -1; + if (!cat.empty()) c = TD::Convert(cat) * -1; + Hypergraph::Node* node = hg.AddNode(c, ""); + for (int i = 0; i < in_edges.size(); ++i) { + if (in_edges[i] >= hg.edges_.size()) { + cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] + << ", but hg only has " << hg.edges_.size() << " edges!\n"; + abort(); + } + hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); + } + } + void CreateEdge(const TRulePtr& rule, SparseVector* feats, const SmallVector& tail) { + Hypergraph::Edge* edge = hg.AddEdge(rule, tail); + feats->swap(edge->feature_values_); + } + + bool HandleJSONEvent(int type, const JSON_value* value) { + switch(state) { + case -1: + assert(type == JSON_T_OBJECT_BEGIN); + state = 0; + break; + case 0: + if (type == JSON_T_OBJECT_END) { + //cerr << "HG created\n"; // TODO, signal some kind of callback + } else if (type == JSON_T_KEY) { + string val = value->vu.str.value; + if (val == "features") { assert(fdict.empty()); state = 1; } + else if (val == "is_sorted") { state = 3; } + else if (val == "rules") { assert(rules.empty()); state = 4; } + else if (val == "node") { state = 8; } + else if (val == "edges") { state = 13; } + else { cerr << "Unexpected key: " << val << endl; return false; } + } + break; + + // features + case 1: + if(type == JSON_T_NULL) { state = 0; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 2; + break; + case 2: + if(type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_STRING); + fdict.push_back(FD::Convert(value->vu.str.value)); + break; + + // is_sorted + case 3: + assert(type == JSON_T_TRUE || type == JSON_T_FALSE); + is_sorted = (type == JSON_T_TRUE); + if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } + state = 0; + break; + + // rules + case 4: + if(type == JSON_T_NULL) { state = 0; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 5; + break; + case 5: + if(type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_INTEGER); + state = 6; + rule_id = value->vu.integer_value; + break; + case 6: + assert(type == JSON_T_STRING); + rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); + state = 5; + break; + + // Nodes + case 8: + assert(type == JSON_T_OBJECT_BEGIN); + ++nodes; + in_edges.clear(); + cat.clear(); + state = 9; break; + case 9: + if (type == JSON_T_OBJECT_END) { + //cerr << "Creating NODE\n"; + CreateNode(cat, in_edges); + state = 0; break; + } + assert(type == JSON_T_KEY); + cur_key = value->vu.str.value; + if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } + if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } + cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; + return false; + case 10: + assert(type == JSON_T_STRING || type == JSON_T_NULL); + cat = value->vu.str.value; + state = 9; break; + case 11: + if (type == JSON_T_NULL) { state = 9; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 12; break; + case 12: + if (type == JSON_T_ARRAY_END) { state = 9; break; } + assert(type == JSON_T_INTEGER); + //cerr << "in_edges: " << value->vu.integer_value << endl; + in_edges.push_back(value->vu.integer_value); + break; + + // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, + // { "tail": null, "feats" : [0,0.87,1,0.02], "rule": 17}, + // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] + case 13: + assert(type == JSON_T_ARRAY_BEGIN); + state = 14; + break; + case 14: + if (type == JSON_T_ARRAY_END) { state = 0; break; } + assert(type == JSON_T_OBJECT_BEGIN); + //cerr << "New edge\n"; + ++edges; + cur_rule.reset(); feats.clear(); tail.clear(); + state = 15; break; + case 15: + if (type == JSON_T_OBJECT_END) { + CreateEdge(cur_rule, &feats, tail); + state = 14; break; + } + assert(type == JSON_T_KEY); + cur_key = value->vu.str.value; + //cerr << "edge key " << cur_key << endl; + if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } + if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } + if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } + cerr << "Unexpected key " << cur_key << " in edge specification\n"; + return false; + case 16: // edge.rule + if (type == JSON_T_INTEGER) { + int rule_id = value->vu.integer_value; + if (rules.find(rule_id) == rules.end()) { + // rules list must come before the edge definitions! + cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; + return false; + } + cur_rule = rules[rule_id]; + } else if (type == JSON_T_STRING) { + cur_rule.reset(new TRule(value->vu.str.value)); + } else { + cerr << "Rule must be either a rule id or a rule string" << endl; + return false; + } + // cerr << "Edge: rule=" << cur_rule->AsString() << endl; + state = 15; + break; + case 17: // edge.feats + if (type == JSON_T_NULL) { state = 15; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 18; break; + case 18: + if (type == JSON_T_ARRAY_END) { state = 15; break; } + if (type != JSON_T_INTEGER && type != JSON_T_STRING) { + cerr << "Unexpected feature id type\n"; return false; + } + if (type == JSON_T_INTEGER) { + fid = value->vu.integer_value; + assert(fid < fdict.size()); + fid = fdict[fid]; + } else if (JSON_T_STRING) { + fid = FD::Convert(value->vu.str.value); + } else { abort(); } + state = 19; + break; + case 19: + { + assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); + double val = (type == JSON_T_INTEGER ? static_cast(value->vu.integer_value) : + strtod(value->vu.str.value, NULL)); + feats.set_value(fid, val); + state = 18; + break; + } + case 20: // edge.tail + if (type == JSON_T_NULL) { state = 15; break; } + assert(type == JSON_T_ARRAY_BEGIN); + state = 21; break; + case 21: + if (type == JSON_T_ARRAY_END) { state = 15; break; } + assert(type == JSON_T_INTEGER); + tail.push_back(value->vu.integer_value); + break; + } + return true; + } + string rp; + string cat; + SmallVector tail; + vector in_edges; + TRulePtr cur_rule; + map rules; + vector fdict; + SparseVector feats; + int state; + int fid; + int nodes; + int edges; + string cur_key; + Hypergraph& hg; + int rule_id; + bool nodes_needed; + bool edges_needed; + bool is_sorted; +}; + +bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { + hg->clear(); + HGReader reader(hg); + return reader.Parse(in); +} + +static void WriteRule(const TRule& r, ostream* out) { + if (!r.lhs_) { (*out) << "[X] ||| "; } + JSONParser::WriteEscapedString(r.AsString(), out); +} + +bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { + map rid; + ostream& o = *out; + rid[NULL] = 0; + o << '{'; + if (!remove_rules) { + o << "\"rules\":["; + for (int i = 0; i < hg.edges_.size(); ++i) { + const TRule* r = hg.edges_[i].rule_.get(); + int &id = rid[r]; + if (!id) { + id=rid.size() - 1; + if (id > 1) o << ','; + o << id << ','; + WriteRule(*r, &o); + }; + } + o << "],"; + } + const bool use_fdict = FD::NumFeats() < 1000; + if (use_fdict) { + o << "\"features\":["; + for (int i = 1; i < FD::NumFeats(); ++i) { + o << (i==1 ? "":",") << '"' << FD::Convert(i) << '"'; + } + o << "],"; + } + vector edgemap(hg.edges_.size(), -1); // edges may be in non-topo order + int edge_count = 0; + for (int i = 0; i < hg.nodes_.size(); ++i) { + const Hypergraph::Node& node = hg.nodes_[i]; + if (i > 0) { o << ","; } + o << "\"edges\":["; + for (int j = 0; j < node.in_edges_.size(); ++j) { + const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; + edgemap[edge.id_] = edge_count; + ++edge_count; + o << (j == 0 ? "" : ",") << "{"; + + o << "\"tail\":["; + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; + } + o << "],"; + + o << "\"feats\":["; + bool first = true; + for (SparseVector::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { + if (!it->second) continue; + if (!first) o << ','; + if (use_fdict) + o << (it->first - 1); + else + o << '"' << FD::Convert(it->first) << '"'; + o << ',' << it->second; + first = false; + } + o << "]"; + if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } + o << "}"; + } + o << "],"; + + o << "\"node\":{\"in_edges\":["; + for (int j = 0; j < node.in_edges_.size(); ++j) { + int mapped_edge = edgemap[node.in_edges_[j]]; + assert(mapped_edge >= 0); + o << (j == 0 ? "" : ",") << mapped_edge; + } + o << "]"; + if (node.cat_ < 0) { o << ",\"cat\":\"" << TD::Convert(node.cat_ * -1) << '"'; } + o << "}"; + } + o << "}\n"; + return true; +} + +bool needs_escape[128]; +void InitEscapes() { + memset(needs_escape, false, 128); + needs_escape[static_cast('\'')] = true; + needs_escape[static_cast('\\')] = true; +} + +string HypergraphIO::Escape(const string& s) { + size_t len = s.size(); + for (int i = 0; i < s.size(); ++i) { + unsigned char c = s[i]; + if (c < 128 && needs_escape[c]) ++len; + } + if (len == s.size()) return s; + string res(len, ' '); + size_t o = 0; + for (int i = 0; i < s.size(); ++i) { + unsigned char c = s[i]; + if (c < 128 && needs_escape[c]) + res[o++] = '\\'; + res[o++] = c; + } + assert(o == len); + return res; +} + +string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses) { + static bool first = true; + if (first) { InitEscapes(); first = false; } + if (hg.nodes_.empty()) return "()"; + ostringstream os; + if (include_global_parentheses) os << '('; + static const string EPS="*EPS*"; + for (int i = 0; i < hg.nodes_.size()-1; ++i) { + os << '('; + if (hg.nodes_[i].out_edges_.empty()) abort(); + for (int j = 0; j < hg.nodes_[i].out_edges_.size(); ++j) { + const Hypergraph::Edge& e = hg.edges_[hg.nodes_[i].out_edges_[j]]; + const string output = e.rule_->e_.size() ==2 ? Escape(TD::Convert(e.rule_->e_[1])) : EPS; + double prob = log(e.edge_prob_); + if (isinf(prob)) { prob = -9e20; } + if (isnan(prob)) { prob = 0; } + os << "('" << output << "'," << prob << "," << e.head_node_ - i << "),"; + } + os << "),"; + } + if (include_global_parentheses) os << ')'; + return os.str(); +} + +namespace PLF { + +const string chars = "'\\"; +const char& quote = chars[0]; +const char& slash = chars[1]; + +// safe get +inline char get(const std::string& in, int c) { + if (c < 0 || c >= (int)in.size()) return 0; + else return in[(size_t)c]; +} + +// consume whitespace +inline void eatws(const std::string& in, int& c) { + while (get(in,c) == ' ') { c++; } +} + +// from 'foo' return foo +std::string getEscapedString(const std::string& in, int &c) +{ + eatws(in,c); + if (get(in,c++) != quote) return "ERROR"; + std::string res; + char cur = 0; + do { + cur = get(in,c++); + if (cur == slash) { res += get(in,c++); } + else if (cur != quote) { res += cur; } + } while (get(in,c) != quote && (c < (int)in.size())); + c++; + eatws(in,c); + return res; +} + +// basically atof +float getFloat(const std::string& in, int &c) +{ + std::string tmp; + eatws(in,c); + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + tmp += get(in,c++); + } + eatws(in,c); + return atof(tmp.c_str()); +} + +// basically atoi +int getInt(const std::string& in, int &c) +{ + std::string tmp; + eatws(in,c); + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + tmp += get(in,c++); + } + eatws(in,c); + return atoi(tmp.c_str()); +} + +// maximum number of nodes permitted +#define MAX_NODES 100000000 +// parse ('foo', 0.23) +void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { + if (get(in,c++) != '(') { assert(!"PCN/PLF parse error: expected ( at start of cn alt block\n"); } + vector ewords(2, 0); + ewords[1] = TD::Convert(getEscapedString(in,c)); + TRulePtr r(new TRule(ewords)); + //cerr << "RULE: " << r->AsString() << endl; + if (get(in,c++) != ',') { assert(!"PCN/PLF parse error: expected , after string\n"); } + size_t cnNext = 1; + std::vector probs; + probs.push_back(getFloat(in,c)); + while (get(in,c) == ',') { + c++; + float val = getFloat(in,c); + probs.push_back(val); + } + //if we read more than one prob, this was a lattice, last item was column increment + if (probs.size()>1) { + cnNext = static_cast(probs.back()); + probs.pop_back(); + if (cnNext < 1) { assert(!"PCN/PLF parse error: bad link length at last element of cn alt block\n"); } + } + if (get(in,c++) != ')') { assert(!"PCN/PLF parse error: expected ) at end of cn alt block\n"); } + eatws(in,c); + Hypergraph::TailNodeVector tail(1, cur_node); + Hypergraph::Edge* edge = hg->AddEdge(r, tail); + //cerr << " <--" << cur_node << endl; + int head_node = cur_node + cnNext; + assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory + if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } + hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); + for (int i = 0; i < probs.size(); ++i) + edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast(i)), probs[i]); +} + +// parse (('foo', 0.23), ('bar', 0.77)) +void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { + //cerr << "PLF READING NODE " << cur_node << endl; + if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } + if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } + eatws(in,c); + while (1) { + if (c > (int)in.size()) { break; } + if (get(in,c) == ')') { + c++; + eatws(in,c); + break; + } + if (get(in,c) == ',' && get(in,c+1) == ')') { + c+=2; + eatws(in,c); + break; + } + if (get(in,c) == ',') { c++; eatws(in,c); } + ReadPLFEdge(in, c, cur_node, hg); + } +} + +} // namespace PLF + +void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) { + hg->clear(); + int c = 0; + int cur_node = 0; + if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } + while (1) { + if (c > (int)in.size()) { break; } + if (PLF::get(in,c) == ')') { + c++; + PLF::eatws(in,c); + break; + } + if (PLF::get(in,c) == ',' && PLF::get(in,c+1) == ')') { + c+=2; + PLF::eatws(in,c); + break; + } + if (PLF::get(in,c) == ',') { c++; PLF::eatws(in,c); } + PLF::ReadPLFNode(in, c, cur_node, line, hg); + ++cur_node; + } + assert(cur_node == hg->nodes_.size() - 1); +} + +void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { + Lattice& l = *pl; + Hypergraph g; + ReadFromPLF(plf, &g, 0); + const int num_nodes = g.nodes_.size() - 1; + l.resize(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + vector& alts = l[i]; + const Hypergraph::Node& node = g.nodes_[i]; + const int num_alts = node.out_edges_.size(); + alts.resize(num_alts); + for (int j = 0; j < num_alts; ++j) { + const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; + alts[j].label = edge.rule_->e_[1]; + alts[j].cost = edge.feature_values_.value(FD::Convert("Feature_0")); + alts[j].dist2next = edge.head_node_ - node.id_; + } + } +} + +namespace B64 { + +static const char cb64[]="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +static const char cd64[]="|$$$}rstuvwxyz{$$$$$$$>?@ABCDEFGHIJKLMNOPQRSTUVW$$$$$$XYZ[\\]^_`abcdefghijklmnopq"; + +static void encodeblock(const unsigned char* in, ostream* os, int len) { + char out[4]; + out[0] = cb64[ in[0] >> 2 ]; + out[1] = cb64[ ((in[0] & 0x03) << 4) | ((in[1] & 0xf0) >> 4) ]; + out[2] = (len > 1 ? cb64[ ((in[1] & 0x0f) << 2) | ((in[2] & 0xc0) >> 6) ] : '='); + out[3] = (len > 2 ? cb64[ in[2] & 0x3f ] : '='); + os->write(out, 4); +} + +void b64encode(const char* data, const size_t size, ostream* out) { + size_t cur = 0; + while(cur < size) { + int len = min(static_cast(3), size - cur); + encodeblock(reinterpret_cast(&data[cur]), out, len); + cur += len; + } +} + +static void decodeblock(const unsigned char* in, unsigned char* out) { + out[0] = (unsigned char ) (in[0] << 2 | in[1] >> 4); + out[1] = (unsigned char ) (in[1] << 4 | in[2] >> 2); + out[2] = (unsigned char ) (((in[2] << 6) & 0xc0) | in[3]); +} + +bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize) { + size_t cur = 0; + size_t ocur = 0; + unsigned char in[4]; + while(cur < insize) { + assert(ocur < outsize); + for (int i = 0; i < 4; ++i) { + unsigned char v = data[cur]; + v = (unsigned char) ((v < 43 || v > 122) ? '\0' : cd64[ v - 43 ]); + if (!v) { + cerr << "B64 decode error at offset " << cur << " offending character: " << (int)data[cur] << endl; + return false; + } + v = (unsigned char) ((v == '$') ? '\0' : v - 61); + if (v) in[i] = v - 1; else in[i] = 0; + ++cur; + } + decodeblock(in, reinterpret_cast(&out[ocur])); + ocur += 3; + } + return true; +} +} + diff --git a/src/hg_io.h b/src/hg_io.h new file mode 100644 index 00000000..69a516c1 --- /dev/null +++ b/src/hg_io.h @@ -0,0 +1,37 @@ +#ifndef _HG_IO_H_ +#define _HG_IO_H_ + +#include + +#include "lattice.h" +class Hypergraph; + +struct HypergraphIO { + + // the format is basically a list of nodes and edges in topological order + // any edge you read, you must have already read its tail nodes + // any node you read, you must have already read its incoming edges + // this may make writing a bit more challenging if your forest is not + // topologically sorted (but that probably doesn't happen very often), + // but it makes reading much more memory efficient. + // see test_data/small.json.gz for an email encoding + static bool ReadFromJSON(std::istream* in, Hypergraph* out); + + // if remove_rules is used, the hypergraph is serialized without rule information + // (so it only contains structure and feature information) + static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); + + // serialization utils + static void ReadFromPLF(const std::string& in, Hypergraph* out, int line = 0); + // return PLF string representation (undefined behavior on non-lattices) + static std::string AsPLF(const Hypergraph& hg, bool include_global_parentheses = true); + static void PLFtoLattice(const std::string& plf, Lattice* pl); + static std::string Escape(const std::string& s); // PLF helper +}; + +namespace B64 { + bool b64decode(const unsigned char* data, const size_t insize, char* out, const size_t outsize); + void b64encode(const char* data, const size_t size, std::ostream* out); +} + +#endif diff --git a/src/hg_test.cc b/src/hg_test.cc new file mode 100644 index 00000000..ecd97508 --- /dev/null +++ b/src/hg_test.cc @@ -0,0 +1,441 @@ +#include +#include +#include +#include +#include +#include "tdict.h" + +#include "json_parse.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "hg_intersect.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" + +using namespace std; + +class HGTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } + void CreateHG(Hypergraph* hg) const; + void CreateHG_int(Hypergraph* hg) const; + void CreateHG_tiny(Hypergraph* hg) const; + void CreateHGBalanced(Hypergraph* hg) const; + void CreateLatticeHG(Hypergraph* hg) const; + void CreateTinyLatticeHG(Hypergraph* hg) const; +}; + +void HGTest::CreateTinyLatticeHG(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] b\",4,\"[X] ||| [1] B'\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.2],\"rule\":1},{\"tail\":[0],\"feats\":[0,-0.6],\"rule\":2}],\"node\":{\"in_edges\":[0,1]},\"edges\":[{\"tail\":[1],\"feats\":[0,-0.1],\"rule\":3},{\"tail\":[1],\"feats\":[0,-0.9],\"rule\":4}],\"node\":{\"in_edges\":[2,3]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateLatticeHG(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| [1] a\",2,\"[X] ||| [1] A\",3,\"[X] ||| [1] A A\",4,\"[X] ||| [1] b\",5,\"[X] ||| [1] c\",6,\"[X] ||| [1] B C\",7,\"[X] ||| [1] A B C\",8,\"[X] ||| [1] CC\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[],\"node\":{\"in_edges\":[]},\"edges\":[{\"tail\":[0],\"feats\":[2,-0.3],\"rule\":1},{\"tail\":[0],\"feats\":[2,-0.6],\"rule\":2},{\"tail\":[0],\"feats\":[2,-1.7],\"rule\":3}],\"node\":{\"in_edges\":[0,1,2]},\"edges\":[{\"tail\":[1],\"feats\":[2,-0.5],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[2],\"feats\":[2,-0.6],\"rule\":5},{\"tail\":[1],\"feats\":[2,-0.8],\"rule\":6},{\"tail\":[0],\"feats\":[2,-0.01],\"rule\":7},{\"tail\":[2],\"feats\":[2,-0.8],\"rule\":8}],\"node\":{\"in_edges\":[4,5,6,7]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG_tiny(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| \",2,\"[X] ||| X [1]\",3,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,-2,1,-99],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.5,1,-0.8],\"rule\":2},{\"tail\":[0],\"feats\":[0,-0.7,1,-0.9],\"rule\":3}],\"node\":{\"in_edges\":[1,2]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG_int(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| b\",3,\"[X] ||| a [1]\",4,\"[X] ||| [1] b\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[0,0.1],\"rule\":1},{\"tail\":[],\"feats\":[0,0.1],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X\"},\"edges\":[{\"tail\":[0],\"feats\":[0,0.3],\"rule\":3},{\"tail\":[0],\"feats\":[0,0.2],\"rule\":4}],\"node\":{\"in_edges\":[2,3],\"cat\":\"Goal\"}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHG(Hypergraph* hg) const { + string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +void HGTest::CreateHGBalanced(Hypergraph* hg) const { + const string json = "{\"rules\":[1,\"[X] ||| i\",2,\"[X] ||| a\",3,\"[X] ||| b\",4,\"[X] ||| [1] [2]\",5,\"[X] ||| [1] [2]\",6,\"[X] ||| c\",7,\"[X] ||| d\",8,\"[X] ||| [1] [2]\",9,\"[X] ||| [1] [2]\",10,\"[X] ||| [1] [2]\",11,\"[X] ||| [1] [2]\",12,\"[X] ||| [1] [2]\",13,\"[X] ||| [1] [2]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[1,2],\"feats\":[],\"rule\":4},{\"tail\":[2,1],\"feats\":[],\"rule\":5}],\"node\":{\"in_edges\":[3,4]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":6}],\"node\":{\"in_edges\":[5]},\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":7}],\"node\":{\"in_edges\":[6]},\"edges\":[{\"tail\":[4,5],\"feats\":[],\"rule\":8},{\"tail\":[5,4],\"feats\":[],\"rule\":9}],\"node\":{\"in_edges\":[7,8]},\"edges\":[{\"tail\":[3,6],\"feats\":[],\"rule\":10},{\"tail\":[6,3],\"feats\":[],\"rule\":11}],\"node\":{\"in_edges\":[9,10]},\"edges\":[{\"tail\":[7,0],\"feats\":[],\"rule\":12},{\"tail\":[0,7],\"feats\":[],\"rule\":13}],\"node\":{\"in_edges\":[11,12]}}"; + istringstream instr(json); + EXPECT_TRUE(HypergraphIO::ReadFromJSON(&instr, hg)); +} + +TEST_F(HGTest,Controlled) { + Hypergraph hg; + CreateHG_tiny(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector trans; + prob_t prob = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "prob: " << prob << "\n"; + EXPECT_FLOAT_EQ(-80.839996, log(prob)); + EXPECT_EQ("X ", TD::GetString(trans)); + vector post; + hg.PrintGraphviz(); + prob_t c2 = Inside(hg, NULL, ScaledEdgeProb(0.6)); + EXPECT_FLOAT_EQ(-47.8577, log(c2)); +} + +TEST_F(HGTest,Union) { + Hypergraph hg1; + Hypergraph hg2; + CreateHG_tiny(&hg1); + CreateHG(&hg2); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg1.Reweight(wts); + hg2.Reweight(wts); + prob_t c1,c2,c3,c4; + vector t1,t2,t3,t4; + c1 = ViterbiESentence(hg1, &t1); + c2 = ViterbiESentence(hg2, &t2); + int l2 = ViterbiPathLength(hg2); + cerr << c1 << "\t" << TD::GetString(t1) << endl; + cerr << c2 << "\t" << TD::GetString(t2) << endl; + hg1.Union(hg2); + hg1.Reweight(wts); + c3 = ViterbiESentence(hg1, &t3); + int l3 = ViterbiPathLength(hg1); + cerr << c3 << "\t" << TD::GetString(t3) << endl; + EXPECT_FLOAT_EQ(c2, c3); + EXPECT_EQ(TD::GetString(t2), TD::GetString(t3)); + EXPECT_EQ(l2, l3); + + wts.set_value(FD::Convert("f2"), -1); + hg1.Reweight(wts); + c4 = ViterbiESentence(hg1, &t4); + cerr << c4 << "\t" << TD::GetString(t4) << endl; + EXPECT_EQ("Z ", TD::GetString(t4)); + EXPECT_FLOAT_EQ(98.82, log(c4)); + + vector, prob_t> > list; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg1, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg1.nodes_.size() - 1, i); + if (!d) break; + list.push_back(make_pair(d->yield, d->score)); + } + EXPECT_TRUE(list[0].first == t4); + EXPECT_FLOAT_EQ(log(list[0].second), log(c4)); + EXPECT_EQ(list.size(), 6); + EXPECT_FLOAT_EQ(log(list.back().second / list.front().second), -97.7); +} + +TEST_F(HGTest,ControlledKBest) { + Hypergraph hg; + CreateHG(&hg); + vector w(2); w[0]=0.4; w[1]=0.8; + hg.Reweight(w); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + + int best = 0; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + ++best; + } + EXPECT_EQ(4, best); +} + + +TEST_F(HGTest,InsideScore) { + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + Hypergraph hg; + CreateTinyLatticeHG(&hg); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + prob_t inside = Inside(hg); + EXPECT_FLOAT_EQ(1.7934048, inside); // computed by hand + vector post; + inside = hg.ComputeBestPathThroughEdges(&post); + EXPECT_FLOAT_EQ(-0.3, log(inside)); // computed by hand + EXPECT_EQ(post.size(), 4); + for (int i = 0; i < 4; ++i) { + cerr << "edge post: " << log(post[i]) << '\t' << hg.edges_[i].rule_->AsString() << endl; + } +} + + +TEST_F(HGTest,PruneInsideOutside) { + SparseVector wts; + wts.set_value(FD::Convert("Feature_1"), 1.0); + Hypergraph hg; + CreateLatticeHG(&hg); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + //hg.DensityPruneInsideOutside(0.5, false, 2.0); + hg.BeamPruneInsideOutside(0.5, false, 0.5); + cost = ViterbiESentence(hg, &trans); + cerr << "Ncst: " << cost << endl; + cerr << TD::GetString(trans) << "\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestPruneEdges) { + Hypergraph hg; + CreateLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + vector prune(hg.edges_.size(), true); + prune[6] = false; + hg.PruneEdges(prune); + cerr << "Pruned:\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestIntersect) { + Hypergraph hg; + CreateHG_int(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + + int best = 0; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + ++best; + } + EXPECT_EQ(4, best); + + Lattice target(2); + target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1)); + target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1)); + HG::Intersect(target, &hg); + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestPrune2) { + Hypergraph hg; + CreateHG_int(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + vector rem(hg.edges_.size(), false); + rem[0] = true; + rem[1] = true; + hg.PruneEdges(rem); + hg.PrintGraphviz(); + cerr << "TODO: fix this pruning behavior-- the resulting HG should be empty!\n"; +} + +TEST_F(HGTest,Sample) { + Hypergraph hg; + CreateLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("Feature_1"), 0.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); +} + +TEST_F(HGTest,PLF) { + Hypergraph hg; + string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; + HypergraphIO::ReadFromPLF(inplf, &hg); + SparseVector wts; + wts.set_value(FD::Convert("Feature_0"), 1.0); + hg.Reweight(wts); + hg.PrintGraphviz(); + string outplf = HypergraphIO::AsPLF(hg); + cerr << " IN: " << inplf << endl; + cerr << "OUT: " << outplf << endl; + assert(inplf == outplf); +} + +TEST_F(HGTest,PushWeightsToGoal) { + Hypergraph hg; + CreateHG(&hg); + vector w(2); w[0]=0.4; w[1]=0.8; + hg.Reweight(w); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + hg.PrintGraphviz(); + hg.PushWeightsToGoal(); + hg.PrintGraphviz(); +} + +TEST_F(HGTest,TestSpecialKBest) { + Hypergraph hg; + CreateHGBalanced(&hg); + vector w(1); w[0]=0; + hg.Reweight(w); + vector, prob_t> > list; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 100000); + for (int i = 0; i < 100000; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << endl; + } + hg.PrintGraphviz(); +} + +TEST_F(HGTest, TestGenericViterbi) { + Hypergraph hg; + CreateHG_tiny(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector trans; + const prob_t prob = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "prob: " << prob << "\n"; + EXPECT_FLOAT_EQ(-80.839996, log(prob)); + EXPECT_EQ("X ", TD::GetString(trans)); +} + +TEST_F(HGTest, TestGenericInside) { + Hypergraph hg; + CreateTinyLatticeHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 1.0); + hg.Reweight(wts); + vector inside; + prob_t ins = Inside(hg, &inside); + EXPECT_FLOAT_EQ(1.7934048, ins); // computed by hand + vector outside; + Outside(hg, inside, &outside); + EXPECT_EQ(3, outside.size()); + EXPECT_FLOAT_EQ(1.7934048, outside[0]); + EXPECT_FLOAT_EQ(1.3114071, outside[1]); + EXPECT_FLOAT_EQ(1.0, outside[2]); +} + +TEST_F(HGTest,TestGenericInside2) { + Hypergraph hg; + CreateHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + vector inside, outside; + prob_t ins = Inside(hg, &inside); + Outside(hg, inside, &outside); + for (int i = 0; i < hg.nodes_.size(); ++i) + cerr << i << "\t" << log(inside[i]) << "\t" << log(outside[i]) << endl; + EXPECT_FLOAT_EQ(0, log(inside[0])); + EXPECT_FLOAT_EQ(-1.7861683, log(outside[0])); + EXPECT_FLOAT_EQ(-0.4, log(inside[1])); + EXPECT_FLOAT_EQ(-1.3861683, log(outside[1])); + EXPECT_FLOAT_EQ(-0.8, log(inside[2])); + EXPECT_FLOAT_EQ(-0.986168, log(outside[2])); + EXPECT_FLOAT_EQ(-0.96, log(inside[3])); + EXPECT_FLOAT_EQ(-0.8261683, log(outside[3])); + EXPECT_FLOAT_EQ(-1.562512, log(inside[4])); + EXPECT_FLOAT_EQ(-0.22365622, log(outside[4])); + EXPECT_FLOAT_EQ(-1.7861683, log(inside[5])); + EXPECT_FLOAT_EQ(0, log(outside[5])); +} + +TEST_F(HGTest,TestAddExpectations) { + Hypergraph hg; + CreateHG(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 0.8); + hg.Reweight(wts); + SparseVector feat_exps; + InsideOutside, EdgeFeaturesWeightFunction>(hg, &feat_exps); + EXPECT_FLOAT_EQ(-2.5439765, feat_exps[FD::Convert("f1")]); + EXPECT_FLOAT_EQ(-2.6357865, feat_exps[FD::Convert("f2")]); + cerr << feat_exps << endl; + SparseVector posts; + InsideOutside, TransitionEventWeightFunction>(hg, &posts); +} + +TEST_F(HGTest, Small) { + ReadFile rf("test_data/small.json.gz"); + Hypergraph hg; + assert(HypergraphIO::ReadFromJSON(rf.stream(), &hg)); + SparseVector wts; + wts.set_value(FD::Convert("Model_0"), -2.0); + wts.set_value(FD::Convert("Model_1"), -0.5); + wts.set_value(FD::Convert("Model_2"), -1.1); + wts.set_value(FD::Convert("Model_3"), -1.0); + wts.set_value(FD::Convert("Model_4"), -1.0); + wts.set_value(FD::Convert("Model_5"), 0.5); + wts.set_value(FD::Convert("Model_6"), 0.2); + wts.set_value(FD::Convert("Model_7"), -3.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + vector post; + prob_t c2 = Inside(hg, NULL, ScaledEdgeProb(0.6)); + EXPECT_FLOAT_EQ(2.1431036, log(c2)); +} + +TEST_F(HGTest, JSONTest) { + ostringstream os; + JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); + EXPECT_EQ("\"\\\"I don't know\\\", she said.\"", os.str()); + ostringstream os2; + JSONParser::WriteEscapedString("yes", &os2); + EXPECT_EQ("\"yes\"", os2.str()); +} + +TEST_F(HGTest, TestGenericKBest) { + Hypergraph hg; + CreateHG(&hg); + //CreateHGBalanced(&hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg.Reweight(wts); + vector trans; + prob_t cost = ViterbiESentence(hg, &trans); + cerr << TD::GetString(trans) << "\n"; + cerr << "cost: " << cost << "\n"; + + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 1000); + for (int i = 0; i < 1000; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << TD::GetString(d->yield) << " F:" << d->feature_values << endl; + } +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ibm_model1.cc b/src/ibm_model1.cc new file mode 100644 index 00000000..eb9c617c --- /dev/null +++ b/src/ibm_model1.cc @@ -0,0 +1,4 @@ +#include + + + diff --git a/src/inside_outside.h b/src/inside_outside.h new file mode 100644 index 00000000..9114c9d7 --- /dev/null +++ b/src/inside_outside.h @@ -0,0 +1,111 @@ +#ifndef _INSIDE_H_ +#define _INSIDE_H_ + +#include +#include +#include "hg.h" + +// run the inside algorithm and return the inside score +// if result is non-NULL, result will contain the inside +// score for each node +// NOTE: WeightType(0) must construct the semiring's additive identity +// WeightType(1) must construct the semiring's multiplicative identity +template +WeightType Inside(const Hypergraph& hg, + std::vector* result = NULL, + const WeightFunction& weight = WeightFunction()) { + const int num_nodes = hg.nodes_.size(); + std::vector dummy; + std::vector& inside_score = result ? *result : dummy; + inside_score.resize(num_nodes); + std::fill(inside_score.begin(), inside_score.end(), WeightType()); + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + WeightType* const cur_node_inside_score = &inside_score[i]; + const int num_in_edges = cur_node.in_edges_.size(); + if (num_in_edges == 0) { + *cur_node_inside_score = WeightType(1); + continue; + } + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + WeightType score = weight(edge); + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + const int tail_node_index = edge.tail_nodes_[k]; + score *= inside_score[tail_node_index]; + } + *cur_node_inside_score += score; + } + } + return inside_score.back(); +} + +template +void Outside(const Hypergraph& hg, + std::vector& inside_score, + std::vector* result, + const WeightFunction& weight = WeightFunction()) { + assert(result); + const int num_nodes = hg.nodes_.size(); + assert(inside_score.size() == num_nodes); + std::vector& outside_score = *result; + outside_score.resize(num_nodes); + std::fill(outside_score.begin(), outside_score.end(), WeightType(0)); + outside_score.back() = WeightType(1); + for (int i = num_nodes - 1; i >= 0; --i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + const WeightType& head_node_outside_score = outside_score[i]; + const int num_in_edges = cur_node.in_edges_.size(); + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + const WeightType head_and_edge_weight = weight(edge) * head_node_outside_score; + const int num_tail_nodes = edge.tail_nodes_.size(); + for (int k = 0; k < num_tail_nodes; ++k) { + const int update_tail_node_index = edge.tail_nodes_[k]; + WeightType* const tail_outside_score = &outside_score[update_tail_node_index]; + WeightType inside_contribution = WeightType(1); + for (int l = 0; l < num_tail_nodes; ++l) { + const int other_tail_node_index = edge.tail_nodes_[l]; + if (update_tail_node_index != other_tail_node_index) + inside_contribution *= inside_score[other_tail_node_index]; + } + *tail_outside_score += head_and_edge_weight * inside_contribution; + } + } + } +} + +// this is the Inside-Outside optimization described in Li et al. (EMNLP 2009) +// for computing the inside algorithm over expensive semirings +// (such as expectations over features). See Figure 4. It is slightly different +// in that x/k is returned not (k,x) +// NOTE: RType * PType must be valid (and yield RType) +template +PType InsideOutside(const Hypergraph& hg, + RType* result_x, + const WeightFunction& weight1 = WeightFunction(), + const WeightFunction2& weight2 = WeightFunction2()) { + const int num_nodes = hg.nodes_.size(); + std::vector inside, outside; + const PType z = Inside(hg, &inside, weight1); + Outside(hg, inside, &outside, weight1); + RType& x = *result_x; + x = RType(); + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + const int num_in_edges = cur_node.in_edges_.size(); + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + PType prob = outside[i]; + prob *= weight1(edge); + const int num_tail_nodes = edge.tail_nodes_.size(); + for (int k = 0; k < num_tail_nodes; ++k) + prob *= inside[edge.tail_nodes_[k]]; + prob /= z; + x += weight2(edge) * prob; + } + } + return z; +} + +#endif diff --git a/src/json_parse.cc b/src/json_parse.cc new file mode 100644 index 00000000..f6fdfea8 --- /dev/null +++ b/src/json_parse.cc @@ -0,0 +1,50 @@ +#include "json_parse.h" + +#include +#include + +using namespace std; + +static const char *json_hex_chars = "0123456789abcdef"; + +void JSONParser::WriteEscapedString(const string& in, ostream* out) { + int pos = 0; + int start_offset = 0; + unsigned char c = 0; + (*out) << '"'; + while(pos < in.size()) { + c = in[pos]; + switch(c) { + case '\b': + case '\n': + case '\r': + case '\t': + case '"': + case '\\': + case '/': + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + if(c == '\b') (*out) << "\\b"; + else if(c == '\n') (*out) << "\\n"; + else if(c == '\r') (*out) << "\\r"; + else if(c == '\t') (*out) << "\\t"; + else if(c == '"') (*out) << "\\\""; + else if(c == '\\') (*out) << "\\\\"; + else if(c == '/') (*out) << "\\/"; + start_offset = ++pos; + break; + default: + if(c < ' ') { + cerr << "Warning, bad character (" << static_cast(c) << ") in string\n"; + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; + start_offset = ++pos; + } else pos++; + } + } + if(pos - start_offset > 0) + (*out) << in.substr(start_offset, pos - start_offset); + (*out) << '"'; +} + diff --git a/src/json_parse.h b/src/json_parse.h new file mode 100644 index 00000000..c3cba954 --- /dev/null +++ b/src/json_parse.h @@ -0,0 +1,58 @@ +#ifndef _JSON_WRAPPER_H_ +#define _JSON_WRAPPER_H_ + +#include +#include +#include "JSON_parser.h" + +class JSONParser { + public: + JSONParser() { + init_JSON_config(&config); + hack.mf = &JSONParser::Callback; + config.depth = 10; + config.callback_ctx = reinterpret_cast(this); + config.callback = hack.cb; + config.allow_comments = 1; + config.handle_floats_manually = 1; + jc = new_JSON_parser(&config); + } + virtual ~JSONParser() { + delete_JSON_parser(jc); + } + bool Parse(std::istream* in) { + int count = 0; + int lc = 1; + for (; in ; ++count) { + int next_char = in->get(); + if (!in->good()) break; + if (lc == '\n') { ++lc; } + if (!JSON_parser_char(jc, next_char)) { + std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; + return false; + } + } + if (!JSON_parser_done(jc)) { + std::cerr << "JSON_parser_done: syntax error\n"; + return false; + } + return true; + } + static void WriteEscapedString(const std::string& in, std::ostream* out); + protected: + virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; + private: + int Callback(int type, const JSON_value* value) { + if (HandleJSONEvent(type, value)) return 1; + return 0; + } + JSON_parser_struct* jc; + JSON_config config; + typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); + union CBHack { + JSON_parser_callback cb; + MF mf; + } hack; +}; + +#endif diff --git a/src/kbest.h b/src/kbest.h new file mode 100644 index 00000000..cd9b6c2b --- /dev/null +++ b/src/kbest.h @@ -0,0 +1,207 @@ +#ifndef _HG_KBEST_H_ +#define _HG_KBEST_H_ + +#include +#include +#include + +#include + +#include "wordid.h" +#include "hg.h" + +namespace KBest { + // default, don't filter any derivations from the k-best list + struct NoFilter { + bool operator()(const std::vector& yield) { + (void) yield; + return false; + } + }; + + // optional, filter unique yield strings + struct FilterUnique { + std::tr1::unordered_set, boost::hash > > unique; + + bool operator()(const std::vector& yield) { + return !unique.insert(yield).second; + } + }; + + // utility class to lazily create the k-best derivations from a forest, uses + // the lazy k-best algorithm (Algorithm 3) from Huang and Chiang (IWPT 2005) + template + struct KBestDerivations { + KBestDerivations(const Hypergraph& hg, + const size_t k, + const Traversal& tf = Traversal(), + const WeightFunction& wf = WeightFunction()) : + traverse(tf), w(wf), g(hg), nds(g.nodes_.size()), k_prime(k) {} + + ~KBestDerivations() { + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + struct Derivation { + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv, + const WeightType& w, + const SparseVector& f) : + edge(&e), + j(jv), + score(w), + feature_values(f) {} + + // dummy constructor, just for query + Derivation(const Hypergraph::Edge& e, + const SmallVector& jv) : edge(&e), j(jv) {} + + T yield; + const Hypergraph::Edge* const edge; + const SmallVector j; + const WeightType score; + const SparseVector feature_values; + }; + struct HeapCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score < b->score; + } + }; + struct DerivationCompare { + bool operator()(const Derivation* a, const Derivation* b) const { + return a->score > b->score; + } + }; + struct DerivationUniquenessHash { + size_t operator()(const Derivation* d) const { + size_t x = 5381; + x = ((x << 5) + x) ^ d->edge->id_; + for (int i = 0; i < d->j.size(); ++i) + x = ((x << 5) + x) ^ d->j[i]; + return x; + } + }; + struct DerivationUniquenessEquals { + bool operator()(const Derivation* a, const Derivation* b) const { + return (a->edge == b->edge) && (a->j == b->j); + } + }; + typedef std::vector CandidateHeap; + typedef std::vector DerivationList; + typedef std::tr1::unordered_set< + const Derivation*, DerivationUniquenessHash, DerivationUniquenessEquals> UniqueDerivationSet; + + struct NodeDerivationState { + CandidateHeap cand; + DerivationList D; + DerivationFilter filter; + UniqueDerivationSet ds; + explicit NodeDerivationState(const DerivationFilter& f = DerivationFilter()) : filter(f) {} + }; + + Derivation* LazyKthBest(int v, int k) { + NodeDerivationState& s = GetCandidates(v); + CandidateHeap& cand = s.cand; + DerivationList& D = s.D; + DerivationFilter& filter = s.filter; + bool add_next = true; + while (D.size() <= k) { + if (add_next && D.size() > 0) { + const Derivation* d = D.back(); + LazyNext(d, &cand, &s.ds); + } + add_next = false; + + if (cand.size() > 0) { + std::pop_heap(cand.begin(), cand.end(), HeapCompare()); + Derivation* d = cand.back(); + cand.pop_back(); + std::vector ants(d->edge->Arity()); + for (int j = 0; j < ants.size(); ++j) + ants[j] = &LazyKthBest(d->edge->tail_nodes_[j], d->j[j])->yield; + traverse(*d->edge, ants, &d->yield); + if (!filter(d->yield)) { + D.push_back(d); + add_next = true; + } + } else { + break; + } + } + if (k < D.size()) return D[k]; else return NULL; + } + + private: + // creates a derivation object with all fields set but the yield + // the yield is computed in LazyKthBest before the derivation is added to D + // returns NULL if j refers to derivation numbers larger than the + // antecedent structure define + Derivation* CreateDerivation(const Hypergraph::Edge& e, const SmallVector& j) { + WeightType score = w(e); + SparseVector feats = e.feature_values_; + for (int i = 0; i < e.Arity(); ++i) { + const Derivation* ant = LazyKthBest(e.tail_nodes_[i], j[i]); + if (!ant) { return NULL; } + score *= ant->score; + feats += ant->feature_values; + } + freelist.push_back(new Derivation(e, j, score, feats)); + return freelist.back(); + } + + NodeDerivationState& GetCandidates(int v) { + NodeDerivationState& s = nds[v]; + if (!s.D.empty() || !s.cand.empty()) return s; + + const Hypergraph::Node& node = g.nodes_[v]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[node.in_edges_[i]]; + SmallVector jv(edge.Arity(), 0); + Derivation* d = CreateDerivation(edge, jv); + assert(d); + s.cand.push_back(d); + } + + const int effective_k = std::min(k_prime, s.cand.size()); + const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k; + std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare()); + s.cand.resize(effective_k); + std::make_heap(s.cand.begin(), s.cand.end(), HeapCompare()); + + return s; + } + + void LazyNext(const Derivation* d, CandidateHeap* cand, UniqueDerivationSet* ds) { + for (int i = 0; i < d->j.size(); ++i) { + SmallVector j = d->j; + ++j[i]; + const Derivation* ant = LazyKthBest(d->edge->tail_nodes_[i], j[i]); + if (ant) { + Derivation query_unique(*d->edge, j); + if (ds->count(&query_unique) == 0) { + Derivation* new_d = CreateDerivation(*d->edge, j); + if (new_d) { + cand->push_back(new_d); + std::push_heap(cand->begin(), cand->end(), HeapCompare()); + assert(ds->insert(new_d).second); // insert into uniqueness set, sanity check + } + } + } + } + } + + const Traversal traverse; + const WeightFunction w; + const Hypergraph& g; + std::vector nds; + std::vector freelist; + const size_t k_prime; + }; +} + +#endif diff --git a/src/lattice.cc b/src/lattice.cc new file mode 100644 index 00000000..aa1df3db --- /dev/null +++ b/src/lattice.cc @@ -0,0 +1,27 @@ +#include "lattice.h" + +#include "tdict.h" +#include "hg_io.h" + +using namespace std; + +bool LatticeTools::LooksLikePLF(const string &line) { + return (line.size() > 5) && (line.substr(0,4) == "((('"); +} + +void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { + Lattice& l = *pl; + vector ids; + TD::ConvertSentence(text, &ids); + l.resize(ids.size()); + for (int i = 0; i < l.size(); ++i) + l[i].push_back(LatticeArc(ids[i], 0.0, 1)); +} + +void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { + if (LooksLikePLF(text_or_plf)) + HypergraphIO::PLFtoLattice(text_or_plf, pl); + else + ConvertTextToLattice(text_or_plf, pl); +} + diff --git a/src/lattice.h b/src/lattice.h new file mode 100644 index 00000000..1177e768 --- /dev/null +++ b/src/lattice.h @@ -0,0 +1,31 @@ +#ifndef __LATTICE_H_ +#define __LATTICE_H_ + +#include +#include +#include "wordid.h" + +struct LatticeArc { + WordID label; + double cost; + int dist2next; + LatticeArc() : label(), cost(), dist2next() {} + LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {} +}; + +class Lattice : public std::vector > { + public: + Lattice() {} + explicit Lattice(size_t t, const std::vector& v = std::vector()) : + std::vector >(t, v) {} + + // TODO add distance functions +}; + +struct LatticeTools { + static bool LooksLikePLF(const std::string &line); + static void ConvertTextToLattice(const std::string& text, Lattice* pl); + static void ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); +}; + +#endif diff --git a/src/lexcrf.cc b/src/lexcrf.cc new file mode 100644 index 00000000..33455a3d --- /dev/null +++ b/src/lexcrf.cc @@ -0,0 +1,112 @@ +#include "lexcrf.h" + +#include + +#include "filelib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct LexicalCRFImpl { + LexicalCRFImpl(const boost::program_options::variables_map& conf) : + use_null(false), + kXCAT(TD::Convert("X")*-1), + kNULL(TD::Convert("")), + kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + vector gfiles = conf["grammar"].as >(); + assert(gfiles.size() == 1); + ReadFile rf(gfiles.front()); + TextGrammar *tg = new TextGrammar; + grammar.reset(tg); + istream* in = rf.stream(); + int lc = 0; + bool flag = false; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + ++lc; + TRulePtr r(TRule::CreateRulePhrasetable(line)); + tg->AddRule(r); + if (lc % 50000 == 0) { cerr << '.'; flag = true; } + if (lc % 2000000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + } + if (flag) cerr << endl; + cerr << "Loaded " << lc << " rules\n"; + } + + void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + const int e_len = smeta.GetTargetLength(); + assert(e_len > 0); + const int f_len = lattice.size(); + // hack to tell the feature function system how big the sentence pair is + const int f_start = (use_null ? -1 : 0); + int prev_node_id = -1; + for (int i = 0; i < e_len; ++i) { // for each word in the *ref* + Hypergraph::Node* node = forest->AddNode(kXCAT); + const int new_node_id = node->id_; + for (int j = f_start; j < f_len; ++j) { // for each word in the source + const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); + const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); + if (!gi) { + cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; + abort(); + } + const RuleBin* rb = gi->GetRules(); + assert(rb); + for (int k = 0; k < rb->GetNumRules(); ++k) { + TRulePtr rule = rb->GetIthRule(k); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + } + if (prev_node_id >= 0) { + const int comb_node_id = forest->AddNode(kXCAT)->id_; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = new_node_id; + const int edge_id = forest->AddEdge(kBINARY, tail)->id_; + forest->ConnectEdgeToHeadNode(edge_id, comb_node_id); + prev_node_id = comb_node_id; + } else { + prev_node_id = new_node_id; + } + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + + private: + const bool use_null; + const WordID kXCAT; + const WordID kNULL; + const TRulePtr kBINARY; + const TRulePtr kGOAL_RULE; + GrammarPtr grammar; +}; + +LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) : + pimpl_(new LexicalCRFImpl(conf)) {} + +bool LexicalCRF::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + Lattice lattice; + LatticeTools::ConvertTextToLattice(input, &lattice); + smeta->SetSourceLength(lattice.size()); + pimpl_->BuildTrellis(lattice, *smeta, forest); + forest->Reweight(weights); + return true; +} + diff --git a/src/lexcrf.h b/src/lexcrf.h new file mode 100644 index 00000000..99362c81 --- /dev/null +++ b/src/lexcrf.h @@ -0,0 +1,18 @@ +#ifndef _LEXCRF_H_ +#define _LEXCRF_H_ + +#include "translator.h" +#include "lattice.h" + +struct LexicalCRFImpl; +struct LexicalCRF : public Translator { + LexicalCRF(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/lm_ff.cc b/src/lm_ff.cc new file mode 100644 index 00000000..f95140de --- /dev/null +++ b/src/lm_ff.cc @@ -0,0 +1,328 @@ +#include "lm_ff.h" + +#include +#include +#include +#include +#include +#include + +#include "tdict.h" +#include "Vocab.h" +#include "Ngram.h" +#include "hg.h" +#include "stringlib.h" + +using namespace std; + +struct LMClient { + struct Cache { + map tree; + float prob; + Cache() : prob() {} + }; + + LMClient(const char* host) : port(6666) { + s = strchr(host, ':'); + if (s != NULL) { + *s = '\0'; + ++s; + port = atoi(s); + } + sock = socket(AF_INET, SOCK_STREAM, 0); + hp = gethostbyname(host); + if (hp == NULL) { + cerr << "unknown host " << host << endl; + abort(); + } + bzero((char *)&server, sizeof(server)); + bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); + server.sin_family = hp->h_addrtype; + server.sin_port = htons(port); + + int errors = 0; + while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { + cerr << "Error: connect()\n"; + sleep(1); + errors++; + if (errors > 3) exit(1); + } + cerr << "Connected to LM on " << host << " on port " << port << endl; + } + + float wordProb(int word, int* context) { + Cache* cur = &cache; + int i = 0; + while (context[i] > 0) { + cur = &cur->tree[context[i++]]; + } + cur = &cur->tree[word]; + if (cur->prob) { return cur->prob; } + + i = 0; + ostringstream os; + os << "prob " << TD::Convert(word); + while (context[i] > 0) { + os << ' ' << TD::Convert(context[i++]); + } + os << endl; + string out = os.str(); + write(sock, out.c_str(), out.size()); + int r = read(sock, res, 6); + int errors = 0; + int cnt = 0; + while (1) { + if (r < 0) { + errors++; sleep(1); + cerr << "Error: read()\n"; + if (errors > 5) exit(1); + } else if (r==0 || res[cnt] == '\n') { break; } + else { + cnt += r; + if (cnt==6) break; + read(sock, &res[cnt], 6-cnt); + } + } + cur->prob = *reinterpret_cast(res); + return cur->prob; + } + + void clear() { + cache.tree.clear(); + } + + private: + Cache cache; + int sock, port; + char *s; + struct hostent *hp; + struct sockaddr_in server; + char res[8]; +}; + +class LanguageModelImpl { + public: + LanguageModelImpl(int order, const string& f) : + ngram_(*TD::dict_), buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + client_(NULL), + kSTART(TD::Convert("")), + kSTOP(TD::Convert("")), + kUNKNOWN(TD::Convert("")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) { + if (f.find("lm://") == 0) { + client_ = new LMClient(f.substr(5).c_str()); + } else { + File file(f.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << f << endl; + ngram_.read(file, false); + } + } + + ~LanguageModelImpl() { + delete client_; + } + + inline int StateSize(const void* state) const { + return *(static_cast(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast(state) + state_size_) = size; + } + + inline double LookupProbForBufferContents(int i) { + double p = client_ ? + client_->wordProb(buffer_[i], &buffer_[i+1]) + : ngram_.wordProb(buffer_[i], (VocabIndex*)&buffer_[i+1]); + if (p < floor_) p = floor_; + return p; + } + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + sum += LookupProbForBufferContents(i); + } + --i; + } + return sum; + } + + double EstimateProb(const vector& phrase) { + int len = phrase.size(); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = phrase[j]; + return ProbNoRemnant(len - 1, len); + } + + double EstimateProb(const void* state) { + int len = StateSize(state); + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + const int* astate = reinterpret_cast(state); + int i = len - 1; + for (int j = 0; j < len; ++j,--i) + buffer_[i] = astate[j]; + return ProbNoRemnant(len - 1, len); + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + double LookupWords(const TRule& rule, const vector& ant_states, void* vstate) { + int len = rule.ELength() - rule.Arity(); + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector& e = rule.e(); + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double sum = 0.0; + int* remnant = reinterpret_cast(vstate); + int j = 0; + i = len - 1; + int edge = len; + + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + sum += LookupProbForBufferContents(i); + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + if (!remnant) return sum; + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + return sum; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + private: + Ngram ngram_; + vector buffer_; + const int order_; + const int state_size_; + const double floor_; + LMClient* client_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +LanguageModel::LanguageModel(const string& param) : + fid_(FD::Convert("LanguageModel")) { + vector argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + // TODO add support for -n FeatureName + string filename; + if (argc < 1) { cerr << "LanguageModel requires a filename, minimally!\n"; abort(); } + else if (argc == 1) { filename = argv[0]; } + else if (argc == 2 || argc > 3) { cerr << "Don't understand 'LanguageModel " << param << "'\n"; } + else if (argc == 3) { + if (argv[0] == "-o") { + order = atoi(argv[1].c_str()); + filename = argv[2]; + } else if (argv[1] == "-o") { + order = atoi(argv[2].c_str()); + filename = argv[0]; + } + } + SetStateSize(LanguageModelImpl::OrderToStateSize(order)); + pimpl_ = new LanguageModelImpl(order, filename); +} + +LanguageModel::~LanguageModel() { + delete pimpl_; +} + +string LanguageModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void LanguageModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state)); + estimated_features->set_value(fid_, pimpl_->EstimateProb(state)); +} + +void LanguageModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} + diff --git a/src/lm_ff.h b/src/lm_ff.h new file mode 100644 index 00000000..cd717360 --- /dev/null +++ b/src/lm_ff.h @@ -0,0 +1,32 @@ +#ifndef _LM_FF_H_ +#define _LM_FF_H_ + +#include +#include + +#include "hg.h" +#include "ff.h" + +class LanguageModelImpl; + +class LanguageModel : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + LanguageModel(const std::string& param); + ~LanguageModel(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + std::string DebugStateToString(const void* state) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; + mutable LanguageModelImpl* pimpl_; +}; + +#endif diff --git a/src/logval.h b/src/logval.h new file mode 100644 index 00000000..a8ca620c --- /dev/null +++ b/src/logval.h @@ -0,0 +1,136 @@ +#ifndef LOGVAL_H_ +#define LOGVAL_H_ + +#include +#include + +template +class LogVal { + public: + LogVal() : v_(-std::numeric_limits::infinity()) {} + explicit LogVal(double x) : v_(std::log(x)) {} + LogVal(const LogVal& o) : v_(o.v_) {} + static LogVal One() { return LogVal(1); } + static LogVal Zero() { return LogVal(); } + + void logeq(const T& v) { v_ = v; } + + LogVal& operator+=(const LogVal& a) { + if (a.v_ == -std::numeric_limits::infinity()) return *this; + if (a.v_ < v_) { + v_ = v_ + log1p(std::exp(a.v_ - v_)); + } else { + v_ = a.v_ + log1p(std::exp(v_ - a.v_)); + } + return *this; + } + + LogVal& operator*=(const LogVal& a) { + v_ += a.v_; + return *this; + } + + LogVal& operator*=(const T& a) { + v_ += log(a); + return *this; + } + + LogVal& operator/=(const LogVal& a) { + v_ -= a.v_; + return *this; + } + + LogVal& poweq(const T& power) { + if (power == 0) v_ = 0; else v_ *= power; + return *this; + } + + LogVal pow(const T& power) const { + LogVal res = *this; + res.poweq(power); + return res; + } + + operator T() const { + return std::exp(v_); + } + + T v_; +}; + +template +LogVal operator+(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res += o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const LogVal& o1, const T& o2) { + LogVal res(o1); + res *= o2; + return res; +} + +template +LogVal operator*(const T& o1, const LogVal& o2) { + LogVal res(o2); + res *= o1; + return res; +} + +template +LogVal operator/(const LogVal& o1, const LogVal& o2) { + LogVal res(o1); + res /= o2; + return res; +} + +template +T log(const LogVal& o) { + return o.v_; +} + +template +LogVal pow(const LogVal& b, const T& e) { + return b.pow(e); +} + +template +bool operator<(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ < rhs.v_); +} + +template +bool operator<=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ <= rhs.v_); +} + +template +bool operator>(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ > rhs.v_); +} + +template +bool operator>=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ >= rhs.v_); +} + +template +bool operator==(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ == rhs.v_); +} + +template +bool operator!=(const LogVal& lhs, const LogVal& rhs) { + return (lhs.v_ != rhs.v_); +} + +#endif diff --git a/src/maxtrans_blunsom.cc b/src/maxtrans_blunsom.cc new file mode 100644 index 00000000..4a6680e0 --- /dev/null +++ b/src/maxtrans_blunsom.cc @@ -0,0 +1,287 @@ +#include "apply_models.h" + +#include +#include +#include +#include + +#include +#include + +#include "tdict.h" +#include "hg.h" +#include "ff.h" + +using boost::tuple; +using namespace std; +using namespace std::tr1; + +namespace Hack { + +struct Candidate; +typedef SmallVector JVector; +typedef vector CandidateHeap; +typedef vector CandidateList; + +// life cycle: candidates are created, placed on the heap +// and retrieved by their estimated cost, when they're +// retrieved, they're incorporated into the +LM hypergraph +// where they also know the head node index they are +// attached to. After they are added to the +LM hypergraph +// inside_prob_ and est_prob_ fields may be updated as better +// derivations are found (this happens since the successor's +// of derivation d may have a better score- they are +// explored lazily). However, the updates don't happen +// when a candidate is in the heap so maintaining the heap +// property is not an issue. +struct Candidate { + int node_index_; // -1 until incorporated + // into the +LM forest + const Hypergraph::Edge* in_edge_; // in -LM forest + Hypergraph::Edge out_edge_; + vector state_; + const JVector j_; + prob_t inside_prob_; // these are fixed until the cand + // is popped, then they may be updated + prob_t est_prob_; + + Candidate(const Hypergraph::Edge& e, + const JVector& j, + const vector& D, + bool is_goal) : + node_index_(-1), + in_edge_(&e), + j_(j) { + InitializeCandidate(D, is_goal); + } + + // used to query uniqueness + Candidate(const Hypergraph::Edge& e, + const JVector& j) : in_edge_(&e), j_(j) {} + + bool IsIncorporatedIntoHypergraph() const { + return node_index_ >= 0; + } + + void InitializeCandidate(const vector >& D, + const bool is_goal) { + const Hypergraph::Edge& in_edge = *in_edge_; + out_edge_.rule_ = in_edge.rule_; + out_edge_.feature_values_ = in_edge.feature_values_; + Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; + tail.resize(j_.size()); + prob_t p = prob_t::One(); + // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; + vector* > ants(tail.size()); + for (int i = 0; i < tail.size(); ++i) { + const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; + ants[i] = &ant.state_; + assert(ant.IsIncorporatedIntoHypergraph()); + tail[i] = ant.node_index_; + p *= ant.inside_prob_; + } + prob_t edge_estimate = prob_t::One(); + if (is_goal) { + assert(tail.size() == 1); + out_edge_.edge_prob_ = in_edge.edge_prob_; + } else { + in_edge.rule_->ESubstitute(ants, &state_); + out_edge_.edge_prob_ = in_edge.edge_prob_; + } + inside_prob_ = out_edge_.edge_prob_ * p; + est_prob_ = inside_prob_ * edge_estimate; + } +}; + +ostream& operator<<(ostream& os, const Candidate& cand) { + os << "CAND["; + if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } + else { os << "+LM_node=" << cand.node_index_; } + os << " edge=" << cand.in_edge_->id_; + os << " j=<"; + for (int i = 0; i < cand.j_.size(); ++i) + os << (i==0 ? "" : " ") << cand.j_[i]; + os << "> vit=" << log(cand.inside_prob_); + os << " est=" << log(cand.est_prob_); + return os << ']'; +} + +struct HeapCandCompare { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ < r->est_prob_; + } +}; + +struct EstProbSorter { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ > r->est_prob_; + } +}; + +// the same candidate can be added multiple times if +// j is multidimensional (if you're going NW in Manhattan, you +// can first go north, then west, or you can go west then north) +// this is a hash function on the relevant variables from +// Candidate to enforce this. +struct CandidateUniquenessHash { + size_t operator()(const Candidate* c) const { + size_t x = 5381; + x = ((x << 5) + x) ^ c->in_edge_->id_; + for (int i = 0; i < c->j_.size(); ++i) + x = ((x << 5) + x) ^ c->j_[i]; + return x; + } +}; + +struct CandidateUniquenessEquals { + bool operator()(const Candidate* a, const Candidate* b) const { + return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); + } +}; + +typedef unordered_set UniqueCandidateSet; +typedef unordered_map, Candidate*, boost::hash > > State2Node; + +class MaxTransBeamSearch { + +public: + MaxTransBeamSearch(const Hypergraph& i, int pop_limit, Hypergraph* o) : + in(i), + out(*o), + D(in.nodes_.size()), + pop_limit_(pop_limit) { + cerr << " Finding max translation (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + cerr << " "; + for (int i = 0; i < in.nodes_.size(); ++i) { + cerr << '.'; + KBest(i, i == goal_id); + } + cerr << endl; + int best_node = D[goal_id].front()->in_edge_->tail_nodes_.front(); + Candidate& best = *D[best_node].front(); + cerr << " Best path: " << log(best.inside_prob_) + << "\t" << log(best.est_prob_) << endl; + cout << TD::GetString(D[best_node].front()->state_) << endl; + FreeAll(); + } + + private: + void FreeAll() { + for (int i = 0; i < D.size(); ++i) { + CandidateList& D_i = D[i]; + for (int j = 0; j < D_i.size(); ++j) + delete D_i[j]; + } + D.clear(); + } + + void IncorporateIntoPlusLMForest(Candidate* item, State2Node* s2n, CandidateList* freelist) { + Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_.rule_, item->out_edge_.tail_nodes_); + new_edge->feature_values_ = item->out_edge_.feature_values_; + new_edge->edge_prob_ = item->out_edge_.edge_prob_; + Candidate*& o_item = (*s2n)[item->state_]; + if (!o_item) o_item = item; + + int& node_id = o_item->node_index_; + if (node_id < 0) { + Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_, ""); + node_id = new_node->id_; + } + Hypergraph::Node* node = &out.nodes_[node_id]; + out.ConnectEdgeToHeadNode(new_edge, node); + + if (item != o_item) { + assert(o_item->state_ == item->state_); // sanity check! + o_item->est_prob_ += item->est_prob_; + o_item->inside_prob_ += item->inside_prob_; + freelist->push_back(item); + } + } + + void KBest(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_cands; + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, D, is_goal)); + assert(unique_cands.insert(cand.back()).second); // these should all be unique! + } +// cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + PushSucc(*item, is_goal, &cand, &unique_cands); + IncorporateIntoPlusLMForest(item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) + D_v[c++] = i->second; + sort(D_v.begin(), D_v.end(), EstProbSorter()); + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void PushSucc(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* cs) { + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (cs->count(&query_unique) == 0) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, D, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + assert(cs->insert(new_cand).second); // insert into uniqueness set, sanity check + } + } + } + } + + const Hypergraph& in; + Hypergraph& out; + + vector D; // maps nodes in in-HG to the + // equivalent nodes (many due to state + // splits) in the out-HG. + const int pop_limit_; +}; + +// each node in the graph has one of these, it keeps track of +void MaxTrans(const Hypergraph& in, + int beam_size) { + Hypergraph out; + MaxTransBeamSearch ma(in, beam_size, &out); + ma.Apply(); +} + +} diff --git a/src/parser_test.cc b/src/parser_test.cc new file mode 100644 index 00000000..da1fbd89 --- /dev/null +++ b/src/parser_test.cc @@ -0,0 +1,35 @@ +#include +#include +#include +#include +#include +#include "hg.h" +#include "trule.h" +#include "bottom_up_parser.h" +#include "tdict.h" + +using namespace std; + +class ChartTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(ChartTest,LanguageModel) { + LatticeArc a(TD::Convert("ein"), 0.0, 1); + LatticeArc b(TD::Convert("haus"), 0.0, 1); + Lattice lattice(2); + lattice[0].push_back(a); + lattice[1].push_back(b); + Hypergraph forest; + GrammarPtr g(new TextGrammar); + vector grammars(1, g); + ExhaustiveBottomUpParser parser("PHRASE", grammars); + parser.Parse(lattice, &forest); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/phrasebased_translator.cc b/src/phrasebased_translator.cc new file mode 100644 index 00000000..5eb70876 --- /dev/null +++ b/src/phrasebased_translator.cc @@ -0,0 +1,206 @@ +#include "phrasebased_translator.h" + +#include +#include +#include +#include + +#include +#include + +#include "sentence_metadata.h" +#include "tdict.h" +#include "hg.h" +#include "filelib.h" +#include "lattice.h" +#include "phrasetable_fst.h" +#include "array2d.h" + +using namespace std; +using namespace std::tr1; +using namespace boost::tuples; + +struct Coverage : public vector { + explicit Coverage(int n, bool v = false) : vector(n, v), first_gap() {} + void Cover(int i, int j) { + vector::iterator it = this->begin() + i; + vector::iterator end = this->begin() + j; + while (it != end) + *it++ = true; + if (first_gap == i) { + first_gap = j; + it = end; + while (*it && it != this->end()) { + ++it; + ++first_gap; + } + } + } + bool Collides(int i, int j) const { + vector::const_iterator it = this->begin() + i; + vector::const_iterator end = this->begin() + j; + while (it != end) + if (*it++) return true; + return false; + } + int GetFirstGap() const { return first_gap; } + private: + int first_gap; +}; +struct CoverageHash { + size_t operator()(const Coverage& cov) const { + return hasher_(static_cast&>(cov)); + } + private: + boost::hash > hasher_; +}; +ostream& operator<<(ostream& os, const Coverage& cov) { + os << '['; + for (int i = 0; i < cov.size(); ++i) + os << (cov[i] ? '*' : '.'); + return os << " gap=" << cov.GetFirstGap() << ']'; +} + +typedef unordered_map CoverageNodeMap; +typedef unordered_set UniqueCoverageSet; + +struct PhraseBasedTranslatorImpl { + PhraseBasedTranslatorImpl(const boost::program_options::variables_map& conf) : + add_pass_through_rules(conf.count("add_pass_through_rules")), + max_distortion(conf["pb_max_distortion"].as()), + kSOURCE_RULE(new TRule("[X] ||| [X,1] ||| [X,1]", true)), + kCONCAT_RULE(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]", true)), + kNT_TYPE(TD::Convert("X") * -1) { + assert(max_distortion >= 0); + vector gfiles = conf["grammar"].as >(); + assert(gfiles.size() == 1); + cerr << "Reading phrasetable from " << gfiles.front() << endl; + ReadFile in(gfiles.front()); + fst.reset(LoadTextPhrasetable(in.stream())); + } + + struct State { + State(const Coverage& c, int _i, int _j, const FSTNode* q) : + coverage(c), i(_i), j(_j), fst(q) {} + Coverage coverage; + int i; + int j; + const FSTNode* fst; + }; + + // we keep track of unique coverages that have been extended since it's + // possible to "extend" the same coverage twice, e.g. translate "a b c" + // with phrases "a" "b" "a b" and "c". There are two ways to cover "a b" + void EnqueuePossibleContinuations(const Coverage& coverage, queue* q, UniqueCoverageSet* ucs) { + if (ucs->insert(coverage).second) { + const int gap = coverage.GetFirstGap(); + const int end = min(static_cast(coverage.size()), gap + max_distortion + 1); + for (int i = gap; i < end; ++i) + if (!coverage[i]) q->push(State(coverage, i, i, fst.get())); + } + } + + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) { + Lattice lattice; + LatticeTools::ConvertTextOrPLF(input, &lattice); + smeta->SetSourceLength(lattice.size()); + size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion); + minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100); + if (add_pass_through_rules) { + SparseVector feats; + feats.set_value(FD::Convert("PassThrough"), 1); + for (int i = 0; i < lattice.size(); ++i) { + const vector& arcs = lattice[i]; + for (int j = 0; j < arcs.size(); ++j) { + fst->AddPassThroughTranslation(arcs[j].label, feats); + // TODO handle lattice edge features + } + } + } + CoverageNodeMap c; + queue q; + UniqueCoverageSet ucs; + const Coverage empty_cov(lattice.size(), false); + const Coverage goal_cov(lattice.size(), true); + EnqueuePossibleContinuations(empty_cov, &q, &ucs); + c[empty_cov] = 0; // have to handle the left edge specially + while(!q.empty()) { + const State s = q.front(); + q.pop(); + // cerr << "(" << s.i << "," << s.j << " ptr=" << s.fst << ") cov=" << s.coverage << endl; + const vector& arcs = lattice[s.j]; + if (s.fst->HasData()) { + Coverage new_cov = s.coverage; + new_cov.Cover(s.i, s.j); + EnqueuePossibleContinuations(new_cov, &q, &ucs); + const vector& phrases = s.fst->GetTranslations()->GetRules(); + const int phrase_head_index = minus_lm_forest->AddNode(kNT_TYPE)->id_; + for (int i = 0; i < phrases.size(); ++i) { + Hypergraph::Edge* edge = minus_lm_forest->AddEdge(phrases[i], Hypergraph::TailNodeVector()); + edge->feature_values_ = edge->rule_->scores_; + minus_lm_forest->ConnectEdgeToHeadNode(edge->id_, phrase_head_index); + } + CoverageNodeMap::iterator cit = c.find(s.coverage); + assert(cit != c.end()); + const int tail_node_plus1 = cit->second; + if (tail_node_plus1 == 0) { // left edge + c[new_cov] = phrase_head_index + 1; + } else { // not left edge + int& head_node_plus1 = c[new_cov]; + if (!head_node_plus1) + head_node_plus1 = minus_lm_forest->AddNode(kNT_TYPE)->id_ + 1; + Hypergraph::TailNodeVector tail(2, tail_node_plus1 - 1); + tail[1] = phrase_head_index; + const int concat_edge = minus_lm_forest->AddEdge(kCONCAT_RULE, tail)->id_; + minus_lm_forest->ConnectEdgeToHeadNode(concat_edge, head_node_plus1 - 1); + } + } + if (s.j == lattice.size()) continue; + for (int l = 0; l < arcs.size(); ++l) { + const LatticeArc& arc = arcs[l]; + + const FSTNode* next_fst_state = s.fst->Extend(arc.label); + const int next_j = s.j + arc.dist2next; + if (next_fst_state && + !s.coverage.Collides(s.i, next_j)) { + q.push(State(s.coverage, s.i, next_j, next_fst_state)); + } + } + } + if (add_pass_through_rules) + fst->ClearPassThroughTranslations(); + int pregoal_plus1 = c[goal_cov]; + if (pregoal_plus1 > 0) { + TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [X,1]")); + int goal = minus_lm_forest->AddNode(TD::Convert("Goal") * -1)->id_; + int gedge = minus_lm_forest->AddEdge(kGOAL_RULE, Hypergraph::TailNodeVector(1, pregoal_plus1 - 1))->id_; + minus_lm_forest->ConnectEdgeToHeadNode(gedge, goal); + // they are almost topo, but not quite always + minus_lm_forest->TopologicallySortNodesAndEdges(goal); + minus_lm_forest->Reweight(weights); + return true; + } else { + return false; // composition failed + } + } + + const bool add_pass_through_rules; + const int max_distortion; + TRulePtr kSOURCE_RULE; + const TRulePtr kCONCAT_RULE; + const WordID kNT_TYPE; + boost::shared_ptr fst; +}; + +PhraseBasedTranslator::PhraseBasedTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new PhraseBasedTranslatorImpl(conf)) {} + +bool PhraseBasedTranslator::Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} diff --git a/src/phrasebased_translator.h b/src/phrasebased_translator.h new file mode 100644 index 00000000..d42ce79c --- /dev/null +++ b/src/phrasebased_translator.h @@ -0,0 +1,18 @@ +#ifndef _PHRASEBASED_TRANSLATOR_H_ +#define _PHRASEBASED_TRANSLATOR_H_ + +#include "translator.h" + +class PhraseBasedTranslatorImpl; +class PhraseBasedTranslator : public Translator { + public: + PhraseBasedTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/phrasetable_fst.cc b/src/phrasetable_fst.cc new file mode 100644 index 00000000..f421e941 --- /dev/null +++ b/src/phrasetable_fst.cc @@ -0,0 +1,141 @@ +#include "phrasetable_fst.h" + +#include +#include +#include + +#include + +#include "filelib.h" +#include "tdict.h" + +using boost::shared_ptr; +using namespace std; + +TargetPhraseSet::~TargetPhraseSet() {} +FSTNode::~FSTNode() {} + +class TextTargetPhraseSet : public TargetPhraseSet { + public: + void AddRule(TRulePtr rule) { + rules_.push_back(rule); + } + const vector& GetRules() const { + return rules_; + } + + private: + // all rules must have arity 0 + vector rules_; +}; + +class TextFSTNode : public FSTNode { + public: + const TargetPhraseSet* GetTranslations() const { return data.get(); } + bool HasData() const { return (bool)data; } + bool HasOutgoingNonEpsilonEdges() const { return !ptr.empty(); } + const FSTNode* Extend(const WordID& t) const { + map::const_iterator it = ptr.find(t); + if (it == ptr.end()) return NULL; + return &it->second; + } + + void AddPhrase(const string& phrase); + + void AddPassThroughTranslation(const WordID& w, const SparseVector& feats); + void ClearPassThroughTranslations(); + private: + vector passthroughs; + shared_ptr data; + map ptr; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void TextFSTNode::AddPhrase(const string& phrase) { + vector words; + TRulePtr rule(TRule::CreateRulePhrasetable(phrase)); + if (!rule) { + static int err = 0; + ++err; + if (err > 2) { cerr << "TOO MANY PHRASETABLE ERRORS\n"; exit(1); } + return; + } + + TextFSTNode* fsa = this; + for (int i = 0; i < rule->FLength(); ++i) + fsa = &fsa->ptr[rule->f_[i]]; + + if (!fsa->data) + fsa->data.reset(new TextTargetPhraseSet); + static_cast(fsa->data.get())->AddRule(rule); +} + +void TextFSTNode::AddPassThroughTranslation(const WordID& w, const SparseVector& feats) { + TextFSTNode* next = &ptr[w]; + // current, rules are only added if the symbol is completely missing as a + // word starting the phrase. As a result, it is possible that some sentences + // won't parse. If this becomes a problem, fix it here. + if (!next->data) { + TextTargetPhraseSet* tps = new TextTargetPhraseSet; + next->data.reset(tps); + TRule* rule = new TRule; + rule->e_.resize(1, w); + rule->f_.resize(1, w); + rule->lhs_ = TD::Convert("___PHRASE") * -1; + rule->scores_ = feats; + rule->arity_ = 0; + tps->AddRule(TRulePtr(rule)); + passthroughs.push_back(w); + } +} + +void TextFSTNode::ClearPassThroughTranslations() { + for (int i = 0; i < passthroughs.size(); ++i) + ptr.erase(passthroughs[i]); + passthroughs.clear(); +} + +static void AddPhrasetableToFST(istream* in, TextFSTNode* fst) { + int lc = 0; + bool flag = false; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + ++lc; + fst->AddPhrase(line); + if (lc % 10000 == 0) { flag = true; cerr << '.' << flush; } + if (lc % 500000 == 0) { flag = false; cerr << " [" << lc << ']' << endl << flush; } + } + if (flag) cerr << endl; + cerr << "Loaded " << lc << " source phrases\n"; +} + +FSTNode* LoadTextPhrasetable(istream* in) { + TextFSTNode *fst = new TextFSTNode; + AddPhrasetableToFST(in, fst); + return fst; +} + +FSTNode* LoadTextPhrasetable(const vector& filenames) { + TextFSTNode* fst = new TextFSTNode; + for (int i = 0; i < filenames.size(); ++i) { + ReadFile rf(filenames[i]); + cerr << "Reading phrase from " << filenames[i] << endl; + AddPhrasetableToFST(rf.stream(), fst); + } + return fst; +} + +FSTNode* LoadBinaryPhrasetable(const string& fname_prefix) { + (void) fname_prefix; + assert(!"not implemented yet"); +} + diff --git a/src/phrasetable_fst.h b/src/phrasetable_fst.h new file mode 100644 index 00000000..477de1f7 --- /dev/null +++ b/src/phrasetable_fst.h @@ -0,0 +1,34 @@ +#ifndef _PHRASETABLE_FST_H_ +#define _PHRASETABLE_FST_H_ + +#include +#include + +#include "sparse_vector.h" +#include "trule.h" + +class TargetPhraseSet { + public: + virtual ~TargetPhraseSet(); + virtual const std::vector& GetRules() const = 0; +}; + +class FSTNode { + public: + virtual ~FSTNode(); + virtual const TargetPhraseSet* GetTranslations() const = 0; + virtual bool HasData() const = 0; + virtual bool HasOutgoingNonEpsilonEdges() const = 0; + virtual const FSTNode* Extend(const WordID& t) const = 0; + + // these should only be called on q_0: + virtual void AddPassThroughTranslation(const WordID& w, const SparseVector& feats) = 0; + virtual void ClearPassThroughTranslations() = 0; +}; + +// attn caller: you own the memory +FSTNode* LoadTextPhrasetable(const std::vector& filenames); +FSTNode* LoadTextPhrasetable(std::istream* in); +FSTNode* LoadBinaryPhrasetable(const std::string& fname_prefix); + +#endif diff --git a/src/prob.h b/src/prob.h new file mode 100644 index 00000000..bc297870 --- /dev/null +++ b/src/prob.h @@ -0,0 +1,8 @@ +#ifndef _PROB_H_ +#define _PROB_H_ + +#include "logval.h" + +typedef LogVal prob_t; + +#endif diff --git a/src/sampler.h b/src/sampler.h new file mode 100644 index 00000000..e5840f41 --- /dev/null +++ b/src/sampler.h @@ -0,0 +1,136 @@ +#ifndef SAMPLER_H_ +#define SAMPLER_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "prob.h" + +struct SampleSet; + +template +struct RandomNumberGenerator { + static uint32_t GetTrulyRandomSeed() { + uint32_t seed; + std::ifstream r("/dev/urandom"); + if (r) { + r.read((char*)&seed,sizeof(uint32_t)); + } + if (r.fail() || !r) { + std::cerr << "Warning: could not read from /dev/urandom. Seeding from clock" << std::endl; + seed = time(NULL); + } + std::cerr << "Seeding random number sequence to " << seed << std::endl; + return seed; + } + + RandomNumberGenerator() : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + uint32_t seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + explicit RandomNumberGenerator(uint32_t seed) : m_dist(0,1), m_generator(), m_random(m_generator,m_dist) { + if (!seed) seed = GetTrulyRandomSeed(); + m_generator.seed(seed); + } + + size_t SelectSample(const prob_t& a, const prob_t& b, double T = 1.0) { + if (T == 1.0) { + if (this->next() > (a / (a + b))) return 1; else return 0; + } else { + assert(!"not implemented"); + } + } + + // T is the annealing temperature, if desired + size_t SelectSample(const SampleSet& ss, double T = 1.0); + + // draw a value from U(0,1) + double next() {return m_random();} + + // draw a value from N(mean,var) + double NextNormal(double mean, double var) { + return boost::normal_distribution(mean, var)(m_random); + } + + // draw a value from a Poisson distribution + // lambda must be greater than 0 + int NextPoisson(int lambda) { + return boost::poisson_distribution(lambda)(m_random); + } + + bool AcceptMetropolisHastings(const prob_t& p_cur, + const prob_t& p_prev, + const prob_t& q_cur, + const prob_t& q_prev) { + const prob_t a = (p_cur / p_prev) * (q_prev / q_cur); + if (log(a) >= 0.0) return true; + return (prob_t(this->next()) < a); + } + + private: + boost::uniform_real<> m_dist; + RNG m_generator; + boost::variate_generator > m_random; +}; + +typedef RandomNumberGenerator MT19937; + +class SampleSet { + public: + const prob_t& operator[](int i) const { return m_scores[i]; } + bool empty() const { return m_scores.empty(); } + void add(const prob_t& s) { m_scores.push_back(s); } + void clear() { m_scores.clear(); } + size_t size() const { return m_scores.size(); } + std::vector m_scores; +}; + +template +size_t RandomNumberGenerator::SelectSample(const SampleSet& ss, double T) { + assert(T > 0.0); + assert(ss.m_scores.size() > 0); + if (ss.m_scores.size() == 1) return 0; + const prob_t annealing_factor(1.0 / T); + const bool anneal = (annealing_factor != prob_t::One()); + prob_t sum = prob_t::Zero(); + if (anneal) { + for (int i = 0; i < ss.m_scores.size(); ++i) + sum += ss.m_scores[i].pow(annealing_factor); // p^(1/T) + } else { + sum = std::accumulate(ss.m_scores.begin(), ss.m_scores.end(), prob_t::Zero()); + } + //for (size_t i = 0; i < ss.m_scores.size(); ++i) std::cerr << ss.m_scores[i] << ","; + //std::cerr << std::endl; + + prob_t random(this->next()); // random number between 0 and 1 + random *= sum; // scale with normalization factor + //std::cerr << "Random number " << random << std::endl; + + //now figure out which sample + size_t position = 1; + sum = ss.m_scores[0]; + if (anneal) { + sum.poweq(annealing_factor); + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position].pow(annealing_factor); + } else { + for (; position < ss.m_scores.size() && sum < random; ++position) + sum += ss.m_scores[position]; + } + //std::cout << "random: " << random << " sample: " << position << std::endl; + //std::cerr << "Sample: " << position-1 << std::endl; + //exit(1); + return position-1; +} + +#endif diff --git a/src/scfg_translator.cc b/src/scfg_translator.cc new file mode 100644 index 00000000..03602c6b --- /dev/null +++ b/src/scfg_translator.cc @@ -0,0 +1,66 @@ +#include "translator.h" + +#include + +#include "hg.h" +#include "grammar.h" +#include "bottom_up_parser.h" +#include "sentence_metadata.h" + +using namespace std; + +Translator::~Translator() {} + +struct SCFGTranslatorImpl { + SCFGTranslatorImpl(const boost::program_options::variables_map& conf) : + max_span_limit(conf["scfg_max_span_limit"].as()), + add_pass_through_rules(conf.count("add_pass_through_rules")), + goal(conf["goal"].as()), + default_nt(conf["scfg_default_nt"].as()) { + vector gfiles = conf["grammar"].as >(); + for (int i = 0; i < gfiles.size(); ++i) { + cerr << "Reading SCFG grammar from " << gfiles[i] << endl; + TextGrammar* g = new TextGrammar(gfiles[i]); + g->SetMaxSpan(max_span_limit); + grammars.push_back(GrammarPtr(g)); + } + if (!conf.count("scfg_no_hiero_glue_grammar")) + grammars.push_back(GrammarPtr(new GlueGrammar(goal, default_nt))); + if (conf.count("scfg_extra_glue_grammar")) + grammars.push_back(GrammarPtr(new GlueGrammar(conf["scfg_extra_glue_grammar"].as()))); + } + + const int max_span_limit; + const bool add_pass_through_rules; + const string goal; + const string default_nt; + vector grammars; + + bool Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + vector glist = grammars; + Lattice lattice; + LatticeTools::ConvertTextOrPLF(input, &lattice); + smeta->SetSourceLength(lattice.size()); + if (add_pass_through_rules) + glist.push_back(GrammarPtr(new PassThroughGrammar(lattice, default_nt))); + ExhaustiveBottomUpParser parser(goal, glist); + if (!parser.Parse(lattice, forest)) + return false; + forest->Reweight(weights); + return true; + } +}; + +SCFGTranslator::SCFGTranslator(const boost::program_options::variables_map& conf) : + pimpl_(new SCFGTranslatorImpl(conf)) {} + +bool SCFGTranslator::Translate(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* minus_lm_forest) { + return pimpl_->Translate(input, smeta, weights, minus_lm_forest); +} + diff --git a/src/sentence_metadata.h b/src/sentence_metadata.h new file mode 100644 index 00000000..0178f1f5 --- /dev/null +++ b/src/sentence_metadata.h @@ -0,0 +1,42 @@ +#ifndef _SENTENCE_METADATA_H_ +#define _SENTENCE_METADATA_H_ + +#include +#include "lattice.h" + +struct SentenceMetadata { + SentenceMetadata(int id, const Lattice& ref) : + sent_id_(id), + src_len_(-1), + has_reference_(ref.size() > 0), + trg_len_(ref.size()), + ref_(has_reference_ ? &ref : NULL) {} + + // this should be called by the Translator object after + // it has parsed the source + void SetSourceLength(int sl) { src_len_ = sl; } + + // this should be called if a separate model needs to + // specify how long the target sentence should be + void SetTargetLength(int tl) { + assert(!has_reference_); + trg_len_ = tl; + } + bool HasReference() const { return has_reference_; } + const Lattice& GetReference() const { return *ref_; } + int GetSourceLength() const { return src_len_; } + int GetTargetLength() const { return trg_len_; } + int GetSentenceID() const { return sent_id_; } + + private: + const int sent_id_; + int src_len_; + + // you need to be very careful when depending on these values + // they will only be set during training / alignment contexts + const bool has_reference_; + int trg_len_; + const Lattice* const ref_; +}; + +#endif diff --git a/src/small_vector.h b/src/small_vector.h new file mode 100644 index 00000000..800c1df1 --- /dev/null +++ b/src/small_vector.h @@ -0,0 +1,187 @@ +#ifndef _SMALL_VECTOR_H_ + +#include // std::max - where to get this? +#include +#include + +#define __SV_MAX_STATIC 2 + +class SmallVector { + + public: + SmallVector() : size_(0) {} + + explicit SmallVector(size_t s, int v = 0) : size_(s) { + assert(s < 0x80); + if (s <= __SV_MAX_STATIC) { + for (int i = 0; i < s; ++i) data_.vals[i] = v; + } else { + capacity_ = s; + size_ = s; + data_.ptr = new int[s]; + for (int i = 0; i < size_; ++i) data_.ptr[i] = v; + } + } + + SmallVector(const SmallVector& o) : size_(o.size_) { + if (size_ <= __SV_MAX_STATIC) { + for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + capacity_ = size_ = o.size_; + data_.ptr = new int[capacity_]; + std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int)); + } + } + + const SmallVector& operator=(const SmallVector& o) { + if (size_ <= __SV_MAX_STATIC) { + if (o.size_ <= __SV_MAX_STATIC) { + size_ = o.size_; + for (int i = 0; i < __SV_MAX_STATIC; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + capacity_ = size_ = o.size_; + data_.ptr = new int[capacity_]; + std::memcpy(data_.ptr, o.data_.ptr, size_ * sizeof(int)); + } + } else { + if (o.size_ <= __SV_MAX_STATIC) { + delete[] data_.ptr; + size_ = o.size_; + for (int i = 0; i < size_; ++i) data_.vals[i] = o.data_.vals[i]; + } else { + if (capacity_ < o.size_) { + delete[] data_.ptr; + capacity_ = o.size_; + data_.ptr = new int[capacity_]; + } + size_ = o.size_; + for (int i = 0; i < size_; ++i) + data_.ptr[i] = o.data_.ptr[i]; + } + } + return *this; + } + + ~SmallVector() { + if (size_ <= __SV_MAX_STATIC) return; + delete[] data_.ptr; + } + + void clear() { + if (size_ > __SV_MAX_STATIC) { + delete[] data_.ptr; + } + size_ = 0; + } + + bool empty() const { return size_ == 0; } + size_t size() const { return size_; } + + inline void ensure_capacity(unsigned char min_size) { + assert(min_size > __SV_MAX_STATIC); + if (min_size < capacity_) return; + unsigned char new_cap = std::max(static_cast(capacity_ << 1), min_size); + int* tmp = new int[new_cap]; + std::memcpy(tmp, data_.ptr, capacity_ * sizeof(int)); + delete[] data_.ptr; + data_.ptr = tmp; + capacity_ = new_cap; + } + + inline void copy_vals_to_ptr() { + capacity_ = __SV_MAX_STATIC * 2; + int* tmp = new int[capacity_]; + for (int i = 0; i < __SV_MAX_STATIC; ++i) tmp[i] = data_.vals[i]; + data_.ptr = tmp; + } + + inline void push_back(int v) { + if (size_ < __SV_MAX_STATIC) { + data_.vals[size_] = v; + ++size_; + return; + } else if (size_ == __SV_MAX_STATIC) { + copy_vals_to_ptr(); + } else if (size_ == capacity_) { + ensure_capacity(size_ + 1); + } + data_.ptr[size_] = v; + ++size_; + } + + int& back() { return this->operator[](size_ - 1); } + const int& back() const { return this->operator[](size_ - 1); } + int& front() { return this->operator[](0); } + const int& front() const { return this->operator[](0); } + + void resize(size_t s, int v = 0) { + if (s <= __SV_MAX_STATIC) { + if (size_ > __SV_MAX_STATIC) { + int tmp[__SV_MAX_STATIC]; + for (int i = 0; i < s; ++i) tmp[i] = data_.ptr[i]; + delete[] data_.ptr; + for (int i = 0; i < s; ++i) data_.vals[i] = tmp[i]; + size_ = s; + return; + } + if (s <= size_) { + size_ = s; + return; + } else { + for (int i = size_; i < s; ++i) + data_.vals[i] = v; + size_ = s; + return; + } + } else { + if (size_ <= __SV_MAX_STATIC) + copy_vals_to_ptr(); + if (s > capacity_) + ensure_capacity(s); + if (s > size_) { + for (int i = size_; i < s; ++i) + data_.ptr[i] = v; + } + size_ = s; + } + } + + int& operator[](size_t i) { + if (size_ <= __SV_MAX_STATIC) return data_.vals[i]; + return data_.ptr[i]; + } + + const int& operator[](size_t i) const { + if (size_ <= __SV_MAX_STATIC) return data_.vals[i]; + return data_.ptr[i]; + } + + bool operator==(const SmallVector& o) const { + if (size_ != o.size_) return false; + if (size_ <= __SV_MAX_STATIC) { + for (size_t i = 0; i < size_; ++i) + if (data_.vals[i] != o.data_.vals[i]) return false; + return true; + } else { + for (size_t i = 0; i < size_; ++i) + if (data_.ptr[i] != o.data_.ptr[i]) return false; + return true; + } + } + + private: + unsigned char capacity_; // only defined when size_ >= __SV_MAX_STATIC + unsigned char size_; + union StorageType { + int vals[__SV_MAX_STATIC]; + int* ptr; + }; + StorageType data_; + +}; + +inline bool operator!=(const SmallVector& a, const SmallVector& b) { + return !(a==b); +} + +#endif diff --git a/src/small_vector_test.cc b/src/small_vector_test.cc new file mode 100644 index 00000000..84237791 --- /dev/null +++ b/src/small_vector_test.cc @@ -0,0 +1,129 @@ +#include "small_vector.h" + +#include +#include +#include +#include + +using namespace std; + +class SVTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(SVTest, LargerThan2) { + SmallVector v; + SmallVector v2; + v.push_back(0); + v.push_back(1); + v.push_back(2); + assert(v.size() == 3); + assert(v[2] == 2); + assert(v[1] == 1); + assert(v[0] == 0); + v2 = v; + SmallVector copy(v); + assert(copy.size() == 3); + assert(copy[0] == 0); + assert(copy[1] == 1); + assert(copy[2] == 2); + assert(copy == v2); + copy[1] = 99; + assert(copy != v2); + assert(v2.size() == 3); + assert(v2[2] == 2); + assert(v2[1] == 1); + assert(v2[0] == 0); + v2[0] = -2; + v2[1] = -1; + v2[2] = 0; + assert(v2[2] == 0); + assert(v2[1] == -1); + assert(v2[0] == -2); + SmallVector v3(1,1); + assert(v3[0] == 1); + v2 = v3; + assert(v2.size() == 1); + assert(v2[0] == 1); + SmallVector v4(10, 1); + assert(v4.size() == 10); + assert(v4[5] == 1); + assert(v4[9] == 1); + v4 = v; + assert(v4.size() == 3); + assert(v4[2] == 2); + assert(v4[1] == 1); + assert(v4[0] == 0); + SmallVector v5(10, 2); + assert(v5.size() == 10); + assert(v5[7] == 2); + assert(v5[0] == 2); + assert(v.size() == 3); + v = v5; + assert(v.size() == 10); + assert(v[2] == 2); + assert(v[9] == 2); + SmallVector cc; + for (int i = 0; i < 33; ++i) + cc.push_back(i); + for (int i = 0; i < 33; ++i) + assert(cc[i] == i); + cc.resize(20); + assert(cc.size() == 20); + for (int i = 0; i < 20; ++i) + assert(cc[i] == i); + cc[0]=-1; + cc.resize(1, 999); + assert(cc.size() == 1); + assert(cc[0] == -1); + cc.resize(99, 99); + for (int i = 1; i < 99; ++i) { + cerr << i << " " << cc[i] << endl; + assert(cc[i] == 99); + } + cc.clear(); + assert(cc.size() == 0); +} + +TEST_F(SVTest, Small) { + SmallVector v; + SmallVector v1(1,0); + SmallVector v2(2,10); + SmallVector v1a(2,0); + EXPECT_TRUE(v1 != v1a); + EXPECT_TRUE(v1 == v1); + EXPECT_EQ(v1[0], 0); + EXPECT_EQ(v2[1], 10); + EXPECT_EQ(v2[0], 10); + ++v2[1]; + --v2[0]; + EXPECT_EQ(v2[0], 9); + EXPECT_EQ(v2[1], 11); + SmallVector v3(v2); + assert(v3[0] == 9); + assert(v3[1] == 11); + assert(!v3.empty()); + assert(v3.size() == 2); + v3.clear(); + assert(v3.empty()); + assert(v3.size() == 0); + assert(v3 != v2); + assert(v2 != v3); + v3 = v2; + assert(v3 == v2); + assert(v2 == v3); + assert(v3[0] == 9); + assert(v3[1] == 11); + assert(!v3.empty()); + assert(v3.size() == 2); + cerr << sizeof(SmallVector) << endl; + cerr << sizeof(vector) << endl; +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/sparse_vector.cc b/src/sparse_vector.cc new file mode 100644 index 00000000..4035b9ef --- /dev/null +++ b/src/sparse_vector.cc @@ -0,0 +1,98 @@ +#include "sparse_vector.h" + +#include +#include + +#include "hg_io.h" + +using namespace std; + +namespace B64 { + +void Encode(double objective, const SparseVector& v, ostream* out) { + const int num_feats = v.num_active(); + size_t tot_size = 0; + const size_t off_objective = tot_size; + tot_size += sizeof(double); // objective + const size_t off_num_feats = tot_size; + tot_size += sizeof(int); // num_feats + const size_t off_data = tot_size; + tot_size += sizeof(unsigned char) * num_feats; // lengths of feature names; + typedef SparseVector::const_iterator const_iterator; + for (const_iterator it = v.begin(); it != v.end(); ++it) + tot_size += FD::Convert(it->first).size(); // feature names; + tot_size += sizeof(double) * num_feats; // gradient + const size_t off_magic = tot_size; + tot_size += 4; // magic + + // size_t b64_size = tot_size * 4 / 3; + // cerr << "Sparse vector binary size: " << tot_size << " (b64 size=" << b64_size << ")\n"; + char* data = new char[tot_size]; + *reinterpret_cast(&data[off_objective]) = objective; + *reinterpret_cast(&data[off_num_feats]) = num_feats; + char* cur = &data[off_data]; + assert(cur - data == off_data); + for (const_iterator it = v.begin(); it != v.end(); ++it) { + const string& fname = FD::Convert(it->first); + *cur++ = static_cast(fname.size()); // name len + memcpy(cur, &fname[0], fname.size()); + cur += fname.size(); + *reinterpret_cast(cur) = it->second; + cur += sizeof(double); + } + assert(cur - data == off_magic); + *reinterpret_cast(cur) = 0xBAABABBAu; + cur += sizeof(unsigned int); + assert(cur - data == tot_size); + b64encode(data, tot_size, out); + delete[] data; +} + +bool Decode(double* objective, SparseVector* v, const char* in, size_t size) { + v->clear(); + if (size % 4 != 0) { + cerr << "B64 error - line % 4 != 0\n"; + return false; + } + const size_t decoded_size = size * 3 / 4 - sizeof(unsigned int); + const size_t buf_size = decoded_size + sizeof(unsigned int); + if (decoded_size < 6) { cerr << "SparseVector decoding error: too short!\n"; return false; } + char* data = new char[buf_size]; + if (!b64decode(reinterpret_cast(in), size, data, buf_size)) { + delete[] data; + return false; + } + size_t cur = 0; + *objective = *reinterpret_cast(data); + cur += sizeof(double); + const int num_feats = *reinterpret_cast(&data[cur]); + cur += sizeof(int); + int fc = 0; + while(fc < num_feats && cur < decoded_size) { + ++fc; + const int fname_len = data[cur++]; + assert(fname_len > 0); + assert(fname_len < 256); + string fname(fname_len, '\0'); + memcpy(&fname[0], &data[cur], fname_len); + cur += fname_len; + const double val = *reinterpret_cast(&data[cur]); + cur += sizeof(double); + int fid = FD::Convert(fname); + v->set_value(fid, val); + } + if(num_feats != fc) { + cerr << "Expected " << num_feats << " but only decoded " << fc << "!\n"; + delete[] data; + return false; + } + if (*reinterpret_cast(&data[cur]) != 0xBAABABBAu) { + cerr << "SparseVector decodeding error : magic does not match!\n"; + delete[] data; + return false; + } + delete[] data; + return true; +} + +} diff --git a/src/sparse_vector.h b/src/sparse_vector.h new file mode 100644 index 00000000..6a8c9bf4 --- /dev/null +++ b/src/sparse_vector.h @@ -0,0 +1,264 @@ +#ifndef _SPARSE_VECTOR_H_ +#define _SPARSE_VECTOR_H_ + +// this is a modified version of code originally written +// by Phil Blunsom + +#include +#include +#include +#include + +#include "fdict.h" + +template +class SparseVector { +public: + SparseVector() {} + + const T operator[](int index) const { + typename std::map::const_iterator found = _values.find(index); + if (found == _values.end()) + return T(0); + else + return found->second; + } + + void set_value(int index, const T &value) { + _values[index] = value; + } + + void add_value(int index, const T &value) { + _values[index] += value; + } + + T value(int index) const { + typename std::map::const_iterator found = _values.find(index); + if (found != _values.end()) + return found->second; + else + return T(0); + } + + void store(std::valarray* target) const { + (*target) *= 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) { + if (it->first >= target->size()) break; + (*target)[it->first] = it->second; + } + } + + int max_index() const { + if (_values.empty()) return 0; + typename std::map::const_iterator found =_values.end(); + --found; + return found->first; + } + + // dot product with a unit vector of the same length + // as the sparse vector + T dot() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second; + return sum; + } + + template + S dot(const SparseVector &vec) const { + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + { + typename std::map::const_iterator + found = vec._values.find(it->first); + if (found != vec._values.end()) + sum += it->second * found->second; + } + return sum; + } + + template + S dot(const std::vector &vec) const { + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + { + if (it->first < static_cast(vec.size())) + sum += it->second * vec[it->first]; + } + return sum; + } + + template + S dot(const S *vec) const { + // this is not range checked! + S sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second * vec[it->first]; + std::cout << "dot(*vec) " << sum << std::endl; + return sum; + } + + T l1norm() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += fabs(it->second); + return sum; + } + + T l2norm() const { + T sum = 0; + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + sum += it->second * it->second; + return sqrt(sum); + } + + SparseVector &operator+=(const SparseVector &other) { + for (typename std::map::const_iterator + it = other._values.begin(); it != other._values.end(); ++it) + { + T v = (_values[it->first] += it->second); + if (v == 0) + _values.erase(it->first); + } + return *this; + } + + SparseVector &operator-=(const SparseVector &other) { + for (typename std::map::const_iterator + it = other._values.begin(); it != other._values.end(); ++it) + { + T v = (_values[it->first] -= it->second); + if (v == 0) + _values.erase(it->first); + } + return *this; + } + + SparseVector &operator-=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second -= x; + return *this; + } + + SparseVector &operator+=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second += x; + return *this; + } + + SparseVector &operator/=(const double &x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second /= x; + return *this; + } + + SparseVector &operator*=(const T& x) { + for (typename std::map::iterator + it = _values.begin(); it != _values.end(); ++it) + it->second *= x; + return *this; + } + + SparseVector operator+(const double &x) const { + SparseVector result = *this; + return result += x; + } + + SparseVector operator-(const double &x) const { + SparseVector result = *this; + return result -= x; + } + + SparseVector operator/(const double &x) const { + SparseVector result = *this; + return result /= x; + } + + std::ostream &operator<<(std::ostream &out) const { + for (typename std::map::const_iterator + it = _values.begin(); it != _values.end(); ++it) + out << (it == _values.begin() ? "" : ";") + << FD::Convert(it->first) << '=' << it->second; + return out; + } + + bool operator<(const SparseVector &other) const { + typename std::map::const_iterator it = _values.begin(); + typename std::map::const_iterator other_it = other._values.begin(); + + for (; it != _values.end() && other_it != other._values.end(); ++it, ++other_it) + { + if (it->first < other_it->first) return true; + if (it->first > other_it->first) return false; + if (it->second < other_it->second) return true; + if (it->second > other_it->second) return false; + } + return _values.size() < other._values.size(); + } + + int num_active() const { return _values.size(); } + bool empty() const { return _values.empty(); } + + typedef typename std::map::const_iterator const_iterator; + const_iterator begin() const { return _values.begin(); } + const_iterator end() const { return _values.end(); } + + void clear() { + _values.clear(); + } + + void swap(SparseVector& other) { + _values.swap(other._values); + } + +private: + std::map _values; +}; + +template +SparseVector operator+(const SparseVector& a, const SparseVector& b) { + SparseVector result = a; + return result += b; +} + +template +SparseVector operator*(const SparseVector& a, const double& b) { + SparseVector result = a; + return result *= b; +} + +template +SparseVector operator*(const SparseVector& a, const T& b) { + SparseVector result = a; + return result *= b; +} + +template +SparseVector operator*(const double& a, const SparseVector& b) { + SparseVector result = b; + return result *= a; +} + +template +std::ostream &operator<<(std::ostream &out, const SparseVector &vec) +{ + return vec.operator<<(out); +} + +namespace B64 { + void Encode(double objective, const SparseVector& v, std::ostream* out); + // returns false if failed to decode + bool Decode(double* objective, SparseVector* v, const char* data, size_t size); +} + +#endif diff --git a/src/stringlib.cc b/src/stringlib.cc new file mode 100644 index 00000000..3ed74bef --- /dev/null +++ b/src/stringlib.cc @@ -0,0 +1,97 @@ +#include "stringlib.h" + +#include +#include +#include +#include + +#include "lattice.h" + +using namespace std; + +void ParseTranslatorInput(const string& line, string* input, string* ref) { + size_t hint = 0; + if (line.find("{\"rules\":") == 0) { + hint = line.find("}}"); + if (hint == string::npos) { + cerr << "Syntax error: " << line << endl; + abort(); + } + hint += 2; + } + size_t pos = line.find("|||", hint); + if (pos == string::npos) { *input = line; return; } + ref->clear(); + *input = line.substr(0, pos - 1); + string rline = line.substr(pos + 4); + if (rline.size() > 0) { + assert(ref); + *ref = rline; + } +} + +void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { + string sref; + ParseTranslatorInput(line, input, &sref); + if (sref.size() > 0) { + assert(ref); + LatticeTools::ConvertTextOrPLF(sref, ref); + } +} + +void ProcessAndStripSGML(string* pline, map* out) { + map& meta = *out; + string& line = *pline; + string lline = LowercaseString(line); + if (lline.find(""); + if (close == string::npos) return; // error + size_t end = lline.find(""); + string seg = Trim(lline.substr(4, close-4)); + string text = line.substr(close+1, end - close - 1); + for (size_t i = 1; i < seg.size(); i++) { + if (seg[i] == '=' && seg[i-1] == ' ') { + string less = seg.substr(0, i-1) + seg.substr(i); + seg = less; i = 0; continue; + } + if (seg[i] == '=' && seg[i+1] == ' ') { + string less = seg.substr(0, i+1); + if (i+2 < seg.size()) less += seg.substr(i+2); + seg = less; i = 0; continue; + } + } + line = Trim(text); + if (seg == "") return; + for (size_t i = 1; i < seg.size(); i++) { + if (seg[i] == '=') { + string label = seg.substr(0, i); + string val = seg.substr(i+1); + if (val[0] == '"') { + val = val.substr(1); + size_t close = val.find('"'); + if (close == string::npos) { + cerr << "SGML parse error: missing \"\n"; + seg = ""; + i = 0; + } else { + seg = val.substr(close+1); + val = val.substr(0, close); + i = 0; + } + } else { + size_t close = val.find(' '); + if (close == string::npos) { + seg = ""; + i = 0; + } else { + seg = val.substr(close+1); + val = val.substr(0, close); + } + } + label = Trim(label); + seg = Trim(seg); + meta[label] = val; + } + } +} + diff --git a/src/stringlib.h b/src/stringlib.h new file mode 100644 index 00000000..d26952c7 --- /dev/null +++ b/src/stringlib.h @@ -0,0 +1,91 @@ +#ifndef _STRINGLIB_H_ + +#include +#include +#include +#include + +// read line in the form of either: +// source +// source ||| target +// source will be returned as a string, target must be a sentence or +// a lattice (in PLF format) and will be returned as a Lattice object +void ParseTranslatorInput(const std::string& line, std::string* input, std::string* ref); +struct Lattice; +void ParseTranslatorInputLattice(const std::string& line, std::string* input, Lattice* ref); + +inline const std::string Trim(const std::string& str, const std::string& dropChars = " \t") { + std::string res = str; + res.erase(str.find_last_not_of(dropChars)+1); + return res.erase(0, res.find_first_not_of(dropChars)); +} + +inline void Tokenize(const std::string& str, char delimiter, std::vector* res) { + std::string s = str; + int last = 0; + res->clear(); + for (int i=0; i < s.size(); ++i) + if (s[i] == delimiter) { + s[i]=0; + if (last != i) { + res->push_back(&s[last]); + } + last = i + 1; + } + if (last != s.size()) + res->push_back(&s[last]); +} + +inline std::string LowercaseString(const std::string& in) { + std::string res(in.size(),' '); + for (int i = 0; i < in.size(); ++i) + res[i] = tolower(in[i]); + return res; +} + +inline int CountSubstrings(const std::string& str, const std::string& sub) { + size_t p = 0; + int res = 0; + while (p < str.size()) { + p = str.find(sub, p); + if (p == std::string::npos) break; + ++res; + p += sub.size(); + } + return res; +} + +inline int SplitOnWhitespace(const std::string& in, std::vector* out) { + out->clear(); + int i = 0; + int start = 0; + std::string cur; + while(i < in.size()) { + if (in[i] == ' ' || in[i] == '\t') { + if (i - start > 0) + out->push_back(in.substr(start, i - start)); + start = i + 1; + } + ++i; + } + if (i > start) + out->push_back(in.substr(start, i - start)); + return out->size(); +} + +inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::string* param) { + cmd->clear(); + param->clear(); + std::vector x; + SplitOnWhitespace(in, &x); + if (x.size() == 0) return; + *cmd = x[0]; + for (int i = 1; i < x.size(); ++i) { + if (i > 1) { *param += " "; } + *param += x[i]; + } +} + +void ProcessAndStripSGML(std::string* line, std::map* out); + +#endif diff --git a/src/synparse.cc b/src/synparse.cc new file mode 100644 index 00000000..96588f1e --- /dev/null +++ b/src/synparse.cc @@ -0,0 +1,212 @@ +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "prob.h" +#include "tdict.h" +#include "filelib.h" + +using namespace std; +using namespace __gnu_cxx; +namespace po = boost::program_options; + +const prob_t kMONO(1.0); // 0.6 +const prob_t kINV(1.0); // 0.1 +const prob_t kLEX(1.0); // 0.3 + +typedef hash_map, hash_map, prob_t, boost::hash > >, boost::hash > > PTable; +typedef boost::multi_array CChart; +typedef pair SpanType; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("phrasetable,p",po::value(), "[REQD] Phrase pairs for ITG alignment") + ("input,i",po::value()->default_value("-"), "Input file") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("phrasetable")) { + cerr << "Please specify a grammar file with -p \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void LoadITGPhrasetable(const string& fname, PTable* ptable) { + const WordID sep = TD::Convert("|||"); + ReadFile rf(fname); + istream& in = *rf.stream(); + assert(in); + int lc = 0; + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + ++lc; + vector full, f, e; + TD::ConvertSentence(line, &full); + int i = 0; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + f.push_back(full[i]); + } + ++i; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + e.push_back(full[i]); + } + ++i; + prob_t prob(0.000001); + if (i < full.size()) { prob = prob_t(atof(TD::Convert(full[i]))); ++i; } + + if (i < full.size()) { cerr << "Warning line " << lc << " has extra stuff.\n"; } + assert(f.size() > 0); + assert(e.size() > 0); + (*ptable)[f][e] = prob; + } + cerr << "Read " << lc << " phrase pairs\n"; +} + +void FindPhrases(const vector& e, const vector& f, const PTable& pt, CChart* pcc) { + CChart& cc = *pcc; + const size_t n = f.size(); + const size_t m = e.size(); + typedef hash_map, vector, boost::hash > > PhraseToSpan; + PhraseToSpan e_locations; + for (int i = 0; i < m; ++i) { + const int mel = m - i; + vector e_phrase; + for (int el = 0; el < mel; ++el) { + e_phrase.push_back(e[i + el]); + e_locations[e_phrase].push_back(make_pair(i, i + el + 1)); + } + } + //cerr << "Cached the locations of " << e_locations.size() << " e-phrases\n"; + + for (int s = 0; s < n; ++s) { + const int mfl = n - s; + vector f_phrase; + for (int fl = 0; fl < mfl; ++fl) { + f_phrase.push_back(f[s + fl]); + PTable::const_iterator it = pt.find(f_phrase); + if (it == pt.end()) continue; + const hash_map, prob_t, boost::hash > >& es = it->second; + for (hash_map, prob_t, boost::hash > >::const_iterator eit = es.begin(); eit != es.end(); ++eit) { + PhraseToSpan::iterator loc = e_locations.find(eit->first); + if (loc == e_locations.end()) continue; + const vector& espans = loc->second; + for (int j = 0; j < espans.size(); ++j) { + cc[s][s + fl + 1][espans[j].first][espans[j].second] = eit->second; + //cerr << '[' << s << ',' << (s + fl + 1) << ',' << espans[j].first << ',' << espans[j].second << "] is C\n"; + } + } + } + } +} + +long long int evals = 0; + +void ProcessSynchronousCell(const int s, + const int t, + const int u, + const int v, + const prob_t& lex, + const prob_t& mono, + const prob_t& inv, + const CChart& tc, CChart* ntc) { + prob_t& inside = (*ntc)[s][t][u][v]; + // cerr << log(tc[s][t][u][v]) << " + " << log(lex) << endl; + inside = tc[s][t][u][v] * lex; + // cerr << " terminal span: " << log(inside) << endl; + if (t - s == 1) return; + if (v - u == 1) return; + for (int x = s+1; x < t; ++x) { + for (int y = u+1; y < v; ++y) { + const prob_t m = (*ntc)[s][x][u][y] * (*ntc)[x][t][y][v] * mono; + const prob_t i = (*ntc)[s][x][y][v] * (*ntc)[x][t][u][y] * inv; + // cerr << log(i) << "\t" << log(m) << endl; + inside += m; + inside += i; + evals++; + } + } + // cerr << " span: " << log(inside) << endl; +} + +prob_t SynchronousParse(const int n, const int m, const prob_t& lex, const prob_t& mono, const prob_t& inv, const CChart& tc, CChart* ntc) { + for (int fl = 0; fl < n; ++fl) { + for (int el = 0; el < m; ++el) { + const int ms = n - fl; + for (int s = 0; s < ms; ++s) { + const int t = s + fl + 1; + const int mu = m - el; + for (int u = 0; u < mu; ++u) { + const int v = u + el + 1; + //cerr << "Processing cell [" << s << ',' << t << ',' << u << ',' << v << "]\n"; + ProcessSynchronousCell(s, t, u, v, lex, mono, inv, tc, ntc); + } + } + } + } + return (*ntc)[0][n][0][m]; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + PTable ptable; + LoadITGPhrasetable(conf["phrasetable"].as(), &ptable); + ReadFile rf(conf["input"].as()); + istream& in = *rf.stream(); + int lc = 0; + const WordID sep = TD::Convert("|||"); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + ++lc; + vector full, f, e; + TD::ConvertSentence(line, &full); + int i = 0; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + f.push_back(full[i]); + } + ++i; + for (; i < full.size(); ++i) { + if (full[i] == sep) break; + e.push_back(full[i]); + } + if (e.empty()) cerr << "E is empty!\n"; + if (f.empty()) cerr << "F is empty!\n"; + if (e.empty() || f.empty()) continue; + int n = f.size(); + int m = e.size(); + cerr << "Synchronous chart has " << (n * n * m * m) << " cells\n"; + clock_t start = clock(); + CChart cc(boost::extents[n+1][n+1][m+1][m+1]); + FindPhrases(e, f, ptable, &cc); + CChart ntc(boost::extents[n+1][n+1][m+1][m+1]); + prob_t likelihood = SynchronousParse(n, m, kLEX, kMONO, kINV, cc, &ntc); + clock_t end = clock(); + cerr << "log Z: " << log(likelihood) << endl; + cerr << " Z: " << likelihood << endl; + double etime = (end - start) / 1000000.0; + cout << " time: " << etime << endl; + cout << "evals: " << evals << endl; + } + return 0; +} + diff --git a/src/tdict.cc b/src/tdict.cc new file mode 100644 index 00000000..c00d20b8 --- /dev/null +++ b/src/tdict.cc @@ -0,0 +1,49 @@ +#include "Ngram.h" +#include "dict.h" +#include "tdict.h" +#include "Vocab.h" + +using namespace std; + +Vocab* TD::dict_ = new Vocab; + +static const string empty; +static const string space = " "; + +WordID TD::Convert(const std::string& s) { + return dict_->addWord((VocabString)s.c_str()); +} + +const char* TD::Convert(const WordID& w) { + return dict_->getWord((VocabIndex)w); +} + +void TD::GetWordIDs(const std::vector& strings, std::vector* ids) { + ids->clear(); + for (vector::const_iterator i = strings.begin(); i != strings.end(); ++i) + ids->push_back(TD::Convert(*i)); +} + +std::string TD::GetString(const std::vector& str) { + string res; + for (vector::const_iterator i = str.begin(); i != str.end(); ++i) + res += (i == str.begin() ? empty : space) + TD::Convert(*i); + return res; +} + +void TD::ConvertSentence(const std::string& sent, std::vector* ids) { + string s = sent; + int last = 0; + ids->clear(); + for (int i=0; i < s.size(); ++i) + if (s[i] == 32 || s[i] == '\t') { + s[i]=0; + if (last != i) { + ids->push_back(Convert(&s[last])); + } + last = i + 1; + } + if (last != s.size()) + ids->push_back(Convert(&s[last])); +} + diff --git a/src/tdict.h b/src/tdict.h new file mode 100644 index 00000000..9d4318fe --- /dev/null +++ b/src/tdict.h @@ -0,0 +1,19 @@ +#ifndef _TDICT_H_ +#define _TDICT_H_ + +#include +#include +#include "wordid.h" + +class Vocab; + +struct TD { + static Vocab* dict_; + static void ConvertSentence(const std::string& sent, std::vector* ids); + static void GetWordIDs(const std::vector& strings, std::vector* ids); + static std::string GetString(const std::vector& str); + static WordID Convert(const std::string& s); + static const char* Convert(const WordID& w); +}; + +#endif diff --git a/src/timing_stats.cc b/src/timing_stats.cc new file mode 100644 index 00000000..85b95de5 --- /dev/null +++ b/src/timing_stats.cc @@ -0,0 +1,24 @@ +#include "timing_stats.h" + +#include + +using namespace std; + +map Timer::stats; + +Timer::Timer(const string& timername) : start_t(clock()), cur(stats[timername]) {} + +Timer::~Timer() { + ++cur.calls; + const clock_t end_t = clock(); + const double elapsed = (end_t - start_t) / 1000000.0; + cur.total_time += elapsed; +} + +void Timer::Summarize() { + for (map::iterator it = stats.begin(); it != stats.end(); ++it) { + cerr << it->first << ": " << it->second.total_time << " secs (" << it->second.calls << " calls)\n"; + } + stats.clear(); +} + diff --git a/src/timing_stats.h b/src/timing_stats.h new file mode 100644 index 00000000..0a9f7656 --- /dev/null +++ b/src/timing_stats.h @@ -0,0 +1,25 @@ +#ifndef _TIMING_STATS_H_ +#define _TIMING_STATS_H_ + +#include +#include + +struct TimerInfo { + int calls; + double total_time; + TimerInfo() : calls(), total_time() {} +}; + +struct Timer { + Timer(const std::string& info); + ~Timer(); + static void Summarize(); + private: + static std::map stats; + clock_t start_t; + TimerInfo& cur; + Timer(const Timer& other); + const Timer& operator=(const Timer& other); +}; + +#endif diff --git a/src/translator.h b/src/translator.h new file mode 100644 index 00000000..194efbaa --- /dev/null +++ b/src/translator.h @@ -0,0 +1,54 @@ +#ifndef _TRANSLATOR_H_ +#define _TRANSLATOR_H_ + +#include +#include +#include +#include + +class Hypergraph; +class SentenceMetadata; + +class Translator { + public: + virtual ~Translator(); + // returns true if goal reached, false otherwise + // minus_lm_forest will contain the unpruned forest. the + // feature values from the phrase table / grammar / etc + // should be in the forest already - the "late" features + // should not just copy values that are available without + // any context or computation. + // SentenceMetadata contains information about the sentence, + // but it is an input/output parameter since the Translator + // is also responsible for setting the value of src_len. + virtual bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest) = 0; +}; + +class SCFGTranslatorImpl; +class SCFGTranslator : public Translator { + public: + SCFGTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +class FSTTranslatorImpl; +class FSTTranslator : public Translator { + public: + FSTTranslator(const boost::program_options::variables_map& conf); + bool Translate(const std::string& src, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* minus_lm_forest); + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/src/trule.cc b/src/trule.cc new file mode 100644 index 00000000..b8f6995e --- /dev/null +++ b/src/trule.cc @@ -0,0 +1,237 @@ +#include "trule.h" + +#include + +#include "stringlib.h" +#include "tdict.h" + +using namespace std; + +static WordID ConvertTrgString(const string& w) { + int len = w.size(); + WordID id = 0; + // [X,0] or [0] + // for target rules, we ignore the category, just keep the index + if (len > 2 && w[0]=='[' && w[len-1]==']' && w[len-2] > '0' && w[len-2] <= '9' && + (len == 3 || (len > 4 && w[len-3] == ','))) { + id = w[len-2] - '0'; + id = 1 - id; + } else { + id = TD::Convert(w); + } + return id; +} + +static WordID ConvertSrcString(const string& w, bool mono = false) { + int len = w.size(); + // [X,0] + // for source rules, we keep the category and ignore the index (source rules are + // always numbered 1, 2, 3... + if (mono) { + if (len > 2 && w[0]=='[' && w[len-1]==']') { + if (len > 4 && w[len-3] == ',') { + cerr << "[ERROR] Monolingual rules mut not have non-terminal indices:\n " + << w << endl; + exit(1); + } + // TODO check that source indices go 1,2,3,etc. + return TD::Convert(w.substr(1, len-2)) * -1; + } else { + return TD::Convert(w); + } + } else { + if (len > 4 && w[0]=='[' && w[len-1]==']' && w[len-3] == ',' && w[len-2] > '0' && w[len-2] <= '9') { + return TD::Convert(w.substr(1, len-4)) * -1; + } else { + return TD::Convert(w); + } + } +} + +static WordID ConvertLHS(const string& w) { + if (w[0] == '[') { + int len = w.size(); + if (len < 3) { cerr << "Format error: " << w << endl; exit(1); } + return TD::Convert(w.substr(1, len-2)) * -1; + } else { + return TD::Convert(w) * -1; + } +} + +TRule* TRule::CreateRuleSynchronous(const std::string& rule) { + TRule* res = new TRule; + if (res->ReadFromString(rule, true, false)) return res; + cerr << "[ERROR] Failed to creating rule from: " << rule << endl; + delete res; + return NULL; +} + +TRule* TRule::CreateRulePhrasetable(const string& rule) { + // TODO make this faster + // TODO add configuration for default NT type + if (rule[0] == '[') { + cerr << "Phrasetable rules shouldn't have a LHS / non-terminals:\n " << rule << endl; + return NULL; + } + TRule* res = new TRule("[X] ||| " + rule, true, false); + if (res->Arity() != 0) { + cerr << "Phrasetable rules should have arity 0:\n " << rule << endl; + delete res; + return NULL; + } + return res; +} + +TRule* TRule::CreateRuleMonolingual(const string& rule) { + return new TRule(rule, false, true); +} + +bool TRule::ReadFromString(const string& line, bool strict, bool mono) { + e_.clear(); + f_.clear(); + scores_.clear(); + + string w; + istringstream is(line); + int format = CountSubstrings(line, "|||"); + if (strict && format < 2) { + cerr << "Bad rule format in strict mode:\n" << line << endl; + return false; + } + if (format >= 2 || (mono && format == 1)) { + while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } + while(is>>w && w!="|||") { f_.push_back(ConvertSrcString(w, mono)); } + if (!mono) { + while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } + } + int fv = 0; + if (is) { + string ss; + getline(is, ss); + //cerr << "L: " << ss << endl; + int start = 0; + const int len = ss.size(); + while (start < len) { + while(start < len && (ss[start] == ' ' || ss[start] == ';')) + ++start; + if (start == len) break; + int end = start + 1; + while(end < len && (ss[end] != '=' && ss[end] != ' ' && ss[end] != ';')) + ++end; + if (end == len || ss[end] == ' ' || ss[end] == ';') { + //cerr << "PROC: '" << ss.substr(start, end - start) << "'\n"; + // non-named features + if (end != len) { ss[end] = 0; } + string fname = "PhraseModel_X"; + if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } + fname[12]='0' + fv; + ++fv; + scores_.set_value(FD::Convert(fname), atof(&ss[start])); + //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; + } else { + const int fid = FD::Convert(ss.substr(start, end - start)); + start = end + 1; + end = start + 1; + while(end < len && (ss[end] != ' ' && ss[end] != ';')) + ++end; + if (end < len) { ss[end] = 0; } + assert(start < len); + scores_.set_value(fid, atof(&ss[start])); + //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; + } + start = end + 1; + } + } + } else if (format == 1) { + while(is>>w && w!="|||") { lhs_ = ConvertLHS(w); } + while(is>>w && w!="|||") { e_.push_back(ConvertTrgString(w)); } + f_ = e_; + int x = ConvertLHS("[X]"); + for (int i = 0; i < f_.size(); ++i) + if (f_[i] <= 0) { f_[i] = x; } + } else { + cerr << "F: " << format << endl; + cerr << "[ERROR] Don't know how to read:\n" << line << endl; + } + if (mono) { + e_ = f_; + int ci = 0; + for (int i = 0; i < e_.size(); ++i) + if (e_[i] < 0) + e_[i] = ci--; + } + ComputeArity(); + return SanityCheck(); +} + +bool TRule::SanityCheck() const { + vector used(f_.size(), 0); + int ac = 0; + for (int i = 0; i < e_.size(); ++i) { + int ind = e_[i]; + if (ind > 0) continue; + ind = -ind; + if ((++used[ind]) != 1) { + cerr << "[ERROR] e-side variable index " << (ind+1) << " used more than once!\n"; + return false; + } + ac++; + } + if (ac != Arity()) { + cerr << "[ERROR] e-side arity mismatches f-side\n"; + return false; + } + return true; +} + +void TRule::ComputeArity() { + int min = 1; + for (vector::const_iterator i = e_.begin(); i != e_.end(); ++i) + if (*i < min) min = *i; + arity_ = 1 - min; +} + +static string AnonymousStrVar(int i) { + string res("[v]"); + if(!(i <= 0 && i >= -8)) { + cerr << "Can't handle more than 9 non-terminals: index=" << (-i) << endl; + abort(); + } + res[1] = '1' - i; + return res; +} + +string TRule::AsString(bool verbose) const { + ostringstream os; + int idx = 0; + if (lhs_ && verbose) { + os << '[' << TD::Convert(lhs_ * -1) << "] |||"; + for (int i = 0; i < f_.size(); ++i) { + const WordID& w = f_[i]; + if (w < 0) { + int wi = w * -1; + ++idx; + os << " [" << TD::Convert(wi) << ',' << idx << ']'; + } else { + os << ' ' << TD::Convert(w); + } + } + os << " ||| "; + } + if (idx > 9) { + cerr << "Too many non-terminals!\n partial: " << os.str() << endl; + exit(1); + } + for (int i =0; i +#include +#include +#include + +#include "sparse_vector.h" +#include "wordid.h" + +class TRule; +typedef boost::shared_ptr TRulePtr; +struct SpanInfo; + +// Translation rule +class TRule { + public: + TRule() : lhs_(0), prev_i(-1), prev_j(-1) { } + explicit TRule(const std::vector& e) : e_(e), lhs_(0), prev_i(-1), prev_j(-1) {} + TRule(const std::vector& e, const std::vector& f, const WordID& lhs) : + e_(e), f_(f), lhs_(lhs), prev_i(-1), prev_j(-1) {} + + // deprecated - this will be private soon + explicit TRule(const std::string& text, bool strict = false, bool mono = false) { + ReadFromString(text, strict, mono); + } + + // make a rule from a hiero-like rule table, e.g. + // [X] ||| [X,1] DE [X,2] ||| [X,2] of the [X,1] + // if misformatted, returns NULL + static TRule* CreateRuleSynchronous(const std::string& rule); + + // make a rule from a phrasetable entry (i.e., one that has no LHS type), e.g: + // el gato ||| the cat ||| Feature_2=0.34 + static TRule* CreateRulePhrasetable(const std::string& rule); + + // make a rule from a non-synchrnous CFG representation, e.g.: + // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] + static TRule* CreateRuleMonolingual(const std::string& rule); + + void ESubstitute(const std::vector* >& var_values, + std::vector* result) const { + int vc = 0; + result->clear(); + for (std::vector::const_iterator i = e_.begin(); i != e_.end(); ++i) { + const WordID& c = *i; + if (c < 1) { + ++vc; + const std::vector& var_value = *var_values[-c]; + std::copy(var_value.begin(), + var_value.end(), + std::back_inserter(*result)); + } else { + result->push_back(c); + } + } + assert(vc == var_values.size()); + } + + void FSubstitute(const std::vector* >& var_values, + std::vector* result) const { + int vc = 0; + result->clear(); + for (std::vector::const_iterator i = f_.begin(); i != f_.end(); ++i) { + const WordID& c = *i; + if (c < 1) { + const std::vector& var_value = *var_values[vc++]; + std::copy(var_value.begin(), + var_value.end(), + std::back_inserter(*result)); + } else { + result->push_back(c); + } + } + assert(vc == var_values.size()); + } + + bool ReadFromString(const std::string& line, bool strict = false, bool monolingual = false); + + bool Initialized() const { return e_.size(); } + + std::string AsString(bool verbose = true) const; + + static TRule DummyRule() { + TRule res; + res.e_.resize(1, 0); + return res; + } + + const std::vector& f() const { return f_; } + const std::vector& e() const { return e_; } + + int EWords() const { return ELength() - Arity(); } + int FWords() const { return FLength() - Arity(); } + int FLength() const { return f_.size(); } + int ELength() const { return e_.size(); } + int Arity() const { return arity_; } + bool IsUnary() const { return (Arity() == 1) && (f_.size() == 1); } + const SparseVector& GetFeatureValues() const { return scores_; } + double Score(int i) const { return scores_[i]; } + WordID GetLHS() const { return lhs_; } + void ComputeArity(); + + // 0 = first variable, -1 = second variable, -2 = third ... + std::vector e_; + // < 0: *-1 = encoding of category of variable + std::vector f_; + WordID lhs_; + SparseVector scores_; + char arity_; + TRulePtr parent_rule_; // usually NULL, except when doing constrained decoding + + // this is only used when doing synchronous parsing + short int prev_i; + short int prev_j; + + private: + bool SanityCheck() const; +}; + +#endif diff --git a/src/trule_test.cc b/src/trule_test.cc new file mode 100644 index 00000000..02a70764 --- /dev/null +++ b/src/trule_test.cc @@ -0,0 +1,65 @@ +#include "trule.h" + +#include +#include +#include +#include "tdict.h" + +using namespace std; + +class TRuleTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + +TEST_F(TRuleTest,TestFSubstitute) { + TRule r1("[X] ||| ob [X,1] [X,2] sah . ||| whether [X,1] saw [X,2] . ||| 0.99"); + TRule r2("[X] ||| ich ||| i ||| 1.0"); + TRule r3("[X] ||| ihn ||| him ||| 1.0"); + vector*> ants; + vector res2; + r2.FSubstitute(ants, &res2); + assert(TD::GetString(res2) == "ich"); + vector res3; + r3.FSubstitute(ants, &res3); + assert(TD::GetString(res3) == "ihn"); + ants.push_back(&res2); + ants.push_back(&res3); + vector res; + r1.FSubstitute(ants, &res); + cerr << TD::GetString(res) << endl; + assert(TD::GetString(res) == "ob ich ihn sah ."); +} + +TEST_F(TRuleTest,TestPhrasetableRule) { + TRulePtr t(TRule::CreateRulePhrasetable("gato ||| cat ||| PhraseModel_0=-23.2;Foo=1;Bar=12")); + cerr << t->AsString() << endl; + assert(t->scores_.num_active() == 3); +}; + + +TEST_F(TRuleTest,TestMonoRule) { + TRulePtr m(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3]")); + assert(m->Arity() == 3); + cerr << m->AsString() << endl; + TRulePtr m2(TRule::CreateRuleMonolingual("[LHS] ||| term1 [NT] term2 [NT2] [NT3] ||| Feature1=0.23")); + assert(m2->Arity() == 3); + cerr << m2->AsString() << endl; + EXPECT_FLOAT_EQ(m2->scores_.value(FD::Convert("Feature1")), 0.23); +} + +TEST_F(TRuleTest,TestRuleR) { + TRule t6; + t6.ReadFromString("[X] ||| den [X,1] sah [X,2] . ||| [X,2] saw the [X,1] . ||| 0.12321 0.23232 0.121"); + cerr << "TEXT: " << t6.AsString() << endl; + EXPECT_EQ(t6.Arity(), 2); + EXPECT_EQ(t6.e_[0], -1); + EXPECT_EQ(t6.e_[3], 0); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + diff --git a/src/ttables.cc b/src/ttables.cc new file mode 100644 index 00000000..2ea960f0 --- /dev/null +++ b/src/ttables.cc @@ -0,0 +1,31 @@ +#include "ttables.h" + +#include + +#include "dict.h" + +using namespace std; +using namespace std::tr1; + +void TTable::DeserializeProbsFromText(std::istream* in) { + int c = 0; + while(*in) { + string e; + string f; + double p; + (*in) >> e >> f >> p; + if (e.empty()) break; + ++c; + ttable[TD::Convert(e)][TD::Convert(f)] = prob_t(p); + } + cerr << "Loaded " << c << " translation parameters.\n"; +} + +void TTable::SerializeHelper(string* out, const Word2Word2Double& o) { + assert(!"not implemented"); +} + +void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) { + assert(!"not implemented"); +} + diff --git a/src/ttables.h b/src/ttables.h new file mode 100644 index 00000000..3ffc238a --- /dev/null +++ b/src/ttables.h @@ -0,0 +1,87 @@ +#ifndef _TTABLES_H_ +#define _TTABLES_H_ + +#include +#include + +#include "wordid.h" +#include "prob.h" +#include "tdict.h" + +class TTable { + public: + TTable() {} + typedef std::map Word2Double; + typedef std::map Word2Word2Double; + inline const prob_t prob(const int& e, const int& f) const { + const Word2Word2Double::const_iterator cit = ttable.find(e); + if (cit != ttable.end()) { + const Word2Double& cpd = cit->second; + const Word2Double::const_iterator it = cpd.find(f); + if (it == cpd.end()) return prob_t(0.00001); + return prob_t(it->second); + } else { + return prob_t(0.00001); + } + } + inline void Increment(const int& e, const int& f) { + counts[e][f] += 1.0; + } + inline void Increment(const int& e, const int& f, double x) { + counts[e][f] += x; + } + void Normalize() { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second /= tot; + } + counts.clear(); + } + // adds counts from another TTable - probabilities remain unchanged + TTable& operator+=(const TTable& rhs) { + for (Word2Word2Double::const_iterator it = rhs.counts.begin(); + it != rhs.counts.end(); ++it) { + const Word2Double& cpd = it->second; + Word2Double& tgt = counts[it->first]; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + tgt[j->first] += j->second; + } + } + return *this; + } + void ShowTTable() { + for (Word2Word2Double::iterator it = ttable.begin(); it != ttable.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void ShowCounts() { + for (Word2Word2Double::iterator it = counts.begin(); it != counts.end(); ++it) { + Word2Double& cpd = it->second; + for (Word2Double::iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void DeserializeProbsFromText(std::istream* in); + void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } + void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } + void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } + void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } + private: + static void SerializeHelper(std::string*, const Word2Word2Double& o); + static void DeserializeHelper(const std::string&, Word2Word2Double* o); + public: + Word2Word2Double ttable; + Word2Word2Double counts; +}; + +#endif diff --git a/src/viterbi.cc b/src/viterbi.cc new file mode 100644 index 00000000..82b2ce6d --- /dev/null +++ b/src/viterbi.cc @@ -0,0 +1,39 @@ +#include "viterbi.h" + +#include +#include "hg.h" + +using namespace std; + +string ViterbiETree(const Hypergraph& hg) { + vector tmp; + const prob_t p = Viterbi, ETreeTraversal, prob_t, EdgeProb>(hg, &tmp); + return TD::GetString(tmp); +} + +string ViterbiFTree(const Hypergraph& hg) { + vector tmp; + const prob_t p = Viterbi, FTreeTraversal, prob_t, EdgeProb>(hg, &tmp); + return TD::GetString(tmp); +} + +prob_t ViterbiESentence(const Hypergraph& hg, vector* result) { + return Viterbi, ESentenceTraversal, prob_t, EdgeProb>(hg, result); +} + +prob_t ViterbiFSentence(const Hypergraph& hg, vector* result) { + return Viterbi, FSentenceTraversal, prob_t, EdgeProb>(hg, result); +} + +int ViterbiELength(const Hypergraph& hg) { + int len = -1; + Viterbi(hg, &len); + return len; +} + +int ViterbiPathLength(const Hypergraph& hg) { + int len = -1; + Viterbi(hg, &len); + return len; +} + diff --git a/src/viterbi.h b/src/viterbi.h new file mode 100644 index 00000000..46a4f528 --- /dev/null +++ b/src/viterbi.h @@ -0,0 +1,130 @@ +#ifndef _VITERBI_H_ +#define _VITERBI_H_ + +#include +#include "prob.h" +#include "hg.h" +#include "tdict.h" + +// V must implement: +// void operator()(const vector& ants, T* result); +template +WeightType Viterbi(const Hypergraph& hg, + T* result, + const Traversal& traverse = Traversal(), + const WeightFunction& weight = WeightFunction()) { + const int num_nodes = hg.nodes_.size(); + std::vector vit_result(num_nodes); + std::vector vit_weight(num_nodes, WeightType::Zero()); + + for (int i = 0; i < num_nodes; ++i) { + const Hypergraph::Node& cur_node = hg.nodes_[i]; + WeightType* const cur_node_best_weight = &vit_weight[i]; + T* const cur_node_best_result = &vit_result[i]; + + const int num_in_edges = cur_node.in_edges_.size(); + if (num_in_edges == 0) { + *cur_node_best_weight = WeightType(1); + continue; + } + for (int j = 0; j < num_in_edges; ++j) { + const Hypergraph::Edge& edge = hg.edges_[cur_node.in_edges_[j]]; + WeightType score = weight(edge); + std::vector ants(edge.tail_nodes_.size()); + for (int k = 0; k < edge.tail_nodes_.size(); ++k) { + const int tail_node_index = edge.tail_nodes_[k]; + score *= vit_weight[tail_node_index]; + ants[k] = &vit_result[tail_node_index]; + } + if (*cur_node_best_weight < score) { + *cur_node_best_weight = score; + traverse(edge, ants, cur_node_best_result); + } + } + } + std::swap(*result, vit_result.back()); + return vit_weight.back(); +} + +struct PathLengthTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector& ants, + int* result) const { + (void) edge; + *result = 1; + for (int i = 0; i < ants.size(); ++i) *result += *ants[i]; + } +}; + +struct ESentenceTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + edge.rule_->ESubstitute(ants, result); + } +}; + +struct ELengthTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector& ants, + int* result) const { + *result = edge.rule_->ELength() - edge.rule_->Arity(); + for (int i = 0; i < ants.size(); ++i) *result += *ants[i]; + } +}; + +struct FSentenceTraversal { + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + edge.rule_->FSubstitute(ants, result); + } +}; + +// create a strings of the form (S (X the man) (X said (X he (X would (X go))))) +struct ETreeTraversal { + ETreeTraversal() : left("("), space(" "), right(")") {} + const std::string left; + const std::string space; + const std::string right; + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + std::vector tmp; + edge.rule_->ESubstitute(ants, &tmp); + const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1); + if (cat == "Goal") + result->swap(tmp); + else + TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right, + result); + } +}; + +struct FTreeTraversal { + FTreeTraversal() : left("("), space(" "), right(")") {} + const std::string left; + const std::string space; + const std::string right; + void operator()(const Hypergraph::Edge& edge, + const std::vector*>& ants, + std::vector* result) const { + std::vector tmp; + edge.rule_->FSubstitute(ants, &tmp); + const std::string cat = TD::Convert(edge.rule_->GetLHS() * -1); + if (cat == "Goal") + result->swap(tmp); + else + TD::ConvertSentence(left + cat + space + TD::GetString(tmp) + right, + result); + } +}; + +prob_t ViterbiESentence(const Hypergraph& hg, std::vector* result); +std::string ViterbiETree(const Hypergraph& hg); +prob_t ViterbiFSentence(const Hypergraph& hg, std::vector* result); +std::string ViterbiFTree(const Hypergraph& hg); +int ViterbiELength(const Hypergraph& hg); +int ViterbiPathLength(const Hypergraph& hg); + +#endif diff --git a/src/weights.cc b/src/weights.cc new file mode 100644 index 00000000..bb0a878f --- /dev/null +++ b/src/weights.cc @@ -0,0 +1,73 @@ +#include "weights.h" + +#include + +#include "fdict.h" +#include "filelib.h" + +using namespace std; + +void Weights::InitFromFile(const std::string& filename, vector* feature_list) { + cerr << "Reading weights from " << filename << endl; + ReadFile in_file(filename); + istream& in = *in_file.stream(); + assert(in); + int weight_count = 0; + bool fl = false; + while (in) { + double val = 0; + string buf; + getline(in, buf); + if (buf.size() == 0) continue; + if (buf[0] == '#') continue; + for (int i = 0; i < buf.size(); ++i) + if (buf[i] == '=') buf[i] = ' '; + int start = 0; + while(start < buf.size() && buf[start] == ' ') ++start; + int end = 0; + while(end < buf.size() && buf[end] != ' ') ++end; + int fid = FD::Convert(buf.substr(start, end - start)); + while(end < buf.size() && buf[end] == ' ') ++end; + val = strtod(&buf.c_str()[end], NULL); + if (wv_.size() <= fid) + wv_.resize(fid + 1); + wv_[fid] = val; + if (feature_list) { feature_list->push_back(FD::Convert(fid)); } + ++weight_count; + if (weight_count % 50000 == 0) { cerr << '.' << flush; fl = true; } + if (weight_count % 2000000 == 0) { cerr << " [" << weight_count << "]\n"; fl = false; } + } + if (fl) { cerr << endl; } + cerr << "Loaded " << weight_count << " feature weights\n"; +} + +void Weights::WriteToFile(const std::string& fname, bool hide_zero_value_features) const { + WriteFile out(fname); + ostream& o = *out.stream(); + assert(o); + o.precision(17); + const int num_feats = FD::NumFeats(); + for (int i = 1; i < num_feats; ++i) { + const double val = (i < wv_.size() ? wv_[i] : 0.0); + if (hide_zero_value_features && val == 0.0) continue; + o << FD::Convert(i) << ' ' << val << endl; + } +} + +void Weights::InitVector(std::vector* w) const { + *w = wv_; +} + +void Weights::InitSparseVector(SparseVector* w) const { + for (int i = 1; i < wv_.size(); ++i) { + const double& weight = wv_[i]; + if (weight) w->set_value(i, weight); + } +} + +void Weights::InitFromVector(const std::vector& w) { + wv_ = w; + if (wv_.size() > FD::NumFeats()) + cerr << "WARNING: initializing weight vector has more features than the global feature dictionary!\n"; + wv_.resize(FD::NumFeats(), 0); +} diff --git a/src/weights.h b/src/weights.h new file mode 100644 index 00000000..f19aa3ce --- /dev/null +++ b/src/weights.h @@ -0,0 +1,21 @@ +#ifndef _WEIGHTS_H_ +#define _WEIGHTS_H_ + +#include +#include +#include +#include "sparse_vector.h" + +class Weights { + public: + Weights() {} + void InitFromFile(const std::string& fname, std::vector* feature_list = NULL); + void WriteToFile(const std::string& fname, bool hide_zero_value_features = true) const; + void InitVector(std::vector* w) const; + void InitSparseVector(SparseVector* w) const; + void InitFromVector(const std::vector& w); + private: + std::vector wv_; +}; + +#endif diff --git a/src/weights_test.cc b/src/weights_test.cc new file mode 100644 index 00000000..aa6b3db2 --- /dev/null +++ b/src/weights_test.cc @@ -0,0 +1,28 @@ +#include +#include +#include +#include +#include +#include "weights.h" +#include "tdict.h" +#include "hg.h" + +using namespace std; + +class WeightsTest : public testing::Test { + protected: + virtual void SetUp() { } + virtual void TearDown() { } +}; + + +TEST_F(WeightsTest,Load) { + Weights w; + w.InitFromFile("test_data/weights"); + w.WriteToFile("-"); +} + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/wordid.h b/src/wordid.h new file mode 100644 index 00000000..fb50bcc1 --- /dev/null +++ b/src/wordid.h @@ -0,0 +1,6 @@ +#ifndef _WORD_ID_H_ +#define _WORD_ID_H_ + +typedef int WordID; + +#endif -- cgit v1.2.3