From 925087356b853e2099c1b60d8b757d7aa02121a9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 2 Oct 2012 00:19:43 -0400 Subject: cdec cleanup, remove bayesian stuff, parsing stuff --- configure.ac | 5 ----- 1 file changed, 5 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index ea9e84fb..07ef9fe1 100644 --- a/configure.ac +++ b/configure.ac @@ -114,7 +114,6 @@ AC_CONFIG_FILES([Makefile]) AC_CONFIG_FILES([utils/Makefile]) AC_CONFIG_FILES([mteval/Makefile]) AC_CONFIG_FILES([decoder/Makefile]) -AC_CONFIG_FILES([phrasinator/Makefile]) AC_CONFIG_FILES([training/Makefile]) AC_CONFIG_FILES([training/liblbfgs/Makefile]) AC_CONFIG_FILES([dpmert/Makefile]) @@ -125,10 +124,6 @@ AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) -AC_CONFIG_FILES([gi/pyp-topics/src/Makefile]) -AC_CONFIG_FILES([gi/clda/src/Makefile]) -AC_CONFIG_FILES([gi/pf/Makefile]) -AC_CONFIG_FILES([gi/markov_al/Makefile]) AC_CONFIG_FILES([rst_parser/Makefile]) AC_CONFIG_FILES([python/setup.py]) -- cgit v1.2.3 From c9c7536ebd387479c2d39c8f1fa91bc047e0cac5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 2 Oct 2012 00:27:21 -0400 Subject: fix build --- configure.ac | 1 - utils/Makefile.am | 6 ------ 2 files changed, 7 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 07ef9fe1..70e8e932 100644 --- a/configure.ac +++ b/configure.ac @@ -124,7 +124,6 @@ AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) -AC_CONFIG_FILES([rst_parser/Makefile]) AC_CONFIG_FILES([python/setup.py]) diff --git a/utils/Makefile.am b/utils/Makefile.am index 55d97354..3ad9d69e 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -45,18 +45,12 @@ m_test_SOURCES = m_test.cc m_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz dict_test_SOURCES = dict_test.cc dict_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz -mfcr_test_SOURCES = mfcr_test.cc -mfcr_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz weights_test_SOURCES = weights_test.cc weights_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz -crp_test_SOURCES = crp_test.cc -crp_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz logval_test_SOURCES = logval_test.cc logval_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz small_vector_test_SOURCES = small_vector_test.cc small_vector_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz -unigram_pyp_lm_SOURCES = unigram_pyp_lm.cc -unigram_pyp_lm_LDADD = libutils.a -lz ################################################################ # do NOT NOT NOT add any other -I includes NO NO NO NO NO ###### -- cgit v1.2.3 From fc13e7db2af155a7b636e4061f3567ea3957fd9a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 11 Oct 2012 21:57:18 -0400 Subject: add support for dlopen'd feature functions --- configure.ac | 3 ++- decoder/Makefile.am | 1 + decoder/cdec_ff.cc | 2 ++ decoder/ff.h | 1 + decoder/ff_external.cc | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ decoder/ff_external.h | 26 +++++++++++++++++++++++ 6 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 decoder/ff_external.cc create mode 100644 decoder/ff_external.h (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 70e8e932..967b657c 100644 --- a/configure.ac +++ b/configure.ac @@ -14,7 +14,8 @@ BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS BOOST_TEST AM_PATH_PYTHON -# TODO detect Cython, generate python/Makefile that calls "python setup.py build" +AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) +AC_CHECK_LIB(dl, dlopen) AC_ARG_ENABLE(mpi, [ --enable-mpi Build MPI binaries, assumes mpi.h is present ], diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 4a98a4f1..28863dbe 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -33,6 +33,7 @@ libcdec_a_SOURCES = \ cfg.cc \ dwarf.cc \ ff_dwarf.cc \ + ff_external.cc \ rule_lexer.cc \ fst_translator.cc \ csplit.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index b516c386..54f6e12b 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -18,6 +18,7 @@ #include "ff_charset.h" #include "ff_wordset.h" #include "ff_dwarf.h" +#include "ff_external.h" #ifdef HAVE_GLC #include @@ -69,6 +70,7 @@ void register_feature_functions() { ff_registry.Register("WordPairFeatures", new FFFactory); ff_registry.Register("WordSet", new FFFactory); ff_registry.Register("Dwarf", new FFFactory); + ff_registry.Register("External", new FFFactory); #ifdef HAVE_GLC ff_registry.Register("ContextCRF", new FFFactory); #endif diff --git a/decoder/ff.h b/decoder/ff.h index 6c22d39f..227787ca 100644 --- a/decoder/ff.h +++ b/decoder/ff.h @@ -27,6 +27,7 @@ typedef std::vector Features; // set of features ids // depends on context, you may also need to implement // FinalTraversalFeatures(...) class FeatureFunction { + friend class ExternalFeature; public: std::string name_; // set by FF factory using usage() bool debug_; // also set by FF factory checking param for immediate initial "debug" diff --git a/decoder/ff_external.cc b/decoder/ff_external.cc new file mode 100644 index 00000000..520e98b1 --- /dev/null +++ b/decoder/ff_external.cc @@ -0,0 +1,57 @@ +#include "ff_external.h" +#include "stringlib.h" + +#include + +using namespace std; + +ExternalFeature::ExternalFeature(const string& param) { + size_t pos = param.find(' '); + string nparam; + string file = param; + if (pos < param.size()) { + nparam = Trim(param.substr(pos + 1)); + file = param.substr(0, pos); + } + if (file.size() < 1) { + cerr << "External requires a path to a dynamic library!\n"; + abort(); + } + lib_handle = dlopen(file.c_str(), RTLD_LAZY); + if (!lib_handle) { + cerr << "dlopen reports: " << dlerror() << endl; + cerr << "Did you provide a full path to the dynamic library?\n"; + abort(); + } + FeatureFunction* (*fn)(const string&) = + (FeatureFunction* (*)(const string&))(dlsym(lib_handle, "create_ff")); + if (!fn) { + cerr << "dlsym reports: " << dlerror() << endl; + abort(); + } + ff_ext = (*fn)(nparam); +} + +ExternalFeature::~ExternalFeature() { + delete ff_ext; + dlclose(lib_handle); +} + +void ExternalFeature::PrepareForInput(const SentenceMetadata& smeta) { + ff_ext->PrepareForInput(smeta); +} + +void ExternalFeature::FinalTraversalFeatures(const void* context, + SparseVector* features) const { + ff_ext->FinalTraversalFeatures(context, features); +} + +void ExternalFeature::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + FeatureVector* features, + FeatureVector* estimated_features, + void* context) const { + ff_ext->TraversalFeaturesImpl(smeta, edge, ant_contexts, features, estimated_features, context); +} + diff --git a/decoder/ff_external.h b/decoder/ff_external.h new file mode 100644 index 00000000..283e58e8 --- /dev/null +++ b/decoder/ff_external.h @@ -0,0 +1,26 @@ +#ifndef _FFEXTERNAL_H_ +#define _FFEXTERNAL_H_ + +#include "ff.h" + +// dynamically loaded feature function +class ExternalFeature : public FeatureFunction { + public: + ExternalFeature(const std::string& param); + ~ExternalFeature(); + virtual void PrepareForInput(const SentenceMetadata& smeta); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + FeatureVector* features, + FeatureVector* estimated_features, + void* context) const; + private: + void* lib_handle; + FeatureFunction* ff_ext; +}; + +#endif -- cgit v1.2.3 From 070d51e5554cc9840d97556d04bfa64a3d60b38c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 11 Oct 2012 23:13:12 -0400 Subject: example external feature function --- Makefile.am | 3 ++- configure.ac | 1 + example_extff/Makefile.am | 5 ++++ example_extff/README.md | 8 +++++++ example_extff/ff_example.cc | 56 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 example_extff/Makefile.am create mode 100644 example_extff/README.md create mode 100644 example_extff/ff_example.cc (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index c0826532..3e0103a8 100644 --- a/Makefile.am +++ b/Makefile.am @@ -14,7 +14,8 @@ SUBDIRS = \ dpmert \ pro-train \ rampion \ - minrisk + minrisk \ + example_extff #gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava diff --git a/configure.ac b/configure.ac index 967b657c..03a0ee87 100644 --- a/configure.ac +++ b/configure.ac @@ -125,6 +125,7 @@ AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) +AC_CONFIG_FILES([example_extff/Makefile]) AC_CONFIG_FILES([python/setup.py]) diff --git a/example_extff/Makefile.am b/example_extff/Makefile.am new file mode 100644 index 00000000..ac2694ca --- /dev/null +++ b/example_extff/Makefile.am @@ -0,0 +1,5 @@ +AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm -I../decoder + +lib_LTLIBRARIES = libff_example.la +libff_example_la_SOURCES = ff_example.cc +libff_example_la_LDFLAGS = -version-info 1:0:0 -module diff --git a/example_extff/README.md b/example_extff/README.md new file mode 100644 index 00000000..f2aba487 --- /dev/null +++ b/example_extff/README.md @@ -0,0 +1,8 @@ +This is an example of an _external_ feature function which is loaded as a dynamically linked library at run time to compute feature functions over derivations in a hypergraph. To load feature external feature functions, you can specify them in your `cdec.ini` configuration file as follows: + + feature_function=External /path/to/libmy_feature.so + +Any extra options are passed to the external library. + +*Note*: the build system uses [GNU Libtool](http://www.gnu.org/software/libtool/) to create the shared library. This may be placed in a hidden directory called `./libs`. + diff --git a/example_extff/ff_example.cc b/example_extff/ff_example.cc new file mode 100644 index 00000000..51ebf364 --- /dev/null +++ b/example_extff/ff_example.cc @@ -0,0 +1,56 @@ +#include "ff.h" +#include +#include + +using namespace std; + +// example of a "stateful" feature made available as an external library +// This feature looks nodes and their daughters and fires an indicator based +// on the arities of the rules involved. +// (X (X a) b (X c)) - this is a 2 arity parent with children of 0 and 0 arity +// so you get MAF_2_0_0=1 +class ParentChildrenArityFeatures : public FeatureFunction { + public: + ParentChildrenArityFeatures(const string& param) : fids(16, vector(256, -1)) { + SetStateSize(1); // number of bytes extra state required by this Feature + } + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const { + // Goal always is arity 1, so there's no discriminative value of + // computing a feature + } + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + FeatureVector* features, + FeatureVector* estimated_features, + void* context) const { + unsigned child_arity_code = 0; + for (unsigned j = 0; j < ant_contexts.size(); ++j) { + child_arity_code <<= 4; + child_arity_code |= *reinterpret_cast(ant_contexts[j]); + } + int& fid = fids[edge.Arity()][child_arity_code]; // reference! + if (fid < 0) { + ostringstream feature_string; + feature_string << "MAF_" << edge.Arity(); + for (unsigned j = 0; j < ant_contexts.size(); ++j) + feature_string << '_' << + static_cast(*reinterpret_cast(ant_contexts[j])); + fid = FD::Convert(feature_string.str()); + } + features->set_value(fid, 1.0); + *reinterpret_cast(context) = edge.Arity(); // save state + } + private: + mutable vector > fids; +}; + +// IMPORTANT: this function must be implemented by any external FF library +// if your library has multiple features, you can use str to configure things +extern "C" FeatureFunction* create_ff(const string& str) { + return new ParentChildrenArityFeatures(str); +} + + -- cgit v1.2.3 From 1fb7bfbbe287e868522613871ed6ca74369ed2a1 Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Mon, 22 Oct 2012 14:04:27 +0100 Subject: Update search, make it compile --- Makefile.am | 1 + configure.ac | 6 +- decoder/Makefile.am | 3 +- decoder/decoder.cc | 8 +- decoder/incremental.cc | 184 +++++++++++++++++++++++++++++++++++++++ decoder/incremental.h | 11 +++ decoder/lazy.cc | 178 -------------------------------------- decoder/lazy.h | 11 --- dtrain/Makefile.am | 2 +- klm/alone/Jamfile | 4 - klm/alone/assemble.cc | 76 ---------------- klm/alone/assemble.hh | 21 ----- klm/alone/graph.hh | 87 ------------------- klm/alone/just_vocab.cc | 14 --- klm/alone/labeled_edge.hh | 30 ------- klm/alone/main.cc | 85 ------------------ klm/alone/read.cc | 118 ------------------------- klm/alone/read.hh | 29 ------- klm/alone/threading.cc | 80 ----------------- klm/alone/threading.hh | 129 --------------------------- klm/alone/vocab.cc | 19 ---- klm/alone/vocab.hh | 34 -------- klm/lm/model.cc | 2 +- klm/lm/vocab.cc | 4 +- klm/lm/vocab.hh | 2 +- klm/search/Jamfile | 2 +- klm/search/Makefile.am | 11 +++ klm/search/arity.hh | 8 -- klm/search/context.hh | 10 +-- klm/search/edge.hh | 53 ++++++++---- klm/search/edge_generator.cc | 144 ++++++++++++++----------------- klm/search/edge_generator.hh | 49 ++++++----- klm/search/edge_queue.cc | 25 ------ klm/search/edge_queue.hh | 73 ---------------- klm/search/final.hh | 41 ++++----- klm/search/header.hh | 57 ++++++++++++ klm/search/source.hh | 48 ----------- klm/search/types.hh | 8 +- klm/search/vertex.cc | 10 +-- klm/search/vertex.hh | 55 ++++++------ klm/search/vertex_generator.cc | 97 ++++++++++++--------- klm/search/vertex_generator.hh | 33 +++---- klm/util/Makefile.am | 2 + klm/util/ersatz_progress.hh | 2 +- klm/util/exception.hh | 2 +- klm/util/pool.cc | 35 ++++++++ klm/util/pool.hh | 45 ++++++++++ klm/util/probing_hash_table.hh | 2 +- klm/util/string_piece.cc | 192 +++++++++++++++++++++++++++++++++++++++++ klm/util/tokenize_piece.hh | 12 +++ mira/Makefile.am | 2 +- training/Makefile.am | 38 ++++---- 52 files changed, 838 insertions(+), 1356 deletions(-) create mode 100644 decoder/incremental.cc create mode 100644 decoder/incremental.h delete mode 100644 decoder/lazy.cc delete mode 100644 decoder/lazy.h delete mode 100644 klm/alone/Jamfile delete mode 100644 klm/alone/assemble.cc delete mode 100644 klm/alone/assemble.hh delete mode 100644 klm/alone/graph.hh delete mode 100644 klm/alone/just_vocab.cc delete mode 100644 klm/alone/labeled_edge.hh delete mode 100644 klm/alone/main.cc delete mode 100644 klm/alone/read.cc delete mode 100644 klm/alone/read.hh delete mode 100644 klm/alone/threading.cc delete mode 100644 klm/alone/threading.hh delete mode 100644 klm/alone/vocab.cc delete mode 100644 klm/alone/vocab.hh create mode 100644 klm/search/Makefile.am delete mode 100644 klm/search/arity.hh delete mode 100644 klm/search/edge_queue.cc delete mode 100644 klm/search/edge_queue.hh create mode 100644 klm/search/header.hh delete mode 100644 klm/search/source.hh create mode 100644 klm/util/pool.cc create mode 100644 klm/util/pool.hh create mode 100644 klm/util/string_piece.cc (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index 3e0103a8..fefc470d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -6,6 +6,7 @@ SUBDIRS = \ mteval \ klm/util \ klm/lm \ + klm/search \ decoder \ training \ training/liblbfgs \ diff --git a/configure.ac b/configure.ac index 03a0ee87..cb132d66 100644 --- a/configure.ac +++ b/configure.ac @@ -12,6 +12,7 @@ AC_PROG_CXX AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS +BOOST_SYSTEM BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -73,9 +74,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, @@ -123,6 +124,7 @@ AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/search/Makefile]) AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 5c0a1964..f8f427d3 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -73,6 +73,7 @@ libcdec_a_SOURCES = \ ff_source_syntax.cc \ ff_bleu.cc \ ff_factory.cc \ + incremental.cc \ lexalign.cc \ lextrans.cc \ tagger.cc \ diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 052823ca..fe812011 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -39,7 +39,7 @@ #include "sampler.h" #include "forest_writer.h" // TODO this section should probably be handled by an Observer -#include "lazy.h" +#include "incremental.h" #include "hg_io.h" #include "aligner.h" @@ -412,7 +412,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") - ("lazy_search", po::value(), "Run lazy search with this language model file") + ("incremental_search", po::value(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") @@ -828,8 +828,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("show_target_graph")) HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); - if (conf.count("lazy_search")) { - PassToLazy(conf["lazy_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); + if (conf.count("incremental_search")) { + PassToIncremental(conf["incremental_search"].as().c_str(), CurrentWeightVector(), pop_limit, forest); o->NotifyDecodingComplete(smeta); return true; } diff --git a/decoder/incremental.cc b/decoder/incremental.cc new file mode 100644 index 00000000..768bbd65 --- /dev/null +++ b/decoder/incremental.cc @@ -0,0 +1,184 @@ +#include "incremental.h" + +#include "hg.h" +#include "fdict.h" +#include "tdict.h" + +#include "lm/enumerate_vocab.hh" +#include "lm/model.hh" +#include "search/config.hh" +#include "search/context.hh" +#include "search/edge.hh" +#include "search/edge_generator.hh" +#include "search/rule.hh" +#include "search/vertex.hh" +#include "search/vertex_generator.hh" +#include "util/exception.hh" + +#include +#include + +#include +#include + +namespace { + +struct MapVocab : public lm::EnumerateVocab { + public: + MapVocab() {} + + // Do not call after Lookup. + void Add(lm::WordIndex index, const StringPiece &str) { + const WordID cdec_id = TD::Convert(str.as_string()); + if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); + out_[cdec_id] = index; + } + + // Assumes Add has been called and will never be called again. + lm::WordIndex FromCDec(WordID id) const { + return out_[out_.size() > id ? id : 0]; + } + + private: + std::vector out_; +}; + +class IncrementalBase { + public: + IncrementalBase(const std::vector &weights) : + cdec_weights_(weights), + weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { + std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; + } + + virtual ~IncrementalBase() {} + + virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; + + static IncrementalBase *Load(const char *model_file, const std::vector &weights); + + protected: + lm::ngram::Config GetConfig() { + lm::ngram::Config ret; + ret.enumerate_vocab = &vocab_; + return ret; + } + + MapVocab vocab_; + + const std::vector &cdec_weights_; + + const search::Weights weights_; +}; + +template class Incremental : public IncrementalBase { + public: + Incremental(const char *model_file, const std::vector &weights) : IncrementalBase(weights), m_(model_file, GetConfig()) {} + + void Search(unsigned int pop_limit, const Hypergraph &hg) const; + + private: + void ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const; + + const Model m_; +}; + +IncrementalBase *IncrementalBase::Load(const char *model_file, const std::vector &weights) { + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; + switch (model_type) { + case lm::ngram::PROBING: + return new Incremental(model_file, weights); + case lm::ngram::REST_PROBING: + return new Incremental(model_file, weights); + default: + UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); + } +} + +void PrintFinal(const Hypergraph &hg, const search::Final final) { + const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); + const search::Final *child(final.Children()); + for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { + if (*i > 0) { + std::cout << TD::Convert(*i) << ' '; + } else { + PrintFinal(hg, *child++); + } + } +} + +template void Incremental::Search(unsigned int pop_limit, const Hypergraph &hg) const { + boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); + search::Config config(weights_, pop_limit); + search::Context context(config, m_); + + for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { + search::EdgeGenerator gen; + const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; + for (unsigned int j = 0; j < down_edges.size(); ++j) { + unsigned int edge_index = down_edges[j]; + ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], gen); + } + search::VertexGenerator vertex_gen(context, out_vertices[i]); + gen.Search(context, vertex_gen); + } + const search::Final top = out_vertices[hg.nodes_.size() - 2].BestChild(); + if (top.Valid()) { + std::cout << "NO PATH FOUND" << std::endl; + } else { + PrintFinal(hg, top); + std::cout << "||| " << top.GetScore() << std::endl; + } +} + +template void Incremental::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::EdgeGenerator &gen) const { + const std::vector &e = in.rule_->e(); + std::vector words; + words.reserve(e.size()); + std::vector nts; + unsigned int terminals = 0; + float score = 0.0; + for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { + if (*word <= 0) { + nts.push_back(vertices[in.tail_nodes_[-*word]].RootPartial()); + if (nts.back().Empty()) return; + score += nts.back().Bound(); + words.push_back(lm::kMaxWordIndex); + } else { + ++terminals; + words.push_back(vocab_.FromCDec(*word)); + } + } + + if (final) { + words.push_back(m_.GetVocabulary().EndSentence()); + } + + search::PartialEdge out(gen.AllocateEdge(nts.size())); + + memcpy(out.NT(), &nts[0], sizeof(search::PartialVertex) * nts.size()); + + search::Note note; + note.vp = ∈ + out.SetNote(note); + + score += in.rule_->GetFeatureValues().dot(cdec_weights_); + score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; + score += search::ScoreRule(context, words, final, out.Between()); + out.SetScore(score); + + gen.AddEdge(out); +} + +boost::scoped_ptr AwfulGlobalIncremental; + +} // namespace + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { + if (!AwfulGlobalIncremental.get()) { + std::cerr << "Pop limit " << pop_limit << std::endl; + AwfulGlobalIncremental.reset(IncrementalBase::Load(model_file, weights)); + } + AwfulGlobalIncremental->Search(pop_limit, hg); +} diff --git a/decoder/incremental.h b/decoder/incremental.h new file mode 100644 index 00000000..180383ce --- /dev/null +++ b/decoder/incremental.h @@ -0,0 +1,11 @@ +#ifndef _INCREMENTAL_H_ +#define _INCREMENTAL_H_ + +#include "weights.h" +#include + +class Hypergraph; + +void PassToIncremental(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); + +#endif // _INCREMENTAL_H_ diff --git a/decoder/lazy.cc b/decoder/lazy.cc deleted file mode 100644 index 1e6a94fe..00000000 --- a/decoder/lazy.cc +++ /dev/null @@ -1,178 +0,0 @@ -#include "hg.h" -#include "lazy.h" -#include "fdict.h" -#include "tdict.h" - -#include "lm/enumerate_vocab.hh" -#include "lm/model.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "search/edge.hh" -#include "search/edge_queue.hh" -#include "search/vertex.hh" -#include "search/vertex_generator.hh" -#include "util/exception.hh" - -#include -#include - -#include -#include - -namespace { - -struct MapVocab : public lm::EnumerateVocab { - public: - MapVocab() {} - - // Do not call after Lookup. - void Add(lm::WordIndex index, const StringPiece &str) { - const WordID cdec_id = TD::Convert(str.as_string()); - if (cdec_id >= out_.size()) out_.resize(cdec_id + 1); - out_[cdec_id] = index; - } - - // Assumes Add has been called and will never be called again. - lm::WordIndex FromCDec(WordID id) const { - return out_[out_.size() > id ? id : 0]; - } - - private: - std::vector out_; -}; - -class LazyBase { - public: - LazyBase(const std::vector &weights) : - cdec_weights_(weights), - weights_(weights[FD::Convert("KLanguageModel")], weights[FD::Convert("KLanguageModel_OOV")], weights[FD::Convert("WordPenalty")]) { - std::cerr << "Weights KLanguageModel " << weights_.LM() << " KLanguageModel_OOV " << weights_.OOV() << " WordPenalty " << weights_.WordPenalty() << std::endl; - } - - virtual ~LazyBase() {} - - virtual void Search(unsigned int pop_limit, const Hypergraph &hg) const = 0; - - static LazyBase *Load(const char *model_file, const std::vector &weights); - - protected: - lm::ngram::Config GetConfig() { - lm::ngram::Config ret; - ret.enumerate_vocab = &vocab_; - return ret; - } - - MapVocab vocab_; - - const std::vector &cdec_weights_; - - const search::Weights weights_; -}; - -template class Lazy : public LazyBase { - public: - Lazy(const char *model_file, const std::vector &weights) : LazyBase(weights), m_(model_file, GetConfig()) {} - - void Search(unsigned int pop_limit, const Hypergraph &hg) const; - - private: - unsigned char ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const; - - const Model m_; -}; - -LazyBase *LazyBase::Load(const char *model_file, const std::vector &weights) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(model_file, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - return new Lazy(model_file, weights); - case lm::ngram::REST_PROBING: - return new Lazy(model_file, weights); - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -void PrintFinal(const Hypergraph &hg, const search::Final &final) { - const std::vector &words = static_cast(final.GetNote().vp)->rule_->e(); - boost::array::const_iterator child(final.Children().begin()); - for (std::vector::const_iterator i = words.begin(); i != words.end(); ++i) { - if (*i > 0) { - std::cout << TD::Convert(*i) << ' '; - } else { - PrintFinal(hg, **child++); - } - } -} - -template void Lazy::Search(unsigned int pop_limit, const Hypergraph &hg) const { - boost::scoped_array out_vertices(new search::Vertex[hg.nodes_.size()]); - search::Config config(weights_, pop_limit); - search::Context context(config, m_); - - for (unsigned int i = 0; i < hg.nodes_.size() - 1; ++i) { - search::EdgeQueue queue(context.PopLimit()); - const Hypergraph::EdgesVector &down_edges = hg.nodes_[i].in_edges_; - for (unsigned int j = 0; j < down_edges.size(); ++j) { - unsigned int edge_index = down_edges[j]; - unsigned char arity = ConvertEdge(context, i == hg.nodes_.size() - 2, out_vertices.get(), hg.edges_[edge_index], queue.InitializeEdge()); - search::Note note; - note.vp = &hg.edges_[edge_index]; - if (arity != 255) queue.AddEdge(arity, note); - } - search::VertexGenerator vertex_gen(context, out_vertices[i]); - queue.Search(context, vertex_gen); - } - const search::Final *top = out_vertices[hg.nodes_.size() - 2].BestChild(); - if (!top) { - std::cout << "NO PATH FOUND" << std::endl; - } else { - PrintFinal(hg, *top); - std::cout << "||| " << top->Bound() << std::endl; - } -} - -template unsigned char Lazy::ConvertEdge(const search::Context &context, bool final, search::Vertex *vertices, const Hypergraph::Edge &in, search::PartialEdge &out) const { - const std::vector &e = in.rule_->e(); - std::vector words; - unsigned int terminals = 0; - unsigned char nt = 0; - out.score = 0.0; - for (std::vector::const_iterator word = e.begin(); word != e.end(); ++word) { - if (*word <= 0) { - out.nt[nt] = vertices[in.tail_nodes_[-*word]].RootPartial(); - if (out.nt[nt].Empty()) return 255; - out.score += out.nt[nt].Bound(); - ++nt; - words.push_back(lm::kMaxWordIndex); - } else { - ++terminals; - words.push_back(vocab_.FromCDec(*word)); - } - } - for (unsigned char fill = nt; fill < search::kMaxArity; ++fill) { - out.nt[fill] = search::kBlankPartialVertex; - } - - if (final) { - words.push_back(m_.GetVocabulary().EndSentence()); - } - - out.score += in.rule_->GetFeatureValues().dot(cdec_weights_); - out.score -= static_cast(terminals) * context.GetWeights().WordPenalty() / M_LN10; - out.score += search::ScoreRule(context, words, final, out.between); - return nt; -} - -boost::scoped_ptr AwfulGlobalLazy; - -} // namespace - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg) { - if (!AwfulGlobalLazy.get()) { - std::cerr << "Pop limit " << pop_limit << std::endl; - AwfulGlobalLazy.reset(LazyBase::Load(model_file, weights)); - } - AwfulGlobalLazy->Search(pop_limit, hg); -} diff --git a/decoder/lazy.h b/decoder/lazy.h deleted file mode 100644 index 94895b19..00000000 --- a/decoder/lazy.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef _LAZY_H_ -#define _LAZY_H_ - -#include "weights.h" -#include - -class Hypergraph; - -void PassToLazy(const char *model_file, const std::vector &weights, unsigned int pop_limit, const Hypergraph &hg); - -#endif // _LAZY_H_ diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am index 64fef489..ca9581f5 100644 --- a/dtrain/Makefile.am +++ b/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/klm/alone/Jamfile b/klm/alone/Jamfile deleted file mode 100644 index 2cc90c05..00000000 --- a/klm/alone/Jamfile +++ /dev/null @@ -1,4 +0,0 @@ -lib standalone : assemble.cc read.cc threading.cc vocab.cc ../lm//kenlm ../util//kenutil ../search//search : .. : : .. ../search//search ../lm//kenlm ; - -exe decode : main.cc standalone main.cc : multi:..//boost_thread ; -exe just_vocab : just_vocab.cc standalone : multi:..//boost_thread ; diff --git a/klm/alone/assemble.cc b/klm/alone/assemble.cc deleted file mode 100644 index 2ae72ce9..00000000 --- a/klm/alone/assemble.cc +++ /dev/null @@ -1,76 +0,0 @@ -#include "alone/assemble.hh" - -#include "alone/labeled_edge.hh" -#include "search/final.hh" - -#include - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final) { - const std::vector &words = static_cast(final.From()).Words(); - if (words.empty()) return o; - const search::Final *const *child = final.Children().data(); - std::vector::const_iterator i(words.begin()); - for (; i != words.end() - 1; ++i) { - if (*i) { - o << **i << ' '; - } else { - o << **child << ' '; - ++child; - } - } - - if (*i) { - if (**i != "") { - o << **i; - } - } else { - o << **child; - } - - return o; -} - -namespace { - -void MakeIndent(std::ostream &o, const char *indent_str, unsigned int level) { - for (unsigned int i = 0; i < level; ++i) - o << indent_str; -} - -void DetailedFinalInternal(std::ostream &o, const search::Final &final, const char *indent_str, unsigned int indent) { - o << "(\n"; - MakeIndent(o, indent_str, indent); - const std::vector &words = static_cast(final.From()).Words(); - const search::Final *const *child = final.Children().data(); - for (std::vector::const_iterator i(words.begin()); i != words.end(); ++i) { - if (*i) { - o << **i; - if (i == words.end() - 1) { - o << '\n'; - MakeIndent(o, indent_str, indent); - } else { - o << ' '; - } - } else { - // One extra indent from the line we're currently on. - o << indent_str; - DetailedFinalInternal(o, **child, indent_str, indent + 1); - for (unsigned int i = 0; i < indent; ++i) o << indent_str; - ++child; - } - } - o << ")=" << final.Bound() << '\n'; -} -} // namespace - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str) { - DetailedFinalInternal(o, final, indent_str, 0); -} - -void PrintFinal(const search::Final &final) { - std::cout << final << std::endl; -} - -} // namespace alone diff --git a/klm/alone/assemble.hh b/klm/alone/assemble.hh deleted file mode 100644 index e6b0ad5c..00000000 --- a/klm/alone/assemble.hh +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef ALONE_ASSEMBLE__ -#define ALONE_ASSEMBLE__ - -#include - -namespace search { -class Final; -} // namespace search - -namespace alone { - -std::ostream &operator<<(std::ostream &o, const search::Final &final); - -void DetailedFinal(std::ostream &o, const search::Final &final, const char *indent_str = " "); - -// This isn't called anywhere but makes it easy to print from gdb. -void PrintFinal(const search::Final &final); - -} // namespace alone - -#endif // ALONE_ASSEMBLE__ diff --git a/klm/alone/graph.hh b/klm/alone/graph.hh deleted file mode 100644 index 788352c9..00000000 --- a/klm/alone/graph.hh +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef ALONE_GRAPH__ -#define ALONE_GRAPH__ - -#include "alone/labeled_edge.hh" -#include "search/rule.hh" -#include "search/types.hh" -#include "search/vertex.hh" -#include "util/exception.hh" - -#include -#include -#include - -namespace alone { - -template class FixedAllocator : boost::noncopyable { - public: - FixedAllocator() : current_(NULL), end_(NULL) {} - - void Init(std::size_t count) { - assert(!current_); - array_.reset(new T[count]); - current_ = array_.get(); - end_ = current_ + count; - } - - T &operator[](std::size_t idx) { - return array_.get()[idx]; - } - - T *New() { - T *ret = current_++; - UTIL_THROW_IF(ret >= end_, util::Exception, "Allocating past end"); - return ret; - } - - std::size_t Size() const { - return end_ - array_.get(); - } - - private: - boost::scoped_array array_; - T *current_, *end_; -}; - -class Graph : boost::noncopyable { - public: - typedef LabeledEdge Edge; - typedef search::Vertex Vertex; - - Graph() {} - - void SetCounts(std::size_t vertices, std::size_t edges) { - vertices_.Init(vertices); - edges_.Init(edges); - } - - Vertex *NewVertex() { - return vertices_.New(); - } - - std::size_t VertexSize() const { return vertices_.Size(); } - - Vertex &MutableVertex(std::size_t index) { - return vertices_[index]; - } - - Edge *NewEdge() { - return edges_.New(); - } - - std::size_t EdgeSize() const { return edges_.Size(); } - - void SetRoot(Vertex *root) { root_ = root; } - - Vertex &Root() { return *root_; } - - private: - FixedAllocator vertices_; - FixedAllocator edges_; - - Vertex *root_; -}; - -} // namespace alone - -#endif // ALONE_GRAPH__ diff --git a/klm/alone/just_vocab.cc b/klm/alone/just_vocab.cc deleted file mode 100644 index 35aea5ed..00000000 --- a/klm/alone/just_vocab.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "alone/read.hh" -#include "util/file_piece.hh" - -#include - -int main() { - util::FilePiece f(0, "stdin", &std::cerr); - while (true) { - try { - alone::JustVocab(f, std::cout); - } catch (const util::EndOfFileException &e) { break; } - std::cout << '\n'; - } -} diff --git a/klm/alone/labeled_edge.hh b/klm/alone/labeled_edge.hh deleted file mode 100644 index 94d8cbdf..00000000 --- a/klm/alone/labeled_edge.hh +++ /dev/null @@ -1,30 +0,0 @@ -#ifndef ALONE_LABELED_EDGE__ -#define ALONE_LABELED_EDGE__ - -#include "search/edge.hh" - -#include -#include - -namespace alone { - -class LabeledEdge : public search::Edge { - public: - LabeledEdge() {} - - void AppendWord(const std::string *word) { - words_.push_back(word); - } - - const std::vector &Words() const { - return words_; - } - - private: - // NULL for non-terminals. - std::vector words_; -}; - -} // namespace alone - -#endif // ALONE_LABELED_EDGE__ diff --git a/klm/alone/main.cc b/klm/alone/main.cc deleted file mode 100644 index e09ab01d..00000000 --- a/klm/alone/main.cc +++ /dev/null @@ -1,85 +0,0 @@ -#include "alone/threading.hh" -#include "search/config.hh" -#include "search/context.hh" -#include "util/exception.hh" -#include "util/file_piece.hh" -#include "util/usage.hh" - -#include - -#include -#include - -namespace alone { - -template void ReadLoop(const std::string &graph_prefix, Control &control) { - for (unsigned int sentence = 0; ; ++sentence) { - std::stringstream name; - name << graph_prefix << '/' << sentence; - std::auto_ptr file; - try { - file.reset(new util::FilePiece(name.str().c_str())); - } catch (const util::ErrnoException &e) { - if (e.Error() == ENOENT) return; - throw; - } - control.Add(file.release()); - } -} - -template void RunWithModelType(const char *graph_prefix, const char *model_file, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - Model model(model_file); - search::Weights weights(weight_str); - search::Config config(weights, pop_limit); - - if (threads > 1) { -#ifdef WITH_THREADS - Controller controller(config, model, threads, std::cout); - ReadLoop(graph_prefix, controller); -#else - UTIL_THROW(util::Exception, "Threading support not compiled in."); -#endif - } else { - InThread controller(config, model, std::cout); - ReadLoop(graph_prefix, controller); - } -} - -void Run(const char *graph_prefix, const char *lm_name, StringPiece weight_str, unsigned int pop_limit, unsigned int threads) { - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(lm_name, model_type)) model_type = lm::ngram::PROBING; - switch (model_type) { - case lm::ngram::PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - case lm::ngram::REST_PROBING: - RunWithModelType(graph_prefix, lm_name, weight_str, pop_limit, threads); - break; - default: - UTIL_THROW(util::Exception, "Sorry this lm type isn't supported yet."); - } -} - -} // namespace alone - -int main(int argc, char *argv[]) { - if (argc < 5 || argc > 6) { - std::cerr << argv[0] << " graph_prefix lm \"weights\" pop [threads]" << std::endl; - return 1; - } - -#ifdef WITH_THREADS - unsigned thread_count = boost::thread::hardware_concurrency(); -#else - unsigned thread_count = 1; -#endif - if (argc == 6) { - thread_count = boost::lexical_cast(argv[5]); - UTIL_THROW_IF(!thread_count, util::Exception, "Thread count 0"); - } - UTIL_THROW_IF(!thread_count, util::Exception, "Boost doesn't know how many threads there are. Pass it on the command line."); - alone::Run(argv[1], argv[2], argv[3], boost::lexical_cast(argv[4]), thread_count); - - util::PrintUsage(std::cerr); - return 0; -} diff --git a/klm/alone/read.cc b/klm/alone/read.cc deleted file mode 100644 index 0b20be35..00000000 --- a/klm/alone/read.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include "alone/read.hh" - -#include "alone/graph.hh" -#include "alone/vocab.hh" -#include "search/arity.hh" -#include "search/context.hh" -#include "search/weights.hh" -#include "util/file_piece.hh" - -#include -#include - -#include - -namespace alone { - -namespace { - -template Graph::Edge &ReadEdge(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab, bool final) { - Graph::Edge *ret = to.NewEdge(); - - StringPiece got; - - std::vector words; - unsigned long int terminals = 0; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) { - // non-terminal - char *end_ptr; - unsigned long int child = std::strtoul(got.data() + 1, &end_ptr, 10); - UTIL_THROW_IF(end_ptr != got.data() + got.size() - 1, FormatException, "Bad non-terminal" << got); - UTIL_THROW_IF(child >= to.VertexSize(), FormatException, "Reference to vertex " << child << " but we only have " << to.VertexSize() << " vertices. Is the file in bottom-up format?"); - ret->Add(to.MutableVertex(child)); - words.push_back(lm::kMaxWordIndex); - ret->AppendWord(NULL); - } else { - const std::pair &found = vocab.FindOrAdd(got); - words.push_back(found.second); - ret->AppendWord(&found.first); - ++terminals; - } - } - if (final) { - // This is not counted for the word penalty. - words.push_back(vocab.EndSentence().second); - ret->AppendWord(&vocab.EndSentence().first); - } - // Hard-coded word penalty. - float additive = context.GetWeights().DotNoLM(from.ReadLine()) - context.GetWeights().WordPenalty() * static_cast(terminals) / M_LN10; - ret->InitRule().Init(context, additive, words, final); - unsigned int arity = ret->GetRule().Arity(); - UTIL_THROW_IF(arity > search::kMaxArity, util::Exception, "Edit search/arity.hh and increase " << search::kMaxArity << " to at least " << arity); - return *ret; -} - -} // namespace - -// TODO: refactor -void JustVocab(util::FilePiece &from, std::ostream &out) { - boost::unordered_set seen; - unsigned long int vertices = from.ReadULong(); - from.ReadULong(); // edges - UTIL_THROW_IF(vertices == 0, FormatException, "Vertex count is zero"); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - std::string temp; - for (unsigned long int i = 0; i < vertices; ++i) { - unsigned long int edge_count = from.ReadULong(); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - StringPiece got; - while ("|||" != (got = from.ReadDelimited())) { - if ('[' == *got.data() && ']' == got.data()[got.size() - 1]) continue; - temp.assign(got.data(), got.size()); - if (seen.insert(temp).second) out << temp << ' '; - } - from.ReadLine(); // weights - } - } - // Eat sentence - from.ReadLine(); -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab) { - unsigned long int vertices; - try { - vertices = from.ReadULong(); - } catch (const util::EndOfFileException &e) { return false; } - unsigned long int edges = from.ReadULong(); - UTIL_THROW_IF(vertices < 2, FormatException, "Vertex count is " << vertices); - UTIL_THROW_IF(edges == 0, FormatException, "Edge count is " << edges); - --vertices; - --edges; - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected newline after counts"); - to.SetCounts(vertices, edges); - Graph::Vertex *vertex; - for (unsigned long int i = 0; ; ++i) { - vertex = to.NewVertex(); - unsigned long int edge_count = from.ReadULong(); - bool root = (i == vertices - 1); - UTIL_THROW_IF('\n' != from.get(), FormatException, "Expected after edge count"); - for (unsigned long int e = 0; e < edge_count; ++e) { - vertex->Add(ReadEdge(context, from, to, vocab, root)); - } - vertex->FinishedAdding(); - if (root) break; - } - to.SetRoot(vertex); - StringPiece str = from.ReadLine(); - UTIL_THROW_IF("1" != str, FormatException, "Expected one edge to root"); - // The edge - from.ReadLine(); - return true; -} - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone diff --git a/klm/alone/read.hh b/klm/alone/read.hh deleted file mode 100644 index 10769a86..00000000 --- a/klm/alone/read.hh +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef ALONE_READ__ -#define ALONE_READ__ - -#include "util/exception.hh" - -#include - -namespace util { class FilePiece; } - -namespace search { template class Context; } - -namespace alone { - -class Graph; -class Vocab; - -class FormatException : public util::Exception { - public: - FormatException() {} - ~FormatException() throw() {} -}; - -void JustVocab(util::FilePiece &from, std::ostream &to); - -template bool ReadCDec(search::Context &context, util::FilePiece &from, Graph &to, Vocab &vocab); - -} // namespace alone - -#endif // ALONE_READ__ diff --git a/klm/alone/threading.cc b/klm/alone/threading.cc deleted file mode 100644 index 475386b6..00000000 --- a/klm/alone/threading.cc +++ /dev/null @@ -1,80 +0,0 @@ -#include "alone/threading.hh" - -#include "alone/assemble.hh" -#include "alone/graph.hh" -#include "alone/read.hh" -#include "alone/vocab.hh" -#include "lm/model.hh" -#include "search/context.hh" -#include "search/vertex_generator.hh" - -#include -#include -#include - -#include - -namespace alone { -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out) { - search::Context context(config, model); - Graph graph; - Vocab vocab(model.GetVocabulary()); - { - boost::scoped_ptr in(in_ptr); - ReadCDec(context, *in, graph, vocab); - } - - for (std::size_t i = 0; i < graph.VertexSize(); ++i) { - search::VertexGenerator(context, graph.MutableVertex(i)); - } - search::PartialVertex top = graph.Root().RootPartial(); - if (top.Empty()) { - out << "NO PATH FOUND"; - } else { - search::PartialVertex continuation; - while (!top.Complete()) { - top.Split(continuation); - top = continuation; - } - out << top.End() << " ||| " << top.End().Bound() << std::endl; - } -} - -template void Decode(const search::Config &config, const lm::ngram::ProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); -template void Decode(const search::Config &config, const lm::ngram::RestProbingModel &model, util::FilePiece *in_ptr, std::ostream &out); - -#ifdef WITH_THREADS -template void DecodeHandler::operator()(Input message) { - std::stringstream assemble; - Decode(config_, model_, message.file, assemble); - Produce(message.sentence_id, assemble.str()); -} - -template void DecodeHandler::Produce(unsigned int sentence_id, const std::string &str) { - Output out; - out.sentence_id = sentence_id; - out.str = new std::string(str); - out_.Produce(out); -} - -void PrintHandler::operator()(Output message) { - unsigned int relative = message.sentence_id - done_; - if (waiting_.size() <= relative) waiting_.resize(relative + 1); - waiting_[relative] = message.str; - for (std::string *lead; !waiting_.empty() && (lead = waiting_[0]); waiting_.pop_front(), ++done_) { - out_ << *lead; - delete lead; - } -} - -template Controller::Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to) : - sentence_id_(0), - printer_(decode_workers, 1, boost::ref(to), Output::Poison()), - decoder_(3, decode_workers, boost::in_place(boost::ref(config), boost::ref(model), boost::ref(printer_.In())), Input::Poison()) {} - -template class Controller; -template class Controller; - -#endif - -} // namespace alone diff --git a/klm/alone/threading.hh b/klm/alone/threading.hh deleted file mode 100644 index 0ab0f739..00000000 --- a/klm/alone/threading.hh +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef ALONE_THREADING__ -#define ALONE_THREADING__ - -#ifdef WITH_THREADS -#include "util/pcqueue.hh" -#include "util/pool.hh" -#endif - -#include -#include -#include - -namespace util { -class FilePiece; -} // namespace util - -namespace search { -class Config; -template class Context; -} // namespace search - -namespace alone { - -template void Decode(const search::Config &config, const Model &model, util::FilePiece *in_ptr, std::ostream &out); - -class Graph; - -#ifdef WITH_THREADS -struct SentenceID { - unsigned int sentence_id; - bool operator==(const SentenceID &other) const { - return sentence_id == other.sentence_id; - } -}; - -struct Input : public SentenceID { - util::FilePiece *file; - static Input Poison() { - Input ret; - ret.sentence_id = static_cast(-1); - ret.file = NULL; - return ret; - } -}; - -struct Output : public SentenceID { - std::string *str; - static Output Poison() { - Output ret; - ret.sentence_id = static_cast(-1); - ret.str = NULL; - return ret; - } -}; - -template class DecodeHandler { - public: - typedef Input Request; - - DecodeHandler(const search::Config &config, const Model &model, util::PCQueue &out) : config_(config), model_(model), out_(out) {} - - void operator()(Input message); - - private: - void Produce(unsigned int sentence_id, const std::string &str); - - const search::Config &config_; - - const Model &model_; - - util::PCQueue &out_; -}; - -class PrintHandler { - public: - typedef Output Request; - - explicit PrintHandler(std::ostream &o) : out_(o), done_(0) {} - - void operator()(Output message); - - private: - std::ostream &out_; - std::deque waiting_; - unsigned int done_; -}; - -template class Controller { - public: - // This config must remain valid. - explicit Controller(const search::Config &config, const Model &model, size_t decode_workers, std::ostream &to); - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Input input; - input.sentence_id = sentence_id_++; - input.file = in; - decoder_.Produce(input); - } - - private: - unsigned int sentence_id_; - - util::Pool printer_; - - util::Pool > decoder_; -}; -#endif - -// Same API as controller. -template class InThread { - public: - InThread(const search::Config &config, const Model &model, std::ostream &to) : config_(config), model_(model), to_(to) {} - - // Takes ownership of in. - void Add(util::FilePiece *in) { - Decode(config_, model_, in, to_); - } - - private: - const search::Config &config_; - - const Model &model_; - - std::ostream &to_; -}; - -} // namespace alone -#endif // ALONE_THREADING__ diff --git a/klm/alone/vocab.cc b/klm/alone/vocab.cc deleted file mode 100644 index ffe55301..00000000 --- a/klm/alone/vocab.cc +++ /dev/null @@ -1,19 +0,0 @@ -#include "alone/vocab.hh" - -#include "lm/virtual_interface.hh" -#include "util/string_piece.hh" - -namespace alone { - -Vocab::Vocab(const lm::base::Vocabulary &backing) : backing_(backing), end_sentence_(FindOrAdd("")) {} - -const std::pair &Vocab::FindOrAdd(const StringPiece &str) { - Map::const_iterator i(FindStringPiece(map_, str)); - if (i != map_.end()) return *i; - std::pair to_ins; - to_ins.first.assign(str.data(), str.size()); - to_ins.second = backing_.Index(str); - return *map_.insert(to_ins).first; -} - -} // namespace alone diff --git a/klm/alone/vocab.hh b/klm/alone/vocab.hh deleted file mode 100644 index 3ac0f542..00000000 --- a/klm/alone/vocab.hh +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef ALONE_VOCAB__ -#define ALONE_VOCAB__ - -#include "lm/word_index.hh" -#include "util/string_piece.hh" - -#include -#include - -#include - -namespace lm { namespace base { class Vocabulary; } } - -namespace alone { - -class Vocab { - public: - explicit Vocab(const lm::base::Vocabulary &backing); - - const std::pair &FindOrAdd(const StringPiece &str); - - const std::pair &EndSentence() const { return end_sentence_; } - - private: - typedef boost::unordered_map Map; - Map map_; - - const lm::base::Vocabulary &backing_; - - const std::pair &end_sentence_; -}; - -} // namespace alone -#endif // ALONE_VCOAB__ diff --git a/klm/lm/model.cc b/klm/lm/model.cc index 40af8a63..2fd20481 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -87,7 +87,7 @@ template void GenericModel.. ; +lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; import testing ; diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am new file mode 100644 index 00000000..ccc5b7f6 --- /dev/null +++ b/klm/search/Makefile.am @@ -0,0 +1,11 @@ +noinst_LIBRARIES = libksearch.a + +libksearch_a_SOURCES = \ + edge_generator.cc \ + rule.cc \ + vertex.cc \ + vertex_generator.cc \ + weights.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. + diff --git a/klm/search/arity.hh b/klm/search/arity.hh deleted file mode 100644 index 09c2c671..00000000 --- a/klm/search/arity.hh +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef SEARCH_ARITY__ -#define SEARCH_ARITY__ -namespace search { - -const unsigned int kMaxArity = 2; - -} // namespace search -#endif // SEARCH_ARITY__ diff --git a/klm/search/context.hh b/klm/search/context.hh index 27940053..62163144 100644 --- a/klm/search/context.hh +++ b/klm/search/context.hh @@ -7,6 +7,7 @@ #include "search/types.hh" #include "search/vertex.hh" #include "util/exception.hh" +#include "util/pool.hh" #include #include @@ -21,10 +22,8 @@ class ContextBase { public: explicit ContextBase(const Config &config) : pop_limit_(config.PopLimit()), weights_(config.GetWeights()) {} - Final *NewFinal() { - Final *ret = final_pool_.construct(); - assert(ret); - return ret; + util::Pool &FinalPool() { + return final_pool_; } VertexNode *NewVertexNode() { @@ -42,7 +41,8 @@ class ContextBase { const Weights &GetWeights() const { return weights_; } private: - boost::object_pool final_pool_; + util::Pool final_pool_; + boost::object_pool vertex_node_pool_; unsigned int pop_limit_; diff --git a/klm/search/edge.hh b/klm/search/edge.hh index 77ab0ade..187904bf 100644 --- a/klm/search/edge.hh +++ b/klm/search/edge.hh @@ -2,30 +2,53 @@ #define SEARCH_EDGE__ #include "lm/state.hh" -#include "search/arity.hh" -#include "search/rule.hh" +#include "search/header.hh" #include "search/types.hh" #include "search/vertex.hh" +#include "util/pool.hh" -#include +#include + +#include namespace search { -struct PartialEdge { - Score score; - // Terminals - lm::ngram::ChartState between[kMaxArity + 1]; - // Non-terminals - PartialVertex nt[kMaxArity]; +// Copyable, but the copy will be shallow. +class PartialEdge : public Header { + public: + // Allow default construction for STL. + PartialEdge() {} + + PartialEdge(util::Pool &pool, Arity arity) + : Header(pool.Allocate(Size(arity, arity + 1)), arity) {} + + PartialEdge(util::Pool &pool, Arity arity, Arity chart_states) + : Header(pool.Allocate(Size(arity, chart_states)), arity) {} - const lm::ngram::ChartState &CompletedState() const { - return between[0]; - } + // Non-terminals + const PartialVertex *NT() const { + return reinterpret_cast(After()); + } + PartialVertex *NT() { + return reinterpret_cast(After()); + } - bool operator<(const PartialEdge &other) const { - return score < other.score; - } + const lm::ngram::ChartState &CompletedState() const { + return *Between(); + } + const lm::ngram::ChartState *Between() const { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + lm::ngram::ChartState *Between() { + return reinterpret_cast(After() + GetArity() * sizeof(PartialVertex)); + } + + private: + static std::size_t Size(Arity arity, Arity chart_states) { + return kHeaderSize + arity * sizeof(PartialVertex) + chart_states * sizeof(lm::ngram::ChartState); + } }; + } // namespace search #endif // SEARCH_EDGE__ diff --git a/klm/search/edge_generator.cc b/klm/search/edge_generator.cc index 56239dfb..260159b1 100644 --- a/klm/search/edge_generator.cc +++ b/klm/search/edge_generator.cc @@ -4,117 +4,107 @@ #include "lm/partial.hh" #include "search/context.hh" #include "search/vertex.hh" -#include "search/vertex_generator.hh" #include namespace search { -EdgeGenerator::EdgeGenerator(PartialEdge &root, unsigned char arity, Note note) : arity_(arity), note_(note) { -/* for (unsigned char i = 0; i < edge.Arity(); ++i) { - root.nt[i] = edge.GetVertex(i).RootPartial(); - } - for (unsigned char i = edge.Arity(); i < 2; ++i) { - root.nt[i] = kBlankPartialVertex; - }*/ - generate_.push(&root); - top_score_ = root.score; -} - namespace { -template float FastScore(const Context &context, unsigned char victim, unsigned char arity, const PartialEdge &previous, PartialEdge &update) { - memcpy(update.between, previous.between, sizeof(lm::ngram::ChartState) * (arity + 1)); - - float ret = 0.0; - lm::ngram::ChartState *before, *after; - if (victim == 0) { - before = &update.between[0]; - after = &update.between[(arity == 2 && previous.nt[1].Complete()) ? 2 : 1]; - } else { - assert(victim == 1); - assert(arity == 2); - before = &update.between[previous.nt[0].Complete() ? 0 : 1]; - after = &update.between[2]; - } - const lm::ngram::ChartState &previous_reveal = previous.nt[victim].State(); - const PartialVertex &update_nt = update.nt[victim]; +template void FastScore(const Context &context, Arity victim, Arity before_idx, Arity incomplete, const PartialVertex &previous_vertex, PartialEdge update) { + lm::ngram::ChartState *between = update.Between(); + lm::ngram::ChartState *before = &between[before_idx], *after = &between[before_idx + 1]; + + float adjustment = 0.0; + const lm::ngram::ChartState &previous_reveal = previous_vertex.State(); + const PartialVertex &update_nt = update.NT()[victim]; const lm::ngram::ChartState &update_reveal = update_nt.State(); - float just_after = 0.0; if ((update_reveal.left.length > previous_reveal.left.length) || (update_reveal.left.full && !previous_reveal.left.full)) { - just_after += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); + adjustment += lm::ngram::RevealAfter(context.LanguageModel(), before->left, before->right, update_reveal.left, previous_reveal.left.length); } - if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous.nt[victim].RightFull())) { - ret += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); + if ((update_reveal.right.length > previous_reveal.right.length) || (update_nt.RightFull() && !previous_vertex.RightFull())) { + adjustment += lm::ngram::RevealBefore(context.LanguageModel(), update_reveal.right, previous_reveal.right.length, update_nt.RightFull(), after->left, after->right); } if (update_nt.Complete()) { if (update_reveal.left.full) { before->left.full = true; } else { assert(update_reveal.left.length == update_reveal.right.length); - ret += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); + adjustment += lm::ngram::Subsume(context.LanguageModel(), before->left, before->right, after->left, after->right, update_reveal.left.length); } - if (victim == 0) { - update.between[0].right = after->right; - } else { - update.between[2].left = before->left; + before->right = after->right; + // Shift the others shifted one down, covering after. + for (lm::ngram::ChartState *cover = after; cover < between + incomplete; ++cover) { + *cover = *(cover + 1); } } - return previous.score + (ret + just_after) * context.GetWeights().LM(); + update.SetScore(update.GetScore() + adjustment * context.GetWeights().LM()); } } // namespace -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool) { +template PartialEdge EdgeGenerator::Pop(Context &context) { assert(!generate_.empty()); - PartialEdge &top = *generate_.top(); + PartialEdge top = generate_.top(); generate_.pop(); - unsigned int victim = 0; - unsigned char lowest_length = 255; - for (unsigned char i = 0; i != arity_; ++i) { - if (!top.nt[i].Complete() && top.nt[i].Length() < lowest_length) { - lowest_length = top.nt[i].Length(); - victim = i; + PartialVertex *const top_nt = top.NT(); + const Arity arity = top.GetArity(); + + Arity victim = 0; + Arity victim_completed; + Arity incomplete; + // Select victim or return if complete. + { + Arity completed = 0; + unsigned char lowest_length = 255; + for (Arity i = 0; i != arity; ++i) { + if (top_nt[i].Complete()) { + ++completed; + } else if (top_nt[i].Length() < lowest_length) { + lowest_length = top_nt[i].Length(); + victim = i; + victim_completed = completed; + } } - } - if (lowest_length == 255) { - // All states report complete. - top.between[0].right = top.between[arity_].right; - // Now top.between[0] is the full edge state. - top_score_ = generate_.empty() ? -kScoreInf : generate_.top()->score; - return ⊤ + if (lowest_length == 255) { + return top; + } + incomplete = arity - completed; } - unsigned int stay = !victim; - PartialEdge &continuation = *static_cast(partial_edge_pool.malloc()); - float old_bound = top.nt[victim].Bound(); - // The alternate's score will change because alternate.nt[victim] changes. - bool split = top.nt[victim].Split(continuation.nt[victim]); - // top is now the alternate. + PartialVertex old_value(top_nt[victim]); + PartialVertex alternate_changed; + if (top_nt[victim].Split(alternate_changed)) { + PartialEdge alternate(partial_edge_pool_, arity, incomplete + 1); + alternate.SetScore(top.GetScore() + alternate_changed.Bound() - old_value.Bound()); - continuation.nt[stay] = top.nt[stay]; - continuation.score = FastScore(context, victim, arity_, top, continuation); - // TODO: dedupe? - generate_.push(&continuation); + alternate.SetNote(top.GetNote()); + + PartialVertex *alternate_nt = alternate.NT(); + for (Arity i = 0; i < victim; ++i) alternate_nt[i] = top_nt[i]; + alternate_nt[victim] = alternate_changed; + for (Arity i = victim + 1; i < arity; ++i) alternate_nt[i] = top_nt[i]; + + memcpy(alternate.Between(), top.Between(), sizeof(lm::ngram::ChartState) * (incomplete + 1)); - if (split) { - // We have an alternate. - top.score += top.nt[victim].Bound() - old_bound; // TODO: dedupe? - generate_.push(&top); - } else { - partial_edge_pool.free(&top); + generate_.push(alternate); } - top_score_ = generate_.top()->score; - return NULL; + // top is now the continuation. + FastScore(context, victim, victim - victim_completed, incomplete, old_value, top); + // TODO: dedupe? + generate_.push(top); + + // Invalid indicates no new hypothesis generated. + return PartialEdge(); } -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); -template PartialEdge *EdgeGenerator::Pop(Context &context, boost::pool<> &partial_edge_pool); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); +template PartialEdge EdgeGenerator::Pop(Context &context); } // namespace search diff --git a/klm/search/edge_generator.hh b/klm/search/edge_generator.hh index 875ccc5e..582c78b7 100644 --- a/klm/search/edge_generator.hh +++ b/klm/search/edge_generator.hh @@ -3,11 +3,8 @@ #include "search/edge.hh" #include "search/note.hh" +#include "search/types.hh" -#include -#include - -#include #include namespace lm { @@ -20,38 +17,40 @@ namespace search { template class Context; -class VertexGenerator; - -struct PartialEdgePointerLess : std::binary_function { - bool operator()(const PartialEdge *first, const PartialEdge *second) const { - return *first < *second; - } -}; - class EdgeGenerator { public: - EdgeGenerator(PartialEdge &root, unsigned char arity, Note note); + EdgeGenerator() {} - Score TopScore() const { - return top_score_; + PartialEdge AllocateEdge(Arity arity) { + return PartialEdge(partial_edge_pool_, arity); } - Note GetNote() const { - return note_; + void AddEdge(PartialEdge edge) { + generate_.push(edge); } - // Pop. If there's a complete hypothesis, return it. Otherwise return NULL. - template PartialEdge *Pop(Context &context, boost::pool<> &partial_edge_pool); + bool Empty() const { return generate_.empty(); } + + // Pop. If there's a complete hypothesis, return it. Otherwise return an invalid PartialEdge. + template PartialEdge Pop(Context &context); + + template void Search(Context &context, Output &output) { + unsigned to_pop = context.PopLimit(); + while (to_pop > 0 && !generate_.empty()) { + PartialEdge got(Pop(context)); + if (got.Valid()) { + output.NewHypothesis(got); + --to_pop; + } + } + output.FinishedSearch(); + } private: - Score top_score_; - - unsigned char arity_; + util::Pool partial_edge_pool_; - typedef std::priority_queue, PartialEdgePointerLess> Generate; + typedef std::priority_queue Generate; Generate generate_; - - Note note_; }; } // namespace search diff --git a/klm/search/edge_queue.cc b/klm/search/edge_queue.cc deleted file mode 100644 index e3ae6ebf..00000000 --- a/klm/search/edge_queue.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "search/edge_queue.hh" - -#include "lm/left.hh" -#include "search/context.hh" - -#include - -namespace search { - -EdgeQueue::EdgeQueue(unsigned int pop_limit_hint) : partial_edge_pool_(sizeof(PartialEdge), pop_limit_hint * 2) { - take_ = static_cast(partial_edge_pool_.malloc()); -} - -/*void EdgeQueue::AddEdge(PartialEdge &root, unsigned char arity, Note note) { - // Ignore empty edges. - for (unsigned char i = 0; i < edge.Arity(); ++i) { - PartialVertex root(edge.GetVertex(i).RootPartial()); - if (root.Empty()) return; - total_score += root.Bound(); - } - PartialEdge &allocated = *static_cast(partial_edge_pool_.malloc()); - allocated.score = total_score; -}*/ - -} // namespace search diff --git a/klm/search/edge_queue.hh b/klm/search/edge_queue.hh deleted file mode 100644 index 187eaed7..00000000 --- a/klm/search/edge_queue.hh +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef SEARCH_EDGE_QUEUE__ -#define SEARCH_EDGE_QUEUE__ - -#include "search/edge.hh" -#include "search/edge_generator.hh" -#include "search/note.hh" - -#include -#include - -#include - -namespace search { - -template class Context; - -class EdgeQueue { - public: - explicit EdgeQueue(unsigned int pop_limit_hint); - - PartialEdge &InitializeEdge() { - return *take_; - } - - void AddEdge(unsigned char arity, Note note) { - generate_.push(edge_pool_.construct(*take_, arity, note)); - take_ = static_cast(partial_edge_pool_.malloc()); - } - - bool Empty() const { return generate_.empty(); } - - /* Generate hypotheses and send them to output. Normally, output is a - * VertexGenerator, but the decoder may want to route edges to different - * vertices i.e. if they have different LHS non-terminal labels. - */ - template void Search(Context &context, Output &output) { - int to_pop = context.PopLimit(); - while (to_pop > 0 && !generate_.empty()) { - EdgeGenerator *top = generate_.top(); - generate_.pop(); - PartialEdge *ret = top->Pop(context, partial_edge_pool_); - if (ret) { - output.NewHypothesis(*ret, top->GetNote()); - --to_pop; - if (top->TopScore() != -kScoreInf) { - generate_.push(top); - } - } else { - generate_.push(top); - } - } - output.FinishedSearch(); - } - - private: - boost::object_pool edge_pool_; - - struct LessByTopScore : public std::binary_function { - bool operator()(const EdgeGenerator *first, const EdgeGenerator *second) const { - return first->TopScore() < second->TopScore(); - } - }; - - typedef std::priority_queue, LessByTopScore> Generate; - Generate generate_; - - boost::pool<> partial_edge_pool_; - - PartialEdge *take_; -}; - -} // namespace search -#endif // SEARCH_EDGE_QUEUE__ diff --git a/klm/search/final.hh b/klm/search/final.hh index 1b3092ac..50e62cf2 100644 --- a/klm/search/final.hh +++ b/klm/search/final.hh @@ -1,37 +1,34 @@ #ifndef SEARCH_FINAL__ #define SEARCH_FINAL__ -#include "search/arity.hh" -#include "search/note.hh" -#include "search/types.hh" - -#include +#include "search/header.hh" +#include "util/pool.hh" namespace search { -class Final { +// A full hypothesis with pointers to children. +class Final : public Header { public: - typedef boost::array ChildArray; + Final() {} - void Reset(Score bound, Note note, const Final &left, const Final &right) { - bound_ = bound; - note_ = note; - children_[0] = &left; - children_[1] = &right; + Final(util::Pool &pool, Score score, Arity arity, Note note) + : Header(pool.Allocate(Size(arity)), arity) { + SetScore(score); + SetNote(note); } - const ChildArray &Children() const { return children_; } - - Note GetNote() const { return note_; } - - Score Bound() const { return bound_; } + // These are arrays of length GetArity(). + Final *Children() { + return reinterpret_cast(After()); + } + const Final *Children() const { + return reinterpret_cast(After()); + } private: - Score bound_; - - Note note_; - - ChildArray children_; + static std::size_t Size(Arity arity) { + return kHeaderSize + arity * sizeof(const Final); + } }; } // namespace search diff --git a/klm/search/header.hh b/klm/search/header.hh new file mode 100644 index 00000000..25550dbe --- /dev/null +++ b/klm/search/header.hh @@ -0,0 +1,57 @@ +#ifndef SEARCH_HEADER__ +#define SEARCH_HEADER__ + +// Header consisting of Score, Arity, and Note + +#include "search/note.hh" +#include "search/types.hh" + +#include + +namespace search { + +// Copying is shallow. +class Header { + public: + bool Valid() const { return base_; } + + Score GetScore() const { + return *reinterpret_cast(base_); + } + void SetScore(Score to) { + *reinterpret_cast(base_) = to; + } + bool operator<(const Header &other) const { + return GetScore() < other.GetScore(); + } + + Arity GetArity() const { + return *reinterpret_cast(base_ + sizeof(Score)); + } + + Note GetNote() const { + return *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)); + } + void SetNote(Note to) { + *reinterpret_cast(base_ + sizeof(Score) + sizeof(Arity)) = to; + } + + protected: + Header() : base_(NULL) {} + + Header(void *base, Arity arity) : base_(static_cast(base)) { + *reinterpret_cast(base_ + sizeof(Score)) = arity; + } + + static const std::size_t kHeaderSize = sizeof(Score) + sizeof(Arity) + sizeof(Note); + + uint8_t *After() { return base_ + kHeaderSize; } + const uint8_t *After() const { return base_ + kHeaderSize; } + + private: + uint8_t *base_; +}; + +} // namespace search + +#endif // SEARCH_HEADER__ diff --git a/klm/search/source.hh b/klm/search/source.hh deleted file mode 100644 index 11839f7b..00000000 --- a/klm/search/source.hh +++ /dev/null @@ -1,48 +0,0 @@ -#ifndef SEARCH_SOURCE__ -#define SEARCH_SOURCE__ - -#include "search/types.hh" - -#include -#include - -namespace search { - -template class Source { - public: - Source() : bound_(kScoreInf) {} - - Index Size() const { - return final_.size(); - } - - Score Bound() const { - return bound_; - } - - const Final &operator[](Index index) const { - return *final_[index]; - } - - Score ScoreOrBound(Index index) const { - return Size() > index ? final_[index]->Total() : Bound(); - } - - protected: - void AddFinal(const Final &store) { - final_.push_back(&store); - } - - void SetBound(Score to) { - assert(to <= bound_ + 0.001); - bound_ = to; - } - - private: - std::vector final_; - - Score bound_; -}; - -} // namespace search -#endif // SEARCH_SOURCE__ diff --git a/klm/search/types.hh b/klm/search/types.hh index 9726379f..06eb5bfa 100644 --- a/klm/search/types.hh +++ b/klm/search/types.hh @@ -1,17 +1,13 @@ #ifndef SEARCH_TYPES__ #define SEARCH_TYPES__ -#include +#include namespace search { typedef float Score; -const Score kScoreInf = INFINITY; -// This could have been an enum but gcc wants 4 bytes. -typedef bool ExtendDirection; -const ExtendDirection kExtendLeft = 0; -const ExtendDirection kExtendRight = 1; +typedef uint32_t Arity; } // namespace search diff --git a/klm/search/vertex.cc b/klm/search/vertex.cc index cc53c0dd..11f4631f 100644 --- a/klm/search/vertex.cc +++ b/klm/search/vertex.cc @@ -21,9 +21,9 @@ struct GreaterByBound : public std::binary_functionBound(); + bound_ = end_.GetScore(); return; } if (extend_.size() == 1 && parent_ptr) { @@ -39,10 +39,4 @@ void VertexNode::SortAndSet(ContextBase &context, VertexNode **parent_ptr) { bound_ = extend_.front()->Bound(); } -namespace { -VertexNode kBlankVertexNode; -} // namespace - -PartialVertex kBlankPartialVertex(kBlankVertexNode); - } // namespace search diff --git a/klm/search/vertex.hh b/klm/search/vertex.hh index e1a9ad11..52bc1dfe 100644 --- a/klm/search/vertex.hh +++ b/klm/search/vertex.hh @@ -18,7 +18,7 @@ class ContextBase; class VertexNode { public: - VertexNode() : end_(NULL) {} + VertexNode() {} void InitRoot() { extend_.clear(); @@ -26,8 +26,7 @@ class VertexNode { state_.left.length = 0; state_.right.length = 0; right_full_ = false; - bound_ = -kScoreInf; - end_ = NULL; + end_ = Final(); } lm::ngram::ChartState &MutableState() { return state_; } @@ -37,19 +36,20 @@ class VertexNode { extend_.push_back(next); } - void SetEnd(Final *end) { end_ = end; } + void SetEnd(Final end) { + assert(!end_.Valid()); + end_ = end; + } - Final &MutableEnd() { return *end_; } - void SortAndSet(ContextBase &context, VertexNode **parent_pointer); // Should only happen to a root node when the entire vertex is empty. bool Empty() const { - return !end_ && extend_.empty(); + return !end_.Valid() && extend_.empty(); } bool Complete() const { - return end_; + return end_.Valid(); } const lm::ngram::ChartState &State() const { return state_; } @@ -63,8 +63,8 @@ class VertexNode { return state_.left.length + state_.right.length; } - // May be NULL. - const Final *End() const { return end_; } + // Will be invalid unless this is a leaf. + const Final End() const { return end_; } const VertexNode &operator[](size_t index) const { return *extend_[index]; @@ -81,7 +81,7 @@ class VertexNode { bool right_full_; Score bound_; - Final *end_; + Final end_; }; class PartialVertex { @@ -97,7 +97,7 @@ class PartialVertex { const lm::ngram::ChartState &State() const { return back_->State(); } bool RightFull() const { return back_->RightFull(); } - Score Bound() const { return Complete() ? back_->End()->Bound() : (*back_)[index_].Bound(); } + Score Bound() const { return Complete() ? back_->End().GetScore() : (*back_)[index_].Bound(); } unsigned char Length() const { return back_->Length(); } @@ -105,20 +105,24 @@ class PartialVertex { return index_ + 1 < back_->Size(); } - // Split into continuation and alternative, rendering this the alternative. - bool Split(PartialVertex &continuation) { + // Split into continuation and alternative, rendering this the continuation. + bool Split(PartialVertex &alternative) { assert(!Complete()); - continuation.back_ = &((*back_)[index_]); - continuation.index_ = 0; + bool ret; if (index_ + 1 < back_->Size()) { - ++index_; - return true; + alternative.index_ = index_ + 1; + alternative.back_ = back_; + ret = true; + } else { + ret = false; } - return false; + back_ = &((*back_)[index_]); + index_ = 0; + return ret; } - const Final &End() const { - return *back_->End(); + const Final End() const { + return back_->End(); } private: @@ -126,25 +130,22 @@ class PartialVertex { unsigned int index_; }; -extern PartialVertex kBlankPartialVertex; - class Vertex { public: Vertex() {} PartialVertex RootPartial() const { return PartialVertex(root_); } - const Final *BestChild() const { + const Final BestChild() const { PartialVertex top(RootPartial()); if (top.Empty()) { - return NULL; + return Final(); } else { PartialVertex continuation; while (!top.Complete()) { top.Split(continuation); - top = continuation; } - return &top.End(); + return top.End(); } } diff --git a/klm/search/vertex_generator.cc b/klm/search/vertex_generator.cc index d94e6e06..0945fe55 100644 --- a/klm/search/vertex_generator.cc +++ b/klm/search/vertex_generator.cc @@ -10,74 +10,85 @@ namespace search { VertexGenerator::VertexGenerator(ContextBase &context, Vertex &gen) : context_(context), gen_(gen) { gen.root_.InitRoot(); - root_.under = &gen.root_; } namespace { + const uint64_t kCompleteAdd = static_cast(-1); -} // namespace -void VertexGenerator::NewHypothesis(const PartialEdge &partial, Note note) { - const lm::ngram::ChartState &state = partial.CompletedState(); - std::pair got(existing_.insert(std::pair(hash_value(state), NULL))); - if (!got.second) { - // Found it already. - Final &exists = *got.first->second; - if (exists.Bound() < partial.score) { - exists.Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - } - return; +// Parallel structure to VertexNode. +struct Trie { + Trie() : under(NULL) {} + + VertexNode *under; + boost::unordered_map extend; +}; + +Trie &FindOrInsert(ContextBase &context, Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { + Trie &next = node.extend[added]; + if (!next.under) { + next.under = context.NewVertexNode(); + lm::ngram::ChartState &writing = next.under->MutableState(); + writing = state; + writing.left.full &= left_full && state.left.full; + next.under->MutableRightFull() = right_full && state.left.full; + writing.left.length = left; + writing.right.length = right; + node.under->AddExtend(next.under); } + return next; +} + +void CompleteTransition(ContextBase &context, Trie &starter, PartialEdge partial) { + Final final(context.FinalPool(), partial.GetScore(), partial.GetArity(), partial.GetNote()); + Final *child_out = final.Children(); + const PartialVertex *part = partial.NT(); + const PartialVertex *const part_end_loop = part + partial.GetArity(); + for (; part != part_end_loop; ++part, ++child_out) + *child_out = part->End(); + + starter.under->SetEnd(final); +} + +void AddHypothesis(ContextBase &context, Trie &root, PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + unsigned char left = 0, right = 0; - Trie *node = &root_; + Trie *node = &root; while (true) { if (left == state.left.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, true, right, false); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, true, right, false); for (; right < state.right.length; ++right) { - node = &FindOrInsert(*node, state.right.words[right], state, left, true, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, true, right + 1, false); } break; } - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, false); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, false); left++; if (right == state.right.length) { - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, left, false, right, true); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, left, false, right, true); for (; left < state.left.length; ++left) { - node = &FindOrInsert(*node, state.left.pointers[left], state, left + 1, false, right, true); + node = &FindOrInsert(context, *node, state.left.pointers[left], state, left + 1, false, right, true); } break; } - node = &FindOrInsert(*node, state.right.words[right], state, left, false, right + 1, false); + node = &FindOrInsert(context, *node, state.right.words[right], state, left, false, right + 1, false); right++; } - node = &FindOrInsert(*node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); - got.first->second = CompleteTransition(*node, state, note, partial); + node = &FindOrInsert(context, *node, kCompleteAdd - state.left.full, state, state.left.length, true, state.right.length, true); + CompleteTransition(context, *node, partial); } -VertexGenerator::Trie &VertexGenerator::FindOrInsert(VertexGenerator::Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full) { - VertexGenerator::Trie &next = node.extend[added]; - if (!next.under) { - next.under = context_.NewVertexNode(); - lm::ngram::ChartState &writing = next.under->MutableState(); - writing = state; - writing.left.full &= left_full && state.left.full; - next.under->MutableRightFull() = right_full && state.left.full; - writing.left.length = left; - writing.right.length = right; - node.under->AddExtend(next.under); - } - return next; -} +} // namespace -Final *VertexGenerator::CompleteTransition(VertexGenerator::Trie &starter, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial) { - VertexNode &node = *starter.under; - assert(node.State().left.full == state.left.full); - assert(!node.End()); - Final *final = context_.NewFinal(); - final->Reset(partial.score, note, partial.nt[0].End(), partial.nt[1].End()); - node.SetEnd(final); - return final; +void VertexGenerator::FinishedSearch() { + Trie root; + root.under = &gen_.root_; + for (Existing::const_iterator i(existing_.begin()); i != existing_.end(); ++i) { + AddHypothesis(context_, root, i->second); + } + root.under->SortAndSet(context_, NULL); } } // namespace search diff --git a/klm/search/vertex_generator.hh b/klm/search/vertex_generator.hh index 6b98da3e..60e86112 100644 --- a/klm/search/vertex_generator.hh +++ b/klm/search/vertex_generator.hh @@ -1,13 +1,11 @@ #ifndef SEARCH_VERTEX_GENERATOR__ #define SEARCH_VERTEX_GENERATOR__ -#include "search/note.hh" +#include "search/edge.hh" #include "search/vertex.hh" #include -#include - namespace lm { namespace ngram { class ChartState; @@ -18,40 +16,29 @@ namespace search { class ContextBase; class Final; -struct PartialEdge; class VertexGenerator { public: VertexGenerator(ContextBase &context, Vertex &gen); - void NewHypothesis(const PartialEdge &partial, Note note); - - void FinishedSearch() { - root_.under->SortAndSet(context_, NULL); + void NewHypothesis(PartialEdge partial) { + const lm::ngram::ChartState &state = partial.CompletedState(); + std::pair ret(existing_.insert(std::make_pair(hash_value(state), partial))); + if (!ret.second && ret.first->second < partial) { + ret.first->second = partial; + } } + void FinishedSearch(); + const Vertex &Generating() const { return gen_; } private: - // Parallel structure to VertexNode. - struct Trie { - Trie() : under(NULL) {} - - VertexNode *under; - boost::unordered_map extend; - }; - - Trie &FindOrInsert(Trie &node, uint64_t added, const lm::ngram::ChartState &state, unsigned char left, bool left_full, unsigned char right, bool right_full); - - Final *CompleteTransition(Trie &node, const lm::ngram::ChartState &state, Note note, const PartialEdge &partial); - ContextBase &context_; Vertex &gen_; - Trie root_; - - typedef boost::unordered_map Existing; + typedef boost::unordered_map Existing; Existing existing_; }; diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5ceccf2c..5306850f 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -26,6 +26,8 @@ libklm_util_a_SOURCES = \ file_piece.cc \ mmap.cc \ murmur_hash.cc \ + pool.cc \ + string_piece.cc \ usage.cc AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index ff4d590f..9909736d 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -4,7 +4,7 @@ #include #include -#include +#include // Ersatz version of boost::progress so core language model doesn't depend on // boost. Also adds option to print nothing. diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 83f99cd6..053a850b 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -6,7 +6,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/pool.cc b/klm/util/pool.cc new file mode 100644 index 00000000..2dffd06f --- /dev/null +++ b/klm/util/pool.cc @@ -0,0 +1,35 @@ +#include "util/pool.hh" + +#include + +namespace util { + +Pool::Pool() { + current_ = NULL; + current_end_ = NULL; +} + +Pool::~Pool() { + FreeAll(); +} + +void Pool::FreeAll() { + for (std::vector::const_iterator i(free_list_.begin()); i != free_list_.end(); ++i) { + free(*i); + } + free_list_.clear(); + current_ = NULL; + current_end_ = NULL; +} + +void *Pool::More(std::size_t size) { + std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); + uint8_t *ret = static_cast(malloc(amount)); + if (!ret) throw std::bad_alloc(); + free_list_.push_back(ret); + current_ = ret + size; + current_end_ = ret + amount; + return ret; +} + +} // namespace util diff --git a/klm/util/pool.hh b/klm/util/pool.hh new file mode 100644 index 00000000..72f8a0c8 --- /dev/null +++ b/klm/util/pool.hh @@ -0,0 +1,45 @@ +// Very simple pool. It can only allocate memory. And all of the memory it +// allocates must be freed at the same time. + +#ifndef UTIL_POOL__ +#define UTIL_POOL__ + +#include + +#include + +namespace util { + +class Pool { + public: + Pool(); + + ~Pool(); + + void *Allocate(std::size_t size) { + void *ret = current_; + current_ += size; + if (current_ < current_end_) { + return ret; + } else { + return More(size); + } + } + + void FreeAll(); + + private: + void *More(std::size_t size); + + std::vector free_list_; + + uint8_t *current_, *current_end_; + + // no copying + Pool(const Pool &); + Pool &operator=(const Pool &); +}; + +} // namespace util + +#endif // UTIL_POOL__ diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 770faa7e..4a8aff35 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -8,7 +8,7 @@ #include #include -#include +#include namespace util { diff --git a/klm/util/string_piece.cc b/klm/util/string_piece.cc new file mode 100644 index 00000000..b422cefc --- /dev/null +++ b/klm/util/string_piece.cc @@ -0,0 +1,192 @@ +// Copyright 2004 The RE2 Authors. All Rights Reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in string_piece.hh. + +#include "util/string_piece.hh" + +#include + +#include + +#ifndef HAVE_ICU + +typedef StringPiece::size_type size_type; + +void StringPiece::CopyToString(std::string* target) const { + target->assign(ptr_, length_); +} + +size_type StringPiece::find(const StringPiece& s, size_type pos) const { + if (length_ < 0 || pos > static_cast(length_)) + return npos; + + const char* result = std::search(ptr_ + pos, ptr_ + length_, + s.ptr_, s.ptr_ + s.length_); + const size_type xpos = result - ptr_; + return xpos + s.length_ <= length_ ? xpos : npos; +} + +size_type StringPiece::find(char c, size_type pos) const { + if (length_ <= 0 || pos >= static_cast(length_)) { + return npos; + } + const char* result = std::find(ptr_ + pos, ptr_ + length_, c); + return result != ptr_ + length_ ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(const StringPiece& s, size_type pos) const { + if (length_ < s.length_) return npos; + const size_t ulen = length_; + if (s.length_ == 0) return std::min(ulen, pos); + + const char* last = ptr_ + std::min(ulen - s.length_, pos) + s.length_; + const char* result = std::find_end(ptr_, last, s.ptr_, s.ptr_ + s.length_); + return result != last ? result - ptr_ : npos; +} + +size_type StringPiece::rfind(char c, size_type pos) const { + if (length_ <= 0) return npos; + for (int i = std::min(pos, static_cast(length_ - 1)); + i >= 0; --i) { + if (ptr_[i] == c) { + return i; + } + } + return npos; +} + +// For each character in characters_wanted, sets the index corresponding +// to the ASCII code of that character to 1 in table. This is used by +// the find_.*_of methods below to tell whether or not a character is in +// the lookup table in constant time. +// The argument `table' must be an array that is large enough to hold all +// the possible values of an unsigned char. Thus it should be be declared +// as follows: +// bool table[UCHAR_MAX + 1] +static inline void BuildLookupTable(const StringPiece& characters_wanted, + bool* table) { + const size_type length = characters_wanted.length(); + const char* const data = characters_wanted.data(); + for (size_type i = 0; i < length; ++i) { + table[static_cast(data[i])] = true; + } +} + +size_type StringPiece::find_first_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + if (s.length_ == 0) + return 0; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_first_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = pos; i < length_; ++i) { + if (!lookup[static_cast(ptr_[i])]) { + return i; + } + } + return npos; +} + +size_type StringPiece::find_first_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (; pos < length_; ++pos) { + if (ptr_[pos] != c) { + return pos; + } + } + return npos; +} + +size_type StringPiece::find_last_of(const StringPiece& s, size_type pos) const { + if (length_ == 0 || s.length_ == 0) + return npos; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(const StringPiece& s, + size_type pos) const { + if (length_ == 0) + return npos; + + size_type i = std::min(pos, length_ - 1); + if (s.length_ == 0) + return i; + + // Avoid the cost of BuildLookupTable() for a single-character search. + if (s.length_ == 1) + return find_last_not_of(s.ptr_[0], pos); + + bool lookup[UCHAR_MAX + 1] = { false }; + BuildLookupTable(s, lookup); + for (; ; --i) { + if (!lookup[static_cast(ptr_[i])]) + return i; + if (i == 0) + break; + } + return npos; +} + +size_type StringPiece::find_last_not_of(char c, size_type pos) const { + if (length_ == 0) + return npos; + + for (size_type i = std::min(pos, length_ - 1); ; --i) { + if (ptr_[i] != c) + return i; + if (i == 0) + break; + } + return npos; +} + +StringPiece StringPiece::substr(size_type pos, size_type n) const { + if (pos > length_) pos = length_; + if (n > length_ - pos) n = length_ - pos; + return StringPiece(ptr_ + pos, n); +} + +const size_type StringPiece::npos = size_type(-1); + +#endif // !HAVE_ICU diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index c7e1c863..4a7f5460 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -54,6 +54,18 @@ class AnyCharacter { StringPiece chars_; }; +class AnyCharacterLast { + public: + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} + + StringPiece Find(const StringPiece &in) const { + return StringPiece(std::find_end(in.data(), in.data() + in.size(), chars_.data(), chars_.data() + chars_.size()), 1); + } + + private: + StringPiece chars_; +}; + template class TokenIter : public boost::iterator_facade, const StringPiece, boost::forward_traversal_tag> { public: TokenIter() {} diff --git a/mira/Makefile.am b/mira/Makefile.am index 7b4a4e12..3f8f17cd 100644 --- a/mira/Makefile.am +++ b/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/Makefile.am b/training/Makefile.am index 5254333a..f9c25391 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -32,60 +32,60 @@ libtraining_a_SOURCES = \ risk.cc mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/utils/libutils.a -lz +plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm -- cgit v1.2.3 From 2f482858e63dc7f62ac3be5b7ed7e0644b63353e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Nov 2012 17:21:30 -0500 Subject: make meteor jar configurable at build time --- configure.ac | 23 ++++++++++++++++++++--- mteval/Makefile.am | 14 +++++++++++++- mteval/ns.cc | 17 +++++++++-------- 3 files changed, 42 insertions(+), 12 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index cb132d66..233009ca 100644 --- a/configure.ac +++ b/configure.ac @@ -13,6 +13,7 @@ AC_LANG_CPLUSPLUS BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS BOOST_SYSTEM +BOOST_SERIALIZATION BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) @@ -26,10 +27,24 @@ AM_CONDITIONAL([MPI], [test "x$mpi" = xyes]) if test "x$mpi" = xyes then - BOOST_SERIALIZATION AC_DEFINE([HAVE_MPI], [1], [flag for MPI]) - # TODO BOOST_MPI needs to be implemented - LIBS="$LIBS -lboost_mpi $BOOST_SERIALIZATION_LIBS" + LIBS="$LIBS -lboost_mpi" +fi + +AM_CONDITIONAL([HAVE_METEOR], false) +AC_ARG_WITH(meteor, + [AC_HELP_STRING([--with-meteor=PATH], [(optional) path to METEOR jar])], + [with_meteor=$withval], + [with_meteor=no] + ) + +if test "x$with_meteor" != 'xno' +then + AC_CHECK_FILE([$with_meteor], + [AC_DEFINE([HAVE_METEOR], [1], [flag for METEOR jar library])], + [AC_MSG_ERROR([Cannot find METEOR jar!])]) + AC_SUBST(METEOR_JAR,"${with_meteor}") + AM_CONDITIONAL([HAVE_METEOR], true) fi AM_CONDITIONAL([HAVE_CMPH], false) @@ -129,6 +144,8 @@ AC_CONFIG_FILES([mira/Makefile]) AC_CONFIG_FILES([dtrain/Makefile]) AC_CONFIG_FILES([example_extff/Makefile]) +AC_CONFIG_FILES([mteval/meteor_jar.cc]) + AC_CONFIG_FILES([python/setup.py]) AC_OUTPUT diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 22550c99..5e9bba91 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -8,7 +8,19 @@ TESTS = scorer_test noinst_LIBRARIES = libmteval.a -libmteval_a_SOURCES = ter.cc comb_scorer.cc aer_scorer.cc scorer.cc external_scorer.cc ns.cc ns_ter.cc ns_ext.cc ns_comb.cc ns_docscorer.cc ns_cer.cc +libmteval_a_SOURCES = \ + aer_scorer.cc \ + comb_scorer.cc \ + external_scorer.cc \ + meteor_jar.cc \ + ns.cc \ + ns_cer.cc \ + ns_comb.cc \ + ns_docscorer.cc \ + ns_ext.cc \ + ns_ter.cc \ + scorer.cc \ + ter.cc fast_score_SOURCES = fast_score.cc fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz diff --git a/mteval/ns.cc b/mteval/ns.cc index f3a82ce0..7d73061c 100644 --- a/mteval/ns.cc +++ b/mteval/ns.cc @@ -19,6 +19,8 @@ using namespace std; map EvaluationMetric::instances_; +extern const char* meteor_jar_path; + SegmentEvaluator::~SegmentEvaluator() {} EvaluationMetric::~EvaluationMetric() {} @@ -235,13 +237,7 @@ struct BleuMetric : public EvaluationMetric { EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) { static bool is_first = true; - static string meteor_jar_path = "/cab0/tools/meteor-1.3/meteor-1.3.jar"; if (is_first) { - const char* ppath = getenv("METEOR_JAR"); - if (ppath) { - cerr << "METEOR_JAR environment variable set to " << ppath << endl; - meteor_jar_path = ppath; - } instances_["NULL"] = NULL; is_first = false; } @@ -259,11 +255,16 @@ EvaluationMetric* EvaluationMetric::Instance(const string& imetric_id) { } else if (metric_id == "TER") { m = new TERMetric; } else if (metric_id == "METEOR") { +#if HAVE_METEOR if (!FileExists(meteor_jar_path)) { - cerr << meteor_jar_path << " not found. Set METEOR_JAR environment variable.\n"; + cerr << meteor_jar_path << " not found!\n"; abort(); } - m = new ExternalMetric("METEOR", "java -Xmx1536m -jar " + meteor_jar_path + " - - -mira -lower -t tune -l en"); + m = new ExternalMetric("METEOR", string("java -Xmx1536m -jar ") + meteor_jar_path + " - - -mira -lower -t tune -l en"); +#else + cerr << "cdec was not built with the --with-meteor option." << endl; + abort(); +#endif } else if (metric_id.find("COMB:") == 0) { m = new CombinationMetric(metric_id); } else if (metric_id == "CER") { -- cgit v1.2.3 From c401956e25295bdb97dd633817ff9a4f1dcf8c4c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 17 Nov 2012 17:46:06 -0500 Subject: remove some old stuff --- configure.ac | 45 --------------------------------------------- decoder/Makefile.am | 4 ---- 2 files changed, 49 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 233009ca..5f18beaa 100644 --- a/configure.ac +++ b/configure.ac @@ -68,25 +68,6 @@ then AM_CONDITIONAL([HAVE_CMPH], true) fi -AM_CONDITIONAL([HAVE_EIGEN], false) -AC_ARG_WITH(eigen, - [AC_HELP_STRING([--with-eigen=PATH], [(optional) path to Eigen linear algebra library])], - [with_eigen=$withval], - [with_eigen=no] - ) - -if test "x$with_eigen" != 'xno' -then - SAVE_CPPFLAGS="$CPPFLAGS" - CPPFLAGS="$CPPFLAGS -I${with_eigen}" - - AC_CHECK_HEADER(Eigen/Dense, - [AC_DEFINE([HAVE_EIGEN], [1], [flag for Eigen linear algebra library])], - [AC_MSG_ERROR([Cannot find Eigen!])]) - - AM_CONDITIONAL([HAVE_EIGEN], true) -fi - #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" @@ -99,32 +80,6 @@ AC_CHECK_HEADER(google/dense_hash_map, AC_PROG_INSTALL -AM_CONDITIONAL([GLC], false) -AC_ARG_WITH(glc, - [AC_HELP_STRING([--with-glc=PATH], [(optional) path to Global Lexical Coherence package (Context CRF)])], - [with_glc=$withval], - [with_glc=no] - ) -FF_GLC="" - -if test "x$with_glc" != 'xno' -then - SAVE_CPPFLAGS="$CPPFLAGS" - CPPFLAGS="$CPPFLAGS -I${with_glc} -I${with_glc}/cdec" - - #AC_CHECK_HEADER(ff_glc.h, - # [AC_DEFINE([HAVE_GLC], [], [flag for GLC])], - # [AC_MSG_ERROR([Cannot find GLC!])]) - - AC_DEFINE([HAVE_GLC], [], [flag for GLC]) - #LIB_RANDLM="-lrandlm" - #LDFLAGS="$LDFLAGS -L${with_glc}/lib" - #LIBS="$LIBS $LIB_RANDLM" - #FMTLIBS="$FMTLIBS libglc.a" - AC_SUBST(FF_GLC,"${with_glc}/cdec/ff_glc.cc") - AM_CONDITIONAL([GLC], true) -fi - CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H" AC_CONFIG_FILES([Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index f8f427d3..6914fa0f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -83,7 +83,3 @@ libcdec_a_SOURCES = \ json_parse.cc \ grammar.cc -if GLC - # Until we build GLC as a library... - libcdec_a_SOURCES += ff_glc.cc string_util.cc feature-factory.cc -endif -- cgit v1.2.3 From 28ba2377913445399f9e98163907c84d9f8ffa5a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 11:33:03 -0500 Subject: fix part 2 --- Makefile.am | 2 +- configure.ac | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index fefc470d..7ca7268a 100644 --- a/Makefile.am +++ b/Makefile.am @@ -13,7 +13,7 @@ SUBDIRS = \ mira \ dtrain \ dpmert \ - pro-train \ + pro \ rampion \ minrisk \ example_extff diff --git a/configure.ac b/configure.ac index 5f18beaa..34e138c2 100644 --- a/configure.ac +++ b/configure.ac @@ -89,7 +89,7 @@ AC_CONFIG_FILES([decoder/Makefile]) AC_CONFIG_FILES([training/Makefile]) AC_CONFIG_FILES([training/liblbfgs/Makefile]) AC_CONFIG_FILES([dpmert/Makefile]) -AC_CONFIG_FILES([pro-train/Makefile]) +AC_CONFIG_FILES([pro/Makefile]) AC_CONFIG_FILES([rampion/Makefile]) AC_CONFIG_FILES([minrisk/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) -- cgit v1.2.3 From fbdacabc85bea65d735f2cb7f92b98e08ce72d04 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 11:38:03 -0500 Subject: fix reference to serialization --- configure.ac | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 34e138c2..09fc5c5b 100644 --- a/configure.ac +++ b/configure.ac @@ -70,9 +70,9 @@ fi #BOOST_THREADS CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" -LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SYSTEM_LDFLAGS" +LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SERIALIZATION_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SYSTEM_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_LIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, -- cgit v1.2.3 From 8aa29810bb77611cc20b7a384897ff6703783ea1 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 18 Nov 2012 13:35:42 -0500 Subject: major restructure of the training code --- .gitignore | 28 +- Makefile.am | 7 +- configure.ac | 32 +- dpmert/Makefile.am | 33 - dpmert/README.shared-mem | 9 - dpmert/ces.cc | 90 -- dpmert/ces.h | 16 - dpmert/decode-and-evaluate.pl | 246 ---- dpmert/divide_refs.py | 15 - dpmert/dpmert.pl | 617 --------- dpmert/error_surface.cc | 42 - dpmert/error_surface.h | 24 - dpmert/libcall.pl | 71 - dpmert/line_mediator.pl | 116 -- dpmert/line_optimizer.cc | 114 -- dpmert/line_optimizer.h | 48 - dpmert/lo_test.cc | 229 --- dpmert/mert_geometry.cc | 185 --- dpmert/mert_geometry.h | 81 -- dpmert/mr_dpmert_generate_mapper_input.cc | 81 -- dpmert/mr_dpmert_map.cc | 112 -- dpmert/mr_dpmert_reduce.cc | 77 -- dpmert/parallelize.pl | 423 ------ dpmert/sentclient.c | 76 - dpmert/sentserver.c | 515 ------- dpmert/sentserver.h | 6 - dpmert/test_aer/README | 8 - dpmert/test_aer/cdec.ini | 3 - dpmert/test_aer/corpus.src | 3 - dpmert/test_aer/grammar | 12 - dpmert/test_aer/ref.0 | 3 - dpmert/test_aer/weights | 13 - dpmert/test_data/0.json.gz | Bin 13709 -> 0 bytes dpmert/test_data/1.json.gz | Bin 204803 -> 0 bytes dpmert/test_data/c2e.txt.0 | 2 - dpmert/test_data/c2e.txt.1 | 2 - dpmert/test_data/c2e.txt.2 | 2 - dpmert/test_data/c2e.txt.3 | 2 - dpmert/test_data/re.txt.0 | 5 - dpmert/test_data/re.txt.1 | 5 - dpmert/test_data/re.txt.2 | 5 - dpmert/test_data/re.txt.3 | 5 - dtrain/Makefile.am | 7 - dtrain/README.md | 48 - dtrain/dtrain.cc | 657 --------- dtrain/dtrain.h | 97 -- dtrain/hstreaming/avg.rb | 32 - dtrain/hstreaming/cdec.ini | 22 - dtrain/hstreaming/dtrain.ini | 15 - dtrain/hstreaming/dtrain.sh | 9 - dtrain/hstreaming/hadoop-streaming-job.sh | 30 - dtrain/hstreaming/lplp.rb | 131 -- dtrain/hstreaming/red-test | 9 - dtrain/kbestget.h | 152 -- dtrain/ksampler.h | 61 - dtrain/pairsampling.h | 149 -- dtrain/parallelize.rb | 79 -- dtrain/parallelize/test/cdec.ini | 22 - dtrain/parallelize/test/dtrain.ini | 15 - dtrain/parallelize/test/in | 10 - dtrain/parallelize/test/refs | 10 - dtrain/score.cc | 254 ---- dtrain/score.h | 212 --- dtrain/test/example/README | 8 - dtrain/test/example/cdec.ini | 25 - dtrain/test/example/dtrain.ini | 22 - dtrain/test/example/expected-output | 89 -- dtrain/test/parallelize/cdec.ini | 22 - dtrain/test/parallelize/dtrain.ini | 15 - dtrain/test/parallelize/in | 10 - dtrain/test/parallelize/refs | 10 - dtrain/test/toy/cdec.ini | 2 - dtrain/test/toy/dtrain.ini | 12 - dtrain/test/toy/input | 2 - minrisk/Makefile.am | 6 - minrisk/minrisk.pl | 540 -------- minrisk/minrisk_generate_input.pl | 18 - minrisk/minrisk_optimize.cc | 197 --- mira/Makefile.am | 6 - mira/kbest_mira.cc | 309 ----- pro/Makefile.am | 11 - pro/README.shared-mem | 9 - pro/mr_pro_generate_mapper_input.pl | 18 - pro/mr_pro_map.cc | 201 --- pro/mr_pro_reduce.cc | 286 ---- pro/pro.pl | 555 -------- rampion/Makefile.am | 6 - rampion/rampion.pl | 540 -------- rampion/rampion_cccp.cc | 168 --- rampion/rampion_generate_input.pl | 18 - training/Makefile.am | 100 +- training/add-model1-features-to-scfg.pl | 93 -- training/candidate_set.cc | 169 --- training/candidate_set.h | 60 - training/cllh_observer.cc | 52 - training/cllh_observer.h | 26 - training/collapse_weights.cc | 110 -- training/crf/Makefile.am | 27 + training/crf/cllh_observer.cc | 52 + training/crf/cllh_observer.h | 26 + training/crf/mpi_batch_optimize.cc | 372 +++++ training/crf/mpi_compute_cllh.cc | 134 ++ training/crf/mpi_extract_features.cc | 151 ++ training/crf/mpi_extract_reachable.cc | 163 +++ training/crf/mpi_flex_optimize.cc | 386 ++++++ training/crf/mpi_online_optimize.cc | 374 +++++ training/dep-reorder/conll2reordering-forest.pl | 65 - training/dep-reorder/george.conll | 4 - training/dep-reorder/scripts/conll2simplecfg.pl | 57 - training/dpmert/Makefile.am | 25 + training/dpmert/ces.cc | 90 ++ training/dpmert/ces.h | 16 + training/dpmert/divide_refs.py | 15 + training/dpmert/dpmert.pl | 618 +++++++++ training/dpmert/error_surface.cc | 42 + training/dpmert/error_surface.h | 24 + training/dpmert/line_mediator.pl | 116 ++ training/dpmert/line_optimizer.cc | 114 ++ training/dpmert/line_optimizer.h | 48 + training/dpmert/lo_test.cc | 229 +++ training/dpmert/mert_geometry.cc | 185 +++ training/dpmert/mert_geometry.h | 81 ++ training/dpmert/mr_dpmert_generate_mapper_input.cc | 81 ++ training/dpmert/mr_dpmert_map.cc | 112 ++ training/dpmert/mr_dpmert_reduce.cc | 77 ++ training/dpmert/test_aer/README | 8 + training/dpmert/test_aer/cdec.ini | 3 + training/dpmert/test_aer/corpus.src | 3 + training/dpmert/test_aer/grammar | 12 + training/dpmert/test_aer/ref.0 | 3 + training/dpmert/test_aer/weights | 13 + training/dpmert/test_data/0.json.gz | Bin 0 -> 13709 bytes training/dpmert/test_data/1.json.gz | Bin 0 -> 204803 bytes training/dpmert/test_data/c2e.txt.0 | 2 + training/dpmert/test_data/c2e.txt.1 | 2 + training/dpmert/test_data/c2e.txt.2 | 2 + training/dpmert/test_data/c2e.txt.3 | 2 + training/dpmert/test_data/re.txt.0 | 5 + training/dpmert/test_data/re.txt.1 | 5 + training/dpmert/test_data/re.txt.2 | 5 + training/dpmert/test_data/re.txt.3 | 5 + training/dtrain/Makefile.am | 7 + training/dtrain/README.md | 48 + training/dtrain/dtrain.cc | 657 +++++++++ training/dtrain/dtrain.h | 97 ++ training/dtrain/hstreaming/avg.rb | 32 + training/dtrain/hstreaming/cdec.ini | 22 + training/dtrain/hstreaming/dtrain.ini | 15 + training/dtrain/hstreaming/dtrain.sh | 9 + training/dtrain/hstreaming/hadoop-streaming-job.sh | 30 + training/dtrain/hstreaming/lplp.rb | 131 ++ training/dtrain/hstreaming/red-test | 9 + training/dtrain/kbestget.h | 152 ++ training/dtrain/ksampler.h | 61 + training/dtrain/pairsampling.h | 149 ++ training/dtrain/parallelize.rb | 79 ++ training/dtrain/parallelize/test/cdec.ini | 22 + training/dtrain/parallelize/test/dtrain.ini | 15 + training/dtrain/parallelize/test/in | 10 + training/dtrain/parallelize/test/refs | 10 + training/dtrain/score.cc | 254 ++++ training/dtrain/score.h | 212 +++ training/dtrain/test/example/README | 8 + training/dtrain/test/example/cdec.ini | 25 + training/dtrain/test/example/dtrain.ini | 22 + training/dtrain/test/example/expected-output | 89 ++ training/dtrain/test/parallelize/cdec.ini | 22 + training/dtrain/test/parallelize/dtrain.ini | 15 + training/dtrain/test/parallelize/in | 10 + training/dtrain/test/parallelize/refs | 10 + training/dtrain/test/toy/cdec.ini | 2 + training/dtrain/test/toy/dtrain.ini | 12 + training/dtrain/test/toy/input | 2 + training/entropy.cc | 41 - training/entropy.h | 22 - training/fast_align.cc | 281 ---- training/feature_expectations.cc | 232 ---- training/grammar_convert.cc | 348 ----- training/lbfgs.h | 1459 -------------------- training/lbfgs_test.cc | 117 -- training/lbl_model.cc | 421 ------ training/minrisk/Makefile.am | 6 + training/minrisk/minrisk.pl | 540 ++++++++ training/minrisk/minrisk_generate_input.pl | 18 + training/minrisk/minrisk_optimize.cc | 197 +++ training/mira/Makefile.am | 6 + training/mira/kbest_mira.cc | 309 +++++ training/mpi_batch_optimize.cc | 372 ----- training/mpi_compute_cllh.cc | 134 -- training/mpi_em_optimize.cc | 389 ------ training/mpi_extract_features.cc | 151 -- training/mpi_extract_reachable.cc | 163 --- training/mpi_flex_optimize.cc | 386 ------ training/mpi_online_optimize.cc | 374 ----- training/mr_em_adapted_reduce.cc | 173 --- training/mr_em_map_adapter.cc | 160 --- training/mr_optimize_reduce.cc | 231 ---- training/mr_reduce_to_weights.cc | 109 -- training/online_optimizer.cc | 16 - training/online_optimizer.h | 129 -- training/optimize.cc | 102 -- training/optimize.h | 92 -- training/optimize_test.cc | 118 -- training/pro/Makefile.am | 11 + training/pro/mr_pro_generate_mapper_input.pl | 18 + training/pro/mr_pro_map.cc | 201 +++ training/pro/mr_pro_reduce.cc | 286 ++++ training/pro/pro.pl | 555 ++++++++ training/rampion/Makefile.am | 6 + training/rampion/rampion.pl | 540 ++++++++ training/rampion/rampion_cccp.cc | 168 +++ training/rampion/rampion_generate_input.pl | 18 + training/risk.cc | 45 - training/risk.h | 26 - training/ttables.cc | 31 - training/ttables.h | 101 -- training/utils/candidate_set.cc | 169 +++ training/utils/candidate_set.h | 60 + training/utils/decode-and-evaluate.pl | 246 ++++ training/utils/entropy.cc | 41 + training/utils/entropy.h | 22 + training/utils/grammar_convert.cc | 348 +++++ training/utils/lbfgs.h | 1459 ++++++++++++++++++++ training/utils/lbfgs_test.cc | 117 ++ training/utils/libcall.pl | 71 + training/utils/online_optimizer.cc | 16 + training/utils/online_optimizer.h | 129 ++ training/utils/optimize.cc | 102 ++ training/utils/optimize.h | 92 ++ training/utils/optimize_test.cc | 118 ++ training/utils/parallelize.pl | 423 ++++++ training/utils/risk.cc | 45 + training/utils/risk.h | 26 + training/utils/sentclient.c | 76 + training/utils/sentserver.c | 515 +++++++ training/utils/sentserver.h | 6 + word-aligner/Makefile.am | 6 + word-aligner/fast_align.cc | 281 ++++ word-aligner/makefiles/makefile.grammars | 2 +- word-aligner/paste-parallel-files.pl | 35 - word-aligner/ttables.cc | 31 + word-aligner/ttables.h | 101 ++ 242 files changed, 13304 insertions(+), 15426 deletions(-) delete mode 100644 dpmert/Makefile.am delete mode 100644 dpmert/README.shared-mem delete mode 100644 dpmert/ces.cc delete mode 100644 dpmert/ces.h delete mode 100755 dpmert/decode-and-evaluate.pl delete mode 100755 dpmert/divide_refs.py delete mode 100755 dpmert/dpmert.pl delete mode 100644 dpmert/error_surface.cc delete mode 100644 dpmert/error_surface.h delete mode 100644 dpmert/libcall.pl delete mode 100755 dpmert/line_mediator.pl delete mode 100644 dpmert/line_optimizer.cc delete mode 100644 dpmert/line_optimizer.h delete mode 100644 dpmert/lo_test.cc delete mode 100644 dpmert/mert_geometry.cc delete mode 100644 dpmert/mert_geometry.h delete mode 100644 dpmert/mr_dpmert_generate_mapper_input.cc delete mode 100644 dpmert/mr_dpmert_map.cc delete mode 100644 dpmert/mr_dpmert_reduce.cc delete mode 100755 dpmert/parallelize.pl delete mode 100644 dpmert/sentclient.c delete mode 100644 dpmert/sentserver.c delete mode 100644 dpmert/sentserver.h delete mode 100644 dpmert/test_aer/README delete mode 100644 dpmert/test_aer/cdec.ini delete mode 100644 dpmert/test_aer/corpus.src delete mode 100644 dpmert/test_aer/grammar delete mode 100644 dpmert/test_aer/ref.0 delete mode 100644 dpmert/test_aer/weights delete mode 100644 dpmert/test_data/0.json.gz delete mode 100644 dpmert/test_data/1.json.gz delete mode 100644 dpmert/test_data/c2e.txt.0 delete mode 100644 dpmert/test_data/c2e.txt.1 delete mode 100644 dpmert/test_data/c2e.txt.2 delete mode 100644 dpmert/test_data/c2e.txt.3 delete mode 100644 dpmert/test_data/re.txt.0 delete mode 100644 dpmert/test_data/re.txt.1 delete mode 100644 dpmert/test_data/re.txt.2 delete mode 100644 dpmert/test_data/re.txt.3 delete mode 100644 dtrain/Makefile.am delete mode 100644 dtrain/README.md delete mode 100644 dtrain/dtrain.cc delete mode 100644 dtrain/dtrain.h delete mode 100755 dtrain/hstreaming/avg.rb delete mode 100644 dtrain/hstreaming/cdec.ini delete mode 100644 dtrain/hstreaming/dtrain.ini delete mode 100755 dtrain/hstreaming/dtrain.sh delete mode 100755 dtrain/hstreaming/hadoop-streaming-job.sh delete mode 100755 dtrain/hstreaming/lplp.rb delete mode 100644 dtrain/hstreaming/red-test delete mode 100644 dtrain/kbestget.h delete mode 100644 dtrain/ksampler.h delete mode 100644 dtrain/pairsampling.h delete mode 100755 dtrain/parallelize.rb delete mode 100644 dtrain/parallelize/test/cdec.ini delete mode 100644 dtrain/parallelize/test/dtrain.ini delete mode 100644 dtrain/parallelize/test/in delete mode 100644 dtrain/parallelize/test/refs delete mode 100644 dtrain/score.cc delete mode 100644 dtrain/score.h delete mode 100644 dtrain/test/example/README delete mode 100644 dtrain/test/example/cdec.ini delete mode 100644 dtrain/test/example/dtrain.ini delete mode 100644 dtrain/test/example/expected-output delete mode 100644 dtrain/test/parallelize/cdec.ini delete mode 100644 dtrain/test/parallelize/dtrain.ini delete mode 100644 dtrain/test/parallelize/in delete mode 100644 dtrain/test/parallelize/refs delete mode 100644 dtrain/test/toy/cdec.ini delete mode 100644 dtrain/test/toy/dtrain.ini delete mode 100644 dtrain/test/toy/input delete mode 100644 minrisk/Makefile.am delete mode 100755 minrisk/minrisk.pl delete mode 100755 minrisk/minrisk_generate_input.pl delete mode 100644 minrisk/minrisk_optimize.cc delete mode 100644 mira/Makefile.am delete mode 100644 mira/kbest_mira.cc delete mode 100644 pro/Makefile.am delete mode 100644 pro/README.shared-mem delete mode 100755 pro/mr_pro_generate_mapper_input.pl delete mode 100644 pro/mr_pro_map.cc delete mode 100644 pro/mr_pro_reduce.cc delete mode 100755 pro/pro.pl delete mode 100644 rampion/Makefile.am delete mode 100755 rampion/rampion.pl delete mode 100644 rampion/rampion_cccp.cc delete mode 100755 rampion/rampion_generate_input.pl delete mode 100755 training/add-model1-features-to-scfg.pl delete mode 100644 training/candidate_set.cc delete mode 100644 training/candidate_set.h delete mode 100644 training/cllh_observer.cc delete mode 100644 training/cllh_observer.h delete mode 100644 training/collapse_weights.cc create mode 100644 training/crf/Makefile.am create mode 100644 training/crf/cllh_observer.cc create mode 100644 training/crf/cllh_observer.h create mode 100644 training/crf/mpi_batch_optimize.cc create mode 100644 training/crf/mpi_compute_cllh.cc create mode 100644 training/crf/mpi_extract_features.cc create mode 100644 training/crf/mpi_extract_reachable.cc create mode 100644 training/crf/mpi_flex_optimize.cc create mode 100644 training/crf/mpi_online_optimize.cc delete mode 100755 training/dep-reorder/conll2reordering-forest.pl delete mode 100644 training/dep-reorder/george.conll delete mode 100755 training/dep-reorder/scripts/conll2simplecfg.pl create mode 100644 training/dpmert/Makefile.am create mode 100644 training/dpmert/ces.cc create mode 100644 training/dpmert/ces.h create mode 100755 training/dpmert/divide_refs.py create mode 100755 training/dpmert/dpmert.pl create mode 100644 training/dpmert/error_surface.cc create mode 100644 training/dpmert/error_surface.h create mode 100755 training/dpmert/line_mediator.pl create mode 100644 training/dpmert/line_optimizer.cc create mode 100644 training/dpmert/line_optimizer.h create mode 100644 training/dpmert/lo_test.cc create mode 100644 training/dpmert/mert_geometry.cc create mode 100644 training/dpmert/mert_geometry.h create mode 100644 training/dpmert/mr_dpmert_generate_mapper_input.cc create mode 100644 training/dpmert/mr_dpmert_map.cc create mode 100644 training/dpmert/mr_dpmert_reduce.cc create mode 100644 training/dpmert/test_aer/README create mode 100644 training/dpmert/test_aer/cdec.ini create mode 100644 training/dpmert/test_aer/corpus.src create mode 100644 training/dpmert/test_aer/grammar create mode 100644 training/dpmert/test_aer/ref.0 create mode 100644 training/dpmert/test_aer/weights create mode 100644 training/dpmert/test_data/0.json.gz create mode 100644 training/dpmert/test_data/1.json.gz create mode 100644 training/dpmert/test_data/c2e.txt.0 create mode 100644 training/dpmert/test_data/c2e.txt.1 create mode 100644 training/dpmert/test_data/c2e.txt.2 create mode 100644 training/dpmert/test_data/c2e.txt.3 create mode 100644 training/dpmert/test_data/re.txt.0 create mode 100644 training/dpmert/test_data/re.txt.1 create mode 100644 training/dpmert/test_data/re.txt.2 create mode 100644 training/dpmert/test_data/re.txt.3 create mode 100644 training/dtrain/Makefile.am create mode 100644 training/dtrain/README.md create mode 100644 training/dtrain/dtrain.cc create mode 100644 training/dtrain/dtrain.h create mode 100755 training/dtrain/hstreaming/avg.rb create mode 100644 training/dtrain/hstreaming/cdec.ini create mode 100644 training/dtrain/hstreaming/dtrain.ini create mode 100755 training/dtrain/hstreaming/dtrain.sh create mode 100755 training/dtrain/hstreaming/hadoop-streaming-job.sh create mode 100755 training/dtrain/hstreaming/lplp.rb create mode 100644 training/dtrain/hstreaming/red-test create mode 100644 training/dtrain/kbestget.h create mode 100644 training/dtrain/ksampler.h create mode 100644 training/dtrain/pairsampling.h create mode 100755 training/dtrain/parallelize.rb create mode 100644 training/dtrain/parallelize/test/cdec.ini create mode 100644 training/dtrain/parallelize/test/dtrain.ini create mode 100644 training/dtrain/parallelize/test/in create mode 100644 training/dtrain/parallelize/test/refs create mode 100644 training/dtrain/score.cc create mode 100644 training/dtrain/score.h create mode 100644 training/dtrain/test/example/README create mode 100644 training/dtrain/test/example/cdec.ini create mode 100644 training/dtrain/test/example/dtrain.ini create mode 100644 training/dtrain/test/example/expected-output create mode 100644 training/dtrain/test/parallelize/cdec.ini create mode 100644 training/dtrain/test/parallelize/dtrain.ini create mode 100644 training/dtrain/test/parallelize/in create mode 100644 training/dtrain/test/parallelize/refs create mode 100644 training/dtrain/test/toy/cdec.ini create mode 100644 training/dtrain/test/toy/dtrain.ini create mode 100644 training/dtrain/test/toy/input delete mode 100644 training/entropy.cc delete mode 100644 training/entropy.h delete mode 100644 training/fast_align.cc delete mode 100644 training/feature_expectations.cc delete mode 100644 training/grammar_convert.cc delete mode 100644 training/lbfgs.h delete mode 100644 training/lbfgs_test.cc delete mode 100644 training/lbl_model.cc create mode 100644 training/minrisk/Makefile.am create mode 100755 training/minrisk/minrisk.pl create mode 100755 training/minrisk/minrisk_generate_input.pl create mode 100644 training/minrisk/minrisk_optimize.cc create mode 100644 training/mira/Makefile.am create mode 100644 training/mira/kbest_mira.cc delete mode 100644 training/mpi_batch_optimize.cc delete mode 100644 training/mpi_compute_cllh.cc delete mode 100644 training/mpi_em_optimize.cc delete mode 100644 training/mpi_extract_features.cc delete mode 100644 training/mpi_extract_reachable.cc delete mode 100644 training/mpi_flex_optimize.cc delete mode 100644 training/mpi_online_optimize.cc delete mode 100644 training/mr_em_adapted_reduce.cc delete mode 100644 training/mr_em_map_adapter.cc delete mode 100644 training/mr_optimize_reduce.cc delete mode 100644 training/mr_reduce_to_weights.cc delete mode 100644 training/online_optimizer.cc delete mode 100644 training/online_optimizer.h delete mode 100644 training/optimize.cc delete mode 100644 training/optimize.h delete mode 100644 training/optimize_test.cc create mode 100644 training/pro/Makefile.am create mode 100755 training/pro/mr_pro_generate_mapper_input.pl create mode 100644 training/pro/mr_pro_map.cc create mode 100644 training/pro/mr_pro_reduce.cc create mode 100755 training/pro/pro.pl create mode 100644 training/rampion/Makefile.am create mode 100755 training/rampion/rampion.pl create mode 100644 training/rampion/rampion_cccp.cc create mode 100755 training/rampion/rampion_generate_input.pl delete mode 100644 training/risk.cc delete mode 100644 training/risk.h delete mode 100644 training/ttables.cc delete mode 100644 training/ttables.h create mode 100644 training/utils/candidate_set.cc create mode 100644 training/utils/candidate_set.h create mode 100755 training/utils/decode-and-evaluate.pl create mode 100644 training/utils/entropy.cc create mode 100644 training/utils/entropy.h create mode 100644 training/utils/grammar_convert.cc create mode 100644 training/utils/lbfgs.h create mode 100644 training/utils/lbfgs_test.cc create mode 100644 training/utils/libcall.pl create mode 100644 training/utils/online_optimizer.cc create mode 100644 training/utils/online_optimizer.h create mode 100644 training/utils/optimize.cc create mode 100644 training/utils/optimize.h create mode 100644 training/utils/optimize_test.cc create mode 100755 training/utils/parallelize.pl create mode 100644 training/utils/risk.cc create mode 100644 training/utils/risk.h create mode 100644 training/utils/sentclient.c create mode 100644 training/utils/sentserver.c create mode 100644 training/utils/sentserver.h create mode 100644 word-aligner/Makefile.am create mode 100644 word-aligner/fast_align.cc delete mode 100755 word-aligner/paste-parallel-files.pl create mode 100644 word-aligner/ttables.cc create mode 100644 word-aligner/ttables.h (limited to 'configure.ac') diff --git a/.gitignore b/.gitignore index aa2e64eb..c6023822 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +example_extff/ff_example.lo +example_extff/libff_example.la +mteval/meteor_jar.cc *.a *.aux *.bbl @@ -176,4 +179,27 @@ utils/reconstruct_weights utils/small_vector_test utils/ts utils/weights_test -utils/unigram_pyp_lm +training/crf/mpi_batch_optimize +training/crf/mpi_compute_cllh +training/crf/mpi_extract_features +training/crf/mpi_extract_reachable +training/crf/mpi_flex_optimize +training/crf/mpi_online_optimize +training/dpmert/lo_test +training/dpmert/mr_dpmert_generate_mapper_input +training/dpmert/mr_dpmert_map +training/dpmert/mr_dpmert_reduce +training/dpmert/sentclient +training/dpmert/sentserver +training/dtrain/dtrain +training/minrisk/minrisk_optimize +training/mira/kbest_mira +training/pro/mr_pro_map +training/pro/mr_pro_reduce +training/rampion/rampion_cccp +training/utils/Makefile.am +training/utils/lbfgs_test +training/utils/optimize_test +training/utils/sentclient +training/utils/sentserver +word-aligner/fast_align diff --git a/Makefile.am b/Makefile.am index 7ca7268a..dbf604a1 100644 --- a/Makefile.am +++ b/Makefile.am @@ -10,12 +10,7 @@ SUBDIRS = \ decoder \ training \ training/liblbfgs \ - mira \ - dtrain \ - dpmert \ - pro \ - rampion \ - minrisk \ + word-aligner \ example_extff #gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava diff --git a/configure.ac b/configure.ac index 09fc5c5b..366112a3 100644 --- a/configure.ac +++ b/configure.ac @@ -82,26 +82,34 @@ AC_PROG_INSTALL CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H" +# core cdec stuff AC_CONFIG_FILES([Makefile]) AC_CONFIG_FILES([utils/Makefile]) AC_CONFIG_FILES([mteval/Makefile]) +AC_CONFIG_FILES([mteval/meteor_jar.cc]) AC_CONFIG_FILES([decoder/Makefile]) -AC_CONFIG_FILES([training/Makefile]) -AC_CONFIG_FILES([training/liblbfgs/Makefile]) -AC_CONFIG_FILES([dpmert/Makefile]) -AC_CONFIG_FILES([pro/Makefile]) -AC_CONFIG_FILES([rampion/Makefile]) -AC_CONFIG_FILES([minrisk/Makefile]) +AC_CONFIG_FILES([python/setup.py]) +AC_CONFIG_FILES([word-aligner/Makefile]) + +# KenLM stuff AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) AC_CONFIG_FILES([klm/search/Makefile]) -AC_CONFIG_FILES([mira/Makefile]) -AC_CONFIG_FILES([dtrain/Makefile]) -AC_CONFIG_FILES([example_extff/Makefile]) -AC_CONFIG_FILES([mteval/meteor_jar.cc]) - -AC_CONFIG_FILES([python/setup.py]) +# training stuff +AC_CONFIG_FILES([training/Makefile]) +AC_CONFIG_FILES([training/utils/Makefile]) +AC_CONFIG_FILES([training/liblbfgs/Makefile]) +AC_CONFIG_FILES([training/crf/Makefile]) +AC_CONFIG_FILES([training/dpmert/Makefile]) +AC_CONFIG_FILES([training/pro/Makefile]) +AC_CONFIG_FILES([training/rampion/Makefile]) +AC_CONFIG_FILES([training/minrisk/Makefile]) +AC_CONFIG_FILES([training/mira/Makefile]) +AC_CONFIG_FILES([training/dtrain/Makefile]) + +# external feature function example code +AC_CONFIG_FILES([example_extff/Makefile]) AC_OUTPUT diff --git a/dpmert/Makefile.am b/dpmert/Makefile.am deleted file mode 100644 index 00768271..00000000 --- a/dpmert/Makefile.am +++ /dev/null @@ -1,33 +0,0 @@ -bin_PROGRAMS = \ - mr_dpmert_map \ - mr_dpmert_reduce \ - mr_dpmert_generate_mapper_input \ - sentserver \ - sentclient - -noinst_PROGRAMS = \ - lo_test -TESTS = lo_test - -sentserver_SOURCES = sentserver.c -sentserver_LDFLAGS = -pthread - -sentclient_SOURCES = sentclient.c -sentclient_LDFLAGS = -pthread - -mr_dpmert_generate_mapper_input_SOURCES = mr_dpmert_generate_mapper_input.cc line_optimizer.cc -mr_dpmert_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -# nbest2hg_SOURCES = nbest2hg.cc -# nbest2hg_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lfst -lz - -mr_dpmert_map_SOURCES = mert_geometry.cc ces.cc error_surface.cc mr_dpmert_map.cc line_optimizer.cc -mr_dpmert_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -mr_dpmert_reduce_SOURCES = error_surface.cc ces.cc mr_dpmert_reduce.cc line_optimizer.cc mert_geometry.cc -mr_dpmert_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -lo_test_SOURCES = lo_test.cc ces.cc mert_geometry.cc error_surface.cc line_optimizer.cc -lo_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/dpmert/README.shared-mem b/dpmert/README.shared-mem deleted file mode 100644 index 7728efc0..00000000 --- a/dpmert/README.shared-mem +++ /dev/null @@ -1,9 +0,0 @@ -If you want to run dist-vest.pl on a very large shared memory machine, do the -following: - - ./dist-vest.pl --use-make I --decode-nodes J --weights weights.init --source-file=dev.src --ref-files=dev.ref.* cdec.ini - -This will use I jobs for doing the line search and J jobs to run the decoder. Typically, since the -decoder must load grammars, language models, etc., J should be smaller than I, but this will depend -on the system you are running on and the complexity of the models used for decoding. - diff --git a/dpmert/ces.cc b/dpmert/ces.cc deleted file mode 100644 index 157b2d17..00000000 --- a/dpmert/ces.cc +++ /dev/null @@ -1,90 +0,0 @@ -#include "ces.h" - -#include -#include -#include - -// TODO, if AER is to be optimized again, we will need this -// #include "aligner.h" -#include "lattice.h" -#include "mert_geometry.h" -#include "error_surface.h" -#include "ns.h" - -using namespace std; - -const bool minimize_segments = true; // if adjacent segments have equal scores, merge them - -void ComputeErrorSurface(const SegmentEvaluator& ss, - const ConvexHull& ve, - ErrorSurface* env, - const EvaluationMetric* metric, - const Hypergraph& hg) { - vector prev_trans; - const vector >& ienv = ve.GetSortedSegs(); - env->resize(ienv.size()); - SufficientStats prev_score; // defaults to 0 - int j = 0; - for (unsigned i = 0; i < ienv.size(); ++i) { - const MERTPoint& seg = *ienv[i]; - vector trans; -#if 0 - if (type == AER) { - vector edges(hg.edges_.size(), false); - seg.CollectEdgesUsed(&edges); // get the set of edges in the viterbi - // alignment - ostringstream os; - const string* psrc = ss.GetSource(); - if (psrc == NULL) { - cerr << "AER scoring in VEST requires source, but it is missing!\n"; - abort(); - } - size_t pos = psrc->rfind(" ||| "); - if (pos == string::npos) { - cerr << "Malformed source for AER: expected |||\nINPUT: " << *psrc << endl; - abort(); - } - Lattice src; - Lattice ref; - LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src); - LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref); - AlignerTools::WriteAlignment(src, ref, hg, &os, true, 0, &edges); - string tstr = os.str(); - TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans); - } else { -#endif - seg.ConstructTranslation(&trans); - //} - //cerr << "Scoring: " << TD::GetString(trans) << endl; - if (trans == prev_trans) { - if (!minimize_segments) { - ErrorSegment& out = (*env)[j]; - out.delta.fields.clear(); - out.x = seg.x; - ++j; - } - //cerr << "Identical translation, skipping scoring\n"; - } else { - SufficientStats score; - ss.Evaluate(trans, &score); - // cerr << "score= " << score->ComputeScore() << "\n"; - //string x1; score.Encode(&x1); cerr << "STATS: " << x1 << endl; - const SufficientStats delta = score - prev_score; - //string x2; delta.Encode(&x2); cerr << "DELTA: " << x2 << endl; - //string xx; delta.Encode(&xx); cerr << xx << endl; - prev_trans.swap(trans); - prev_score = score; - if ((!minimize_segments) || (!delta.IsAdditiveIdentity())) { - ErrorSegment& out = (*env)[j]; - out.delta = delta; - out.x = seg.x; - ++j; - } - } - } - // cerr << " In segments: " << ienv.size() << endl; - // cerr << "Out segments: " << j << endl; - assert(j > 0); - env->resize(j); -} - diff --git a/dpmert/ces.h b/dpmert/ces.h deleted file mode 100644 index e4fa2080..00000000 --- a/dpmert/ces.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef _CES_H_ -#define _CES_H_ - -class ConvexHull; -class Hypergraph; -class SegmentEvaluator; -class ErrorSurface; -class EvaluationMetric; - -void ComputeErrorSurface(const SegmentEvaluator& ss, - const ConvexHull& convex_hull, - ErrorSurface* es, - const EvaluationMetric* metric, - const Hypergraph& hg); - -#endif diff --git a/dpmert/decode-and-evaluate.pl b/dpmert/decode-and-evaluate.pl deleted file mode 100755 index fe765d00..00000000 --- a/dpmert/decode-and-evaluate.pl +++ /dev/null @@ -1,246 +0,0 @@ -#!/usr/bin/env perl -use strict; -my @ORIG_ARGV=@ARGV; -use Cwd qw(getcwd); -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } - -# Skip local config (used for distributing jobs) if we're running in local-only mode -use LocalConfig; -use Getopt::Long; -use File::Basename qw(basename); -my $QSUB_CMD = qsub_args(mert_memory()); - -require "libcall.pl"; - -# Default settings -my $default_jobs = env_default_jobs(); -my $bin_dir = $SCRIPT_DIR; -die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; -my $FAST_SCORE="$bin_dir/../mteval/fast_score"; -die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $parallelize = "$bin_dir/parallelize.pl"; -my $libcall = "$bin_dir/libcall.pl"; -my $sentserver = "$bin_dir/sentserver"; -my $sentclient = "$bin_dir/sentclient"; -my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; - -my $SCORER = $FAST_SCORE; -my $cdec = "$bin_dir/../decoder/cdec"; -die "Can't find decoder in $cdec" unless -x $cdec; -die "Can't find $parallelize" unless -x $parallelize; -die "Can't find $libcall" unless -e $libcall; -my $decoder = $cdec; -my $jobs = $default_jobs; # number of decode nodes -my $pmem = "9g"; -my $help = 0; -my $config; -my $test_set; -my $weights; -my $use_make = 1; -my $useqsub; -my $cpbin=1; -# Process command-line options -if (GetOptions( - "jobs=i" => \$jobs, - "help" => \$help, - "qsub" => \$useqsub, - "input=s" => \$test_set, - "config=s" => \$config, - "weights=s" => \$weights, -) == 0 || @ARGV!=0 || $help) { - print_help(); - exit; -} - -if ($useqsub) { - $use_make = 0; - die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); -} - -my @missing_args = (); - -if (!defined $test_set) { push @missing_args, "--input"; } -if (!defined $config) { push @missing_args, "--config"; } -if (!defined $weights) { push @missing_args, "--weights"; } -die "Please specify missing arguments: " . join (', ', @missing_args) . "\nUse --help for more information.\n" if (@missing_args); - -my @tf = localtime(time); -my $tname = basename($test_set); -$tname =~ s/\.(sgm|sgml|xml)$//i; -my $dir = "eval.$tname." . sprintf('%d%02d%02d-%02d%02d%02d', 1900+$tf[5], $tf[4], $tf[3], $tf[2], $tf[1], $tf[0]); - -my $time = unchecked_output("date"); - -check_call("mkdir -p $dir"); - -split_devset($test_set, "$dir/test.input.raw", "$dir/test.refs"); -my $refs = "-r $dir/test.refs"; -my $newsrc = "$dir/test.input"; -enseg("$dir/test.input.raw", $newsrc); -my $src_file = $newsrc; -open F, "<$src_file" or die "Can't read $src_file: $!"; close F; - -my $test_trans="$dir/test.trans"; -my $logdir="$dir/logs"; -my $decoderLog="$logdir/decoder.sentserver.log"; -check_call("mkdir -p $logdir"); - -#decode -print STDERR "RUNNING DECODER AT "; -print STDERR unchecked_output("date"); -my $decoder_cmd = "$decoder -c $config --weights $weights"; -my $pcmd; -if ($use_make) { - $pcmd = "cat $src_file | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; -} else { - $pcmd = "cat $src_file | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; -} -my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $test_trans"; -check_bash_call($cmd); -print STDERR "DECODER COMPLETED AT "; -print STDERR unchecked_output("date"); -print STDERR "\nOUTPUT: $test_trans\n\n"; -my $bleu = check_output("cat $test_trans | $SCORER $refs -m ibm_bleu"); -chomp $bleu; -print STDERR "BLEU: $bleu\n"; -my $ter = check_output("cat $test_trans | $SCORER $refs -m ter"); -chomp $ter; -print STDERR " TER: $ter\n"; -open TR, ">$dir/test.scores" or die "Can't write $dir/test.scores: $!"; -print TR <$newsrc"); - my $i=0; - while (my $line=){ - chomp $line; - if ($line =~ /^\s* tags, you must include a zero-based id attribute"; - } - } else { - print NEWSRC "$line\n"; - } - $i++; - } - close SRC; - close NEWSRC; -} - -sub print_help { - my $executable = basename($0); chomp $executable; - print << "Help"; - -Usage: $executable [options] - - $executable --config cdec.ini --weights weights.txt [--jobs N] [--qsub] - -Options: - - --help - Print this message and exit. - - --config - A path to the cdec.ini file. - - --weights - A file specifying feature weights. - - --dir - Directory for intermediate and output files. - -Job control options: - - --jobs - Number of decoder processes to run in parallel. [default=$default_jobs] - - --qsub - Use qsub to run jobs in parallel (qsub must be configured in - environment/LocalEnvironment.pm) - - --pmem - Amount of physical memory requested for parallel decoding jobs - (used with qsub requests only) - -Help -} - -sub convert { - my ($str) = @_; - my @ps = split /;/, $str; - my %dict = (); - for my $p (@ps) { - my ($k, $v) = split /=/, $p; - $dict{$k} = $v; - } - return %dict; -} - - - -sub cmdline { - return join ' ',($0,@ORIG_ARGV); -} - -#buggy: last arg gets quoted sometimes? -my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; -my $shell_escape_in_quote=qr{[\\"\$`!]}; - -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} - -sub escaped_shell_args { - return map {local $_=$_;chomp;escape_shell($_)} @_; -} - -sub escaped_shell_args_str { - return join ' ',&escaped_shell_args(@_); -} - -sub escaped_cmdline { - return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); -} - -sub split_devset { - my ($infile, $outsrc, $outref) = @_; - open F, "<$infile" or die "Can't read $infile: $!"; - open S, ">$outsrc" or die "Can't write $outsrc: $!"; - open R, ">$outref" or die "Can't write $outref: $!"; - while() { - chomp; - my ($src, @refs) = split /\s*\|\|\|\s*/; - die "Malformed devset line: $_\n" unless scalar @refs > 0; - print S "$src\n"; - print R join(' ||| ', @refs) . "\n"; - } - close R; - close S; - close F; -} - diff --git a/dpmert/divide_refs.py b/dpmert/divide_refs.py deleted file mode 100755 index b478f918..00000000 --- a/dpmert/divide_refs.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python -import sys - -(numRefs, outPrefix) = sys.argv[1:] -numRefs = int(numRefs) - -outs = [open(outPrefix+str(i), "w") for i in range(numRefs)] - -i = 0 -for line in sys.stdin: - outs[i].write(line) - i = (i + 1) % numRefs - -for out in outs: - out.close() diff --git a/dpmert/dpmert.pl b/dpmert/dpmert.pl deleted file mode 100755 index c4f98870..00000000 --- a/dpmert/dpmert.pl +++ /dev/null @@ -1,617 +0,0 @@ -#!/usr/bin/env perl -use strict; -my @ORIG_ARGV=@ARGV; -use Cwd qw(getcwd); -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } - -# Skip local config (used for distributing jobs) if we're running in local-only mode -use LocalConfig; -use Getopt::Long; -use File::Basename qw(basename); -require "libcall.pl"; - -my $QSUB_CMD = qsub_args(mert_memory()); - -# Default settings -my $srcFile; # deprecated -my $refFiles; # deprecated -my $default_jobs = env_default_jobs(); -my $bin_dir = $SCRIPT_DIR; -die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; -my $FAST_SCORE="$bin_dir/../mteval/fast_score"; -die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $MAPINPUT = "$bin_dir/mr_dpmert_generate_mapper_input"; -my $MAPPER = "$bin_dir/mr_dpmert_map"; -my $REDUCER = "$bin_dir/mr_dpmert_reduce"; -my $parallelize = "$bin_dir/parallelize.pl"; -my $libcall = "$bin_dir/libcall.pl"; -my $sentserver = "$bin_dir/sentserver"; -my $sentclient = "$bin_dir/sentclient"; -my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; - -my $SCORER = $FAST_SCORE; -die "Can't find $MAPPER" unless -x $MAPPER; -my $cdec = "$bin_dir/../decoder/cdec"; -die "Can't find decoder in $cdec" unless -x $cdec; -die "Can't find $parallelize" unless -x $parallelize; -die "Can't find $libcall" unless -e $libcall; -my $decoder = $cdec; -my $lines_per_mapper = 200; -my $rand_directions = 15; -my $iteration = 1; -my $best_weights; -my $max_iterations = 15; -my $optimization_iters = 6; -my $jobs = $default_jobs; # number of decode nodes -my $pmem = "9g"; -my $disable_clean = 0; -my %seen_weights; -my $help = 0; -my $epsilon = 0.0001; -my $last_score = -10000000; -my $metric = "ibm_bleu"; -my $dir; -my $iniFile; -my $weights; -my $initialWeights; -my $bleu_weight=1; -my $use_make = 1; # use make to parallelize line search -my $useqsub; -my $pass_suffix = ''; -my $devset; -# Process command-line options -if (GetOptions( - "config=s" => \$iniFile, - "weights=s" => \$initialWeights, - "devset=s" => \$devset, - "jobs=i" => \$jobs, - "pass-suffix=s" => \$pass_suffix, - "help" => \$help, - "qsub" => \$useqsub, - "iterations=i" => \$max_iterations, - "pmem=s" => \$pmem, - "random-directions=i" => \$rand_directions, - "metric=s" => \$metric, - "source-file=s" => \$srcFile, - "output-dir=s" => \$dir, -) == 0 || @ARGV!=0 || $help) { - print_help(); - exit; -} - -if ($useqsub) { - $use_make = 0; - die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); -} - -my @missing_args = (); -if (defined $srcFile || defined $refFiles) { - die <) { $devSize++; } -close F; - -unless($best_weights){ $best_weights = $weights; } -unless($projected_score){ $projected_score = 0.0; } -$seen_weights{$weights} = 1; - -my $random_seed = int(time / 1000); -my $lastWeightsFile; -my $lastPScore = 0; -# main optimization loop -while (1){ - print STDERR "\n\nITERATION $iteration\n==========\n"; - - if ($iteration > $max_iterations){ - print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; - last; - } - # iteration-specific files - my $runFile="$dir/run.raw.$iteration"; - my $onebestFile="$dir/1best.$iteration"; - my $logdir="$dir/logs.$iteration"; - my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; - my $scorerLog="$logdir/scorer.log.$iteration"; - check_call("mkdir -p $logdir"); - - - #decode - print STDERR "RUNNING DECODER AT "; - print STDERR unchecked_output("date"); - my $im1 = $iteration - 1; - my $weightsFile="$dir/weights.$im1"; - my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - my $pcmd; - if ($use_make) { - $pcmd = "cat $srcFile | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; - } else { - $pcmd = "cat $srcFile | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; - } - my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - my $num_hgs; - my $num_topbest; - my $retries = 0; - while($retries < 5) { - $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); - $num_topbest = check_output("wc -l < $runFile"); - print STDERR "NUMBER OF HGs: $num_hgs\n"; - print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; - if($devSize == $num_hgs && $devSize == $num_topbest) { - last; - } else { - print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; - sleep(3); - } - $retries++; - } - die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); - my $dec_score = check_output("cat $runFile | $SCORER $refs -m $metric"); - chomp $dec_score; - print STDERR "DECODER SCORE: $dec_score\n"; - - # save space - check_call("gzip -f $runFile"); - check_call("gzip -f $decoderLog"); - - # run optimizer - print STDERR "RUNNING OPTIMIZER AT "; - print STDERR unchecked_output("date"); - my $mergeLog="$logdir/prune-merge.log.$iteration"; - - my $score = 0; - my $icc = 0; - my $inweights="$dir/weights.$im1"; - for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) { - print STDERR "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n"; - print STDERR unchecked_output("date"); - $icc++; - $cmd="$MAPINPUT -w $inweights -r $dir/hgs -s $devSize -d $rand_directions > $dir/agenda.$im1-$opt_iter"; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - check_call("mkdir -p $dir/splag.$im1"); - $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput."; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; - my @shards = grep { /^mapinput\./ } readdir(DIR); - closedir DIR; - die "No shards!" unless scalar @shards > 0; - my $joblist = ""; - my $nmappers = 0; - my @mapoutputs = (); - @cleanupcmds = (); - my %o2i = (); - my $first_shard = 1; - my $mkfile; # only used with makefiles - my $mkfilename; - if ($use_make) { - $mkfilename = "$dir/splag.$im1/domap.mk"; - open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; - print $mkfile "all: $dir/splag.$im1/map.done\n\n"; - } - my @mkouts = (); # only used with makefiles - for my $shard (@shards) { - my $mapoutput = $shard; - my $client_name = $shard; - $client_name =~ s/mapinput.//; - $client_name = "dpmert.$client_name"; - $mapoutput =~ s/mapinput/mapoutput/; - push @mapoutputs, "$dir/splag.$im1/$mapoutput"; - $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; - my $script = "$MAPPER -s $srcFile -m $metric $refs < $dir/splag.$im1/$shard | sort -t \$'\\t' -k 1 > $dir/splag.$im1/$mapoutput"; - if ($use_make) { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "#!/bin/bash\n"; - print F "$script\n"; - close F; - my $output = "$dir/splag.$im1/$mapoutput"; - push @mkouts, $output; - chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; - } else { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "$script\n"; - close F; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - - $nmappers++; - my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; - my $jobid = check_output("$qcmd"); - chomp $jobid; - $jobid =~ s/^(\d+)(.*?)$/\1/g; - $jobid =~ s/^Your job (\d+) .*$/\1/; - push(@cleanupcmds, "qdel $jobid 2> /dev/null"); - print STDERR " $jobid"; - if ($joblist == "") { $joblist = $jobid; } - else {$joblist = $joblist . "\|" . $jobid; } - } - } - if ($use_make) { - print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; - close $mkfile; - my $mcmd = "make -j $jobs -f $mkfilename"; - print STDERR "\nExecuting: $mcmd\n"; - check_call($mcmd); - } else { - print STDERR "\nLaunched $nmappers mappers.\n"; - sleep 8; - print STDERR "Waiting for mappers to complete...\n"; - while ($nmappers > 0) { - sleep 5; - my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); - $nmappers = scalar @livejobs; - } - print STDERR "All mappers complete.\n"; - } - my $tol = 0; - my $til = 0; - for my $mo (@mapoutputs) { - my $olines = get_lines($mo); - my $ilines = get_lines($o2i{$mo}); - $tol += $olines; - $til += $ilines; - die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines; - } - print STDERR "Results for $tol/$til lines\n"; - print STDERR "\nSORTING AND RUNNING VEST REDUCER\n"; - print STDERR unchecked_output("date"); - $cmd="sort -t \$'\\t' -k 1 @mapoutputs | $REDUCER -m $metric > $dir/redoutput.$im1"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - $cmd="sort -nk3 $DIR_FLAG '-t|' $dir/redoutput.$im1 | head -1"; - # sort returns failure even when it doesn't fail for some reason - my $best=unchecked_output("$cmd"); chomp $best; - print STDERR "$best\n"; - my ($oa, $x, $xscore) = split /\|/, $best; - $score = $xscore; - print STDERR "PROJECTED SCORE: $score\n"; - if (abs($x) < $epsilon) { - print STDERR "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n"; - last; - } - my $psd = $score - $last_score; - $last_score = $score; - if (abs($psd) < $epsilon) { - print STDERR "\nOPTIMIZER: no score improvement: abs($psd) < $epsilon\n"; - last; - } - my ($origin, $axis) = split /\s+/, $oa; - - my %ori = convert($origin); - my %axi = convert($axis); - - my $finalFile="$dir/weights.$im1-$opt_iter"; - open W, ">$finalFile" or die "Can't write: $finalFile: $!"; - my $norm = 0; - for my $k (sort keys %ori) { - my $dd = $ori{$k} + $axi{$k} * $x; - $norm += $dd * $dd; - } - $norm = sqrt($norm); - $norm = 1; - for my $k (sort keys %ori) { - my $v = ($ori{$k} + $axi{$k} * $x) / $norm; - print W "$k $v\n"; - } - check_call("rm $dir/splag.$im1/*"); - $inweights = $finalFile; - } - $lastWeightsFile = "$dir/weights.$iteration"; - check_call("cp $inweights $lastWeightsFile"); - if ($icc < 2) { - print STDERR "\nREACHED STOPPING CRITERION: score change too little\n"; - last; - } - $lastPScore = $score; - $iteration++; - print STDERR "\n==========\n"; -} - -check_call("cp $lastWeightsFile $dir/weights.final"); -print STDERR "\nFINAL WEIGHTS: $dir/weights.final\n(Use -w with the decoder)\n\n"; -print STDOUT "$dir/weights.final\n"; -exit 0; - - -sub get_lines { - my $fn = shift @_; - open FL, "<$fn" or die "Couldn't read $fn: $!"; - my $lc = 0; - while() { $lc++; } - return $lc; -} - -sub read_weights_file { - my ($file) = @_; - open F, "<$file" or die "Couldn't read $file: $!"; - my @r = (); - my $pm = -1; - while() { - next if /^#/; - next if /^\s*$/; - chomp; - if (/^(.+)\s+(.+)$/) { - my $m = $1; - my $w = $2; - die "Weights out of order: $m <= $pm" unless $m > $pm; - push @r, $w; - } else { - warn "Unexpected feature name in weight file: $_"; - } - } - close F; - return join ' ', @r; -} - -sub update_weights_file { - my ($neww, $rfn, $rpts) = @_; - my @feats = @$rfn; - my @pts = @$rpts; - my $num_feats = scalar @feats; - my $num_pts = scalar @pts; - die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; - open G, ">$neww" or die; - for (my $i = 0; $i < $num_feats; $i++) { - my $f = $feats[$i]; - my $lambda = $pts[$i]; - print G "$f $lambda\n"; - } - close G; -} - -sub enseg { - my $src = shift; - my $newsrc = shift; - open(SRC, $src); - open(NEWSRC, ">$newsrc"); - my $i=0; - while (my $line=){ - chomp $line; - if ($line =~ /^\s* tags, you must include a zero-based id attribute"; - } - } else { - print NEWSRC "$line\n"; - } - $i++; - } - close SRC; - close NEWSRC; -} - -sub print_help { - - my $executable = basename($0); chomp $executable; - print << "Help"; - -Usage: $executable [options] - - $executable [options] - Runs a complete MERT optimization. Required options are --weights, - --devset, and --config. - -Options: - - --config [-c ] - The decoder configuration file. - - --devset [-d ] - The source *and* references for the development set. - - --weights [-w ] - A file specifying initial feature weights. The format is - FeatureName_1 value1 - FeatureName_2 value2 - **All and only the weights listed in will be optimized!** - - --metric - Metric to optimize. - Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi - - --iterations - Maximum number of iterations to run. If not specified, defaults - to 10. - - --pass-suffix - If the decoder is doing multi-pass decoding, the pass suffix "2", - "3", etc., is used to control what iteration of weights is set. - - --rand-directions - MERT will attempt to optimize along all of the principle directions, - set this parameter to explore other directions. Defaults to 5. - - --output-dir - Directory for intermediate and output files. - - --help - Print this message and exit. - -Job control options: - - --jobs - Number of decoder processes to run in parallel. [default=$default_jobs] - - --qsub - Use qsub to run jobs in parallel (qsub must be configured in - environment/LocalEnvironment.pm) - - --pmem - Amount of physical memory requested for parallel decoding jobs - (used with qsub requests only) - -Help -} - -sub convert { - my ($str) = @_; - my @ps = split /;/, $str; - my %dict = (); - for my $p (@ps) { - my ($k, $v) = split /=/, $p; - $dict{$k} = $v; - } - return %dict; -} - - - -sub cmdline { - return join ' ',($0,@ORIG_ARGV); -} - -#buggy: last arg gets quoted sometimes? -my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; -my $shell_escape_in_quote=qr{[\\"\$`!]}; - -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} - -sub escaped_shell_args { - return map {local $_=$_;chomp;escape_shell($_)} @_; -} - -sub escaped_shell_args_str { - return join ' ',&escaped_shell_args(@_); -} - -sub escaped_cmdline { - return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); -} - -sub split_devset { - my ($infile, $outsrc, $outref) = @_; - open F, "<$infile" or die "Can't read $infile: $!"; - open S, ">$outsrc" or die "Can't write $outsrc: $!"; - open R, ">$outref" or die "Can't write $outref: $!"; - while() { - chomp; - my ($src, @refs) = split /\s*\|\|\|\s*/; - die "Malformed devset line: $_\n" unless scalar @refs > 0; - print S "$src\n"; - print R join(' ||| ', @refs) . "\n"; - } - close R; - close S; - close F; -} - diff --git a/dpmert/error_surface.cc b/dpmert/error_surface.cc deleted file mode 100644 index 515b67f8..00000000 --- a/dpmert/error_surface.cc +++ /dev/null @@ -1,42 +0,0 @@ -#include "error_surface.h" - -#include -#include - -using namespace std; - -ErrorSurface::~ErrorSurface() {} - -void ErrorSurface::Serialize(std::string* out) const { - const int segments = this->size(); - ostringstream os(ios::binary); - os.write((const char*)&segments,sizeof(segments)); - for (int i = 0; i < segments; ++i) { - const ErrorSegment& cur = (*this)[i]; - string senc; - cur.delta.Encode(&senc); - assert(senc.size() < 1024); - unsigned char len = senc.size(); - os.write((const char*)&cur.x, sizeof(cur.x)); - os.write((const char*)&len, sizeof(len)); - os.write((const char*)&senc[0], len); - } - *out = os.str(); -} - -void ErrorSurface::Deserialize(const std::string& in) { - istringstream is(in, ios::binary); - int segments; - is.read((char*)&segments, sizeof(segments)); - this->resize(segments); - for (int i = 0; i < segments; ++i) { - ErrorSegment& cur = (*this)[i]; - unsigned char len; - is.read((char*)&cur.x, sizeof(cur.x)); - is.read((char*)&len, sizeof(len)); - string senc(len, '\0'); assert(senc.size() == len); - is.read((char*)&senc[0], len); - cur.delta = SufficientStats(senc); - } -} - diff --git a/dpmert/error_surface.h b/dpmert/error_surface.h deleted file mode 100644 index bb65847b..00000000 --- a/dpmert/error_surface.h +++ /dev/null @@ -1,24 +0,0 @@ -#ifndef _ERROR_SURFACE_H_ -#define _ERROR_SURFACE_H_ - -#include -#include - -#include "ns.h" - -class Score; - -struct ErrorSegment { - double x; - SufficientStats delta; - ErrorSegment() : x(0), delta() {} -}; - -class ErrorSurface : public std::vector { - public: - ~ErrorSurface(); - void Serialize(std::string* out) const; - void Deserialize(const std::string& in); -}; - -#endif diff --git a/dpmert/libcall.pl b/dpmert/libcall.pl deleted file mode 100644 index c7d0f128..00000000 --- a/dpmert/libcall.pl +++ /dev/null @@ -1,71 +0,0 @@ -use IPC::Open3; -use Symbol qw(gensym); - -$DUMMY_STDERR = gensym(); -$DUMMY_STDIN = gensym(); - -# Run the command and ignore failures -sub unchecked_call { - system("@_") -} - -# Run the command and return its output, if any ignoring failures -sub unchecked_output { - return `@_` -} - -# WARNING: Do not use this for commands that will return large amounts -# of stdout or stderr -- they might block indefinitely -sub check_output { - print STDERR "Executing and gathering output: @_\n"; - - my $pid = open3($DUMMY_STDIN, \*PH, $DUMMY_STDERR, @_); - my $proc_output = ""; - while( ) { - $proc_output .= $_; - } - waitpid($pid, 0); - # TODO: Grab signal that the process died from - my $child_exit_status = $? >> 8; - if($child_exit_status == 0) { - return $proc_output; - } else { - print STDERR "ERROR: Execution of @_ failed.\n"; - exit(1); - } -} - -# Based on Moses' safesystem sub -sub check_call { - print STDERR "Executing: @_\n"; - system(@_); - my $exitcode = $? >> 8; - if($exitcode == 0) { - return 0; - } elsif ($? == -1) { - print STDERR "ERROR: Failed to execute: @_\n $!\n"; - exit(1); - - } elsif ($? & 127) { - printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n", - ($? & 127), ($? & 128) ? 'with' : 'without'; - exit(1); - - } else { - print STDERR "Failed with exit code: $exitcode\n" if $exitcode; - exit($exitcode); - } -} - -sub check_bash_call { - my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); - check_call(@args); -} - -sub check_bash_output { - my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); - return check_output(@args); -} - -# perl module weirdness... -return 1; diff --git a/dpmert/line_mediator.pl b/dpmert/line_mediator.pl deleted file mode 100755 index bc2bb24c..00000000 --- a/dpmert/line_mediator.pl +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/perl -w -#hooks up two processes, 2nd of which has one line of output per line of input, expected by the first, which starts off the communication - -# if you don't know how to fork/exec in a C program, this could be helpful under limited cirmustances (would be ok to liaise with sentserver) - -#WARNING: because it waits for the result from command 2 after sending every line, and especially if command 1 does the same, using sentserver as command 2 won't actually buy you any real parallelism. - -use strict; -use IPC::Open2; -use POSIX qw(pipe dup2 STDIN_FILENO STDOUT_FILENO); - -my $quiet=!$ENV{DEBUG}; -$quiet=1 if $ENV{QUIET}; -sub info { - local $,=' '; - print STDERR @_ unless $quiet; -} - -my $mode='CROSS'; -my $ser='DIRECT'; -$mode='PIPE' if $ENV{PIPE}; -$mode='SNAKE' if $ENV{SNAKE}; -$mode='CROSS' if $ENV{CROSS}; -$ser='SERIAL' if $ENV{SERIAL}; -$ser='DIRECT' if $ENV{DIRECT}; -$ser='SERIAL' if $mode eq 'SNAKE'; -info("mode: $mode\n"); -info("connection: $ser\n"); - - -my @c1; -if (scalar @ARGV) { - do { - push @c1,shift - } while scalar @ARGV && $c1[$#c1] ne '--'; -} -pop @c1; -my @c2=@ARGV; -@ARGV=(); -(scalar @c1 && scalar @c2) || die qq{ -usage: $0 cmd1 args -- cmd2 args -all options are environment variables. -DEBUG=1 env var enables debugging output. -CROSS=1 hooks up two processes, 2nd of which has one line of output per line of input, expected by the first, which starts off the communication. crosses stdin/stderr of cmd1 and cmd2 line by line (both must flush on newline and output. cmd1 initiates the conversation (sends the first line). default: attempts to cross stdin/stdout of c1 and c2 directly (via two unidirectional posix pipes created before fork). -SERIAL=1: (no parallelism possible) but lines exchanged are logged if DEBUG. -if SNAKE then stdin -> c1 -> c2 -> c1 -> stdout. -if PIPE then stdin -> c1 -> c2 -> stdout (same as shell c1|c2, but with SERIAL you can see the intermediate in real time; you could do similar with c1 | tee /dev/fd/2 |c2. -DIRECT=1 (default) will override SERIAL=1. -CROSS=1 (default) will override SNAKE or PIPE. -}; - -info("1 cmd:",@c1,"\n"); -info("2 cmd:",@c2,"\n"); - -sub lineto { - select $_[0]; - $|=1; - shift; - print @_; -} - -if ($ser eq 'SERIAL') { - my ($R1,$W1,$R2,$W2); - my $c1p=open2($R1,$W1,@c1); # Open2 R W backward from Open3. - my $c2p=open2($R2,$W2,@c2); - if ($mode eq 'CROSS') { - while(<$R1>) { - info("1:",$_); - lineto($W2,$_); - last unless defined ($_=<$R2>); - info("1|2:",$_); - lineto($W1,$_); - } - } else { - my $snake=$mode eq 'SNAKE'; - while() { - info("IN:",$_); - lineto($W1,$_); - last unless defined ($_=<$R1>); - info("IN|1:",$_); - lineto($W2,$_); - last unless defined ($_=<$R2>); - info("IN|1|2:",$_); - if ($snake) { - lineto($W1,$_); - last unless defined ($_=<$R1>); - info("IN|1|2|1:",$_); - } - lineto(*STDOUT,$_); - } - } -} else { - info("DIRECT mode\n"); - my @rw1=POSIX::pipe(); - my @rw2=POSIX::pipe(); - my $pid=undef; - $SIG{CHLD} = sub { wait }; - while (not defined ($pid=fork())) { - sleep 1; - } - my $pipe = $mode eq 'PIPE'; - unless ($pipe) { - POSIX::close(STDOUT_FILENO); - POSIX::close(STDIN_FILENO); - } - if ($pid) { - POSIX::dup2($rw1[1],STDOUT_FILENO); - POSIX::dup2($rw2[0],STDIN_FILENO) unless $pipe; - exec @c1; - } else { - POSIX::dup2($rw2[1],STDOUT_FILENO) unless $pipe; - POSIX::dup2($rw1[0],STDIN_FILENO); - exec @c2; - } - while (wait()!=-1) {} -} diff --git a/dpmert/line_optimizer.cc b/dpmert/line_optimizer.cc deleted file mode 100644 index 9cf33502..00000000 --- a/dpmert/line_optimizer.cc +++ /dev/null @@ -1,114 +0,0 @@ -#include "line_optimizer.h" - -#include -#include - -#include "sparse_vector.h" -#include "ns.h" - -using namespace std; - -typedef ErrorSurface::const_iterator ErrorIter; - -// sort by increasing x-ints -struct IntervalComp { - bool operator() (const ErrorIter& a, const ErrorIter& b) const { - return a->x < b->x; - } -}; - -double LineOptimizer::LineOptimize( - const EvaluationMetric* metric, - const vector& surfaces, - const LineOptimizer::ScoreType type, - float* best_score, - const double epsilon) { - // cerr << "MIN=" << MINIMIZE_SCORE << " MAX=" << MAXIMIZE_SCORE << " MINE=" << type << endl; - vector all_ints; - for (vector::const_iterator i = surfaces.begin(); - i != surfaces.end(); ++i) { - const ErrorSurface& surface = *i; - for (ErrorIter j = surface.begin(); j != surface.end(); ++j) - all_ints.push_back(j); - } - sort(all_ints.begin(), all_ints.end(), IntervalComp()); - double last_boundary = all_ints.front()->x; - SufficientStats acc; - float& cur_best_score = *best_score; - cur_best_score = (type == MAXIMIZE_SCORE ? - -numeric_limits::max() : numeric_limits::max()); - bool left_edge = true; - double pos = numeric_limits::quiet_NaN(); - for (vector::iterator i = all_ints.begin(); - i != all_ints.end(); ++i) { - const ErrorSegment& seg = **i; - if (seg.x - last_boundary > epsilon) { - float sco = metric->ComputeScore(acc); - if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || - (type == MINIMIZE_SCORE && sco < cur_best_score) ) { - cur_best_score = sco; - if (left_edge) { - pos = seg.x - 0.1; - left_edge = false; - } else { - pos = last_boundary + (seg.x - last_boundary) / 2; - } - //cerr << "NEW BEST: " << pos << " (score=" << cur_best_score << ")\n"; - } - // string xx = metric->DetailedScore(acc); cerr << "---- " << xx; -#undef SHOW_ERROR_SURFACES -#ifdef SHOW_ERROR_SURFACES - cerr << "x=" << seg.x << "\ts=" << sco << "\n"; -#endif - last_boundary = seg.x; - } - // cerr << "x-boundary=" << seg.x << "\n"; - //string x2; acc.Encode(&x2); cerr << " ACC: " << x2 << endl; - //string x1; seg.delta.Encode(&x1); cerr << " DELTA: " << x1 << endl; - acc += seg.delta; - } - float sco = metric->ComputeScore(acc); - if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || - (type == MINIMIZE_SCORE && sco < cur_best_score) ) { - cur_best_score = sco; - if (left_edge) { - pos = 0; - } else { - pos = last_boundary + 1000.0; - } - } - return pos; -} - -void LineOptimizer::RandomUnitVector(const vector& features_to_optimize, - SparseVector* axis, - RandomNumberGenerator* rng) { - axis->clear(); - for (int i = 0; i < features_to_optimize.size(); ++i) - axis->set_value(features_to_optimize[i], rng->NextNormal(0.0,1.0)); - (*axis) /= axis->l2norm(); -} - -void LineOptimizer::CreateOptimizationDirections( - const vector& features_to_optimize, - int additional_random_directions, - RandomNumberGenerator* rng, - vector >* dirs - , bool include_orthogonal - ) { - dirs->clear(); - typedef SparseVector Dir; - vector &out=*dirs; - int i=0; - if (include_orthogonal) - for (;i - -#include "sparse_vector.h" -#include "error_surface.h" -#include "sampler.h" - -class EvaluationMetric; -class Weights; - -struct LineOptimizer { - - // use MINIMIZE_SCORE for things like TER, WER - // MAXIMIZE_SCORE for things like BLEU - enum ScoreType { MAXIMIZE_SCORE, MINIMIZE_SCORE }; - - // merge all the error surfaces together into a global - // error surface and find (the middle of) the best segment - static double LineOptimize( - const EvaluationMetric* metric, - const std::vector& envs, - const LineOptimizer::ScoreType type, - float* best_score, - const double epsilon = 1.0/65536.0); - - // return a random vector of length 1 where all dimensions - // not listed in dimensions will be 0. - static void RandomUnitVector(const std::vector& dimensions, - SparseVector* axis, - RandomNumberGenerator* rng); - - // generate a list of directions to optimize; the list will - // contain the orthogonal vectors corresponding to the dimensions in - // primary and then additional_random_directions directions in those - // dimensions as well. All vectors will be length 1. - static void CreateOptimizationDirections( - const std::vector& primary, - int additional_random_directions, - RandomNumberGenerator* rng, - std::vector >* dirs - , bool include_primary=true - ); - -}; - -#endif diff --git a/dpmert/lo_test.cc b/dpmert/lo_test.cc deleted file mode 100644 index 95a08d3d..00000000 --- a/dpmert/lo_test.cc +++ /dev/null @@ -1,229 +0,0 @@ -#define BOOST_TEST_MODULE LineOptimizerTest -#include -#include - -#include -#include -#include - -#include - -#include "ns.h" -#include "ns_docscorer.h" -#include "ces.h" -#include "fdict.h" -#include "hg.h" -#include "kbest.h" -#include "hg_io.h" -#include "filelib.h" -#include "inside_outside.h" -#include "viterbi.h" -#include "mert_geometry.h" -#include "line_optimizer.h" - -using namespace std; - -const char* ref11 = "australia reopens embassy in manila"; -const char* ref12 = "( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack ."; -const char* ref21 = "australia reopened manila embassy"; -const char* ref22 = "( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack ."; -const char* ref31 = "australia to reopen embassy in manila"; -const char* ref32 = "( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so - called confirmed terrorist attack threats ."; -const char* ref41 = "australia to re - open its embassy to manila"; -const char* ref42 = "( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so - called \" clear \" threat of terrorist attack 7 weeks ago ."; - -BOOST_AUTO_TEST_CASE( TestCheckNaN) { - double x = 0; - double y = 0; - double z = x / y; - BOOST_CHECK_EQUAL(true, std::isnan(z)); -} - -BOOST_AUTO_TEST_CASE(TestConvexHull) { - boost::shared_ptr a1(new MERTPoint(-1, 0)); - boost::shared_ptr b1(new MERTPoint(1, 0)); - boost::shared_ptr a2(new MERTPoint(-1, 1)); - boost::shared_ptr b2(new MERTPoint(1, -1)); - vector > sa; sa.push_back(a1); sa.push_back(b1); - vector > sb; sb.push_back(a2); sb.push_back(b2); - ConvexHull a(sa); - cerr << a << endl; - ConvexHull b(sb); - ConvexHull c = a; - c *= b; - cerr << a << " (*) " << b << " = " << c << endl; - BOOST_CHECK_EQUAL(3, c.size()); -} - -BOOST_AUTO_TEST_CASE(TestConvexHullInside) { - const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; - Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); - SparseVector wts; - wts.set_value(FD::Convert("f1"), 0.4); - wts.set_value(FD::Convert("f2"), 1.0); - hg.Reweight(wts); - vector, prob_t> > list; - std::vector > features; - KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); - for (int i = 0; i < 10; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; - } - SparseVector dir; dir.set_value(FD::Convert("f1"), 1.0); - ConvexHullWeightFunction wf(wts, dir); - ConvexHull env = Inside(hg, NULL, wf); - cerr << env << endl; - const vector >& segs = env.GetSortedSegs(); - dir *= segs[1]->x; - wts += dir; - hg.Reweight(wts); - KBest::KBestDerivations, ESentenceTraversal> kbest2(hg, 10); - for (int i = 0; i < 10; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest2.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; - } - for (unsigned i = 0; i < segs.size(); ++i) { - cerr << "seg=" << i << endl; - vector trans; - segs[i]->ConstructTranslation(&trans); - cerr << TD::GetString(trans) << endl; - } -} - -BOOST_AUTO_TEST_CASE( TestS1) { - int fPhraseModel_0 = FD::Convert("PhraseModel_0"); - int fPhraseModel_1 = FD::Convert("PhraseModel_1"); - int fPhraseModel_2 = FD::Convert("PhraseModel_2"); - int fLanguageModel = FD::Convert("LanguageModel"); - int fWordPenalty = FD::Convert("WordPenalty"); - int fPassThrough = FD::Convert("PassThrough"); - SparseVector wts; - wts.set_value(fWordPenalty, 4.25); - wts.set_value(fLanguageModel, -1.1165); - wts.set_value(fPhraseModel_0, -0.96); - wts.set_value(fPhraseModel_1, -0.65); - wts.set_value(fPhraseModel_2, -0.77); - wts.set_value(fPassThrough, -10.0); - - vector to_optimize; - to_optimize.push_back(fWordPenalty); - to_optimize.push_back(fLanguageModel); - to_optimize.push_back(fPhraseModel_0); - to_optimize.push_back(fPhraseModel_1); - to_optimize.push_back(fPhraseModel_2); - - std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : "test_data"); - - Hypergraph hg; - ReadFile rf(path + "/0.json.gz"); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); - hg.Reweight(wts); - - Hypergraph hg2; - ReadFile rf2(path + "/1.json.gz"); - HypergraphIO::ReadFromJSON(rf2.stream(), &hg2); - hg2.Reweight(wts); - - vector > refs1(4); - TD::ConvertSentence(ref11, &refs1[0]); - TD::ConvertSentence(ref21, &refs1[1]); - TD::ConvertSentence(ref31, &refs1[2]); - TD::ConvertSentence(ref41, &refs1[3]); - vector > refs2(4); - TD::ConvertSentence(ref12, &refs2[0]); - TD::ConvertSentence(ref22, &refs2[1]); - TD::ConvertSentence(ref32, &refs2[2]); - TD::ConvertSentence(ref42, &refs2[3]); - vector envs(2); - - RandomNumberGenerator rng; - - vector > axes; // directions to search - LineOptimizer::CreateOptimizationDirections( - to_optimize, - 10, - &rng, - &axes); - assert(axes.size() == 10 + to_optimize.size()); - for (unsigned i = 0; i < axes.size(); ++i) - cerr << axes[i] << endl; - const SparseVector& axis = axes[0]; - - cerr << "Computing Viterbi envelope using inside algorithm...\n"; - cerr << "axis: " << axis << endl; - clock_t t_start=clock(); - ConvexHullWeightFunction wf(wts, axis); // wts = starting point, axis = search direction - envs[0] = Inside(hg, NULL, wf); - envs[1] = Inside(hg2, NULL, wf); - - vector es(2); - EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU"); - boost::shared_ptr scorer1 = metric->CreateSegmentEvaluator(refs1); - boost::shared_ptr scorer2 = metric->CreateSegmentEvaluator(refs2); - ComputeErrorSurface(*scorer1, envs[0], &es[0], metric, hg); - ComputeErrorSurface(*scorer2, envs[1], &es[1], metric, hg2); - cerr << envs[0].size() << " " << envs[1].size() << endl; - cerr << es[0].size() << " " << es[1].size() << endl; - envs.clear(); - clock_t t_env=clock(); - float score; - double m = LineOptimizer::LineOptimize(metric,es, LineOptimizer::MAXIMIZE_SCORE, &score); - clock_t t_opt=clock(); - cerr << "line optimizer returned: " << m << " (SCORE=" << score << ")\n"; - BOOST_CHECK_CLOSE(0.48719698, score, 1e-5); - SparseVector res = axis; - res *= m; - res += wts; - cerr << "res: " << res << endl; - cerr << "ENVELOPE PROCESSING=" << (static_cast(t_env - t_start) / 1000.0) << endl; - cerr << " LINE OPTIMIZATION=" << (static_cast(t_opt - t_env) / 1000.0) << endl; - hg.Reweight(res); - hg2.Reweight(res); - vector t1,t2; - ViterbiESentence(hg, &t1); - ViterbiESentence(hg2, &t2); - cerr << TD::GetString(t1) << endl; - cerr << TD::GetString(t2) << endl; -} - -BOOST_AUTO_TEST_CASE(TestZeroOrigin) { - const string json = "{\"rules\":[1,\"[X7] ||| blA ||| without ||| LHSProb=3.92173 LexE2F=2.90799 LexF2E=1.85003 GenerativeProb=10.5381 RulePenalty=1 XFE=2.77259 XEF=0.441833 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=0.693147\",2,\"[X7] ||| blA ||| except ||| LHSProb=4.92173 LexE2F=3.90799 LexF2E=1.85003 GenerativeProb=11.5381 RulePenalty=1 XFE=2.77259 XEF=1.44183 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=1.69315\",3,\"[S] ||| [X7,1] ||| [1] ||| GlueTop=1\",4,\"[X28] ||| EnwAn ||| title ||| LHSProb=3.96802 LexE2F=2.22462 LexF2E=1.83258 GenerativeProb=10.0863 RulePenalty=1 XFE=0 XEF=1.20397 LabelledEF=1.20397 LabelledFE=-1.98341e-08 LogRuleCount=1.09861\",5,\"[X0] ||| EnwAn ||| funny ||| LHSProb=3.98479 LexE2F=1.79176 LexF2E=3.21888 GenerativeProb=11.1681 RulePenalty=1 XFE=0 XEF=2.30259 LabelledEF=2.30259 LabelledFE=0 LogRuleCount=0 SingletonRule=1\",6,\"[X8] ||| [X7,1] EnwAn ||| entitled [1] ||| LHSProb=3.82533 LexE2F=3.21888 LexF2E=2.52573 GenerativeProb=11.3276 RulePenalty=1 XFE=1.20397 XEF=1.20397 LabelledEF=2.30259 LabelledFE=2.30259 LogRuleCount=0 SingletonRule=1\",7,\"[S] ||| [S,1] [X28,2] ||| [1] [2] ||| Glue=1\",8,\"[S] ||| [S,1] [X0,2] ||| [1] [2] ||| Glue=1\",9,\"[S] ||| [X8,1] ||| [1] ||| GlueTop=1\",10,\"[Goal] ||| [S,1] ||| [1]\"],\"features\":[\"PassThrough\",\"Glue\",\"GlueTop\",\"LanguageModel\",\"WordPenalty\",\"LHSProb\",\"LexE2F\",\"LexF2E\",\"GenerativeProb\",\"RulePenalty\",\"XFE\",\"XEF\",\"LabelledEF\",\"LabelledFE\",\"LogRuleCount\",\"SingletonRule\"],\"edges\":[{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,3.92173,6,2.90799,7,1.85003,8,10.5381,9,1,10,2.77259,11,0.441833,12,2.63906,13,4.96981,14,0.693147],\"rule\":1},{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,4.92173,6,3.90799,7,1.85003,8,11.5381,9,1,10,2.77259,11,1.44183,12,2.63906,13,4.96981,14,1.69315],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X7\"},\"edges\":[{\"tail\":[0],\"spans\":[0,1,-1,-1],\"feats\":[2,1],\"rule\":3}],\"node\":{\"in_edges\":[2],\"cat\":\"S\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.96802,6,2.22462,7,1.83258,8,10.0863,9,1,11,1.20397,12,1.20397,13,-1.98341e-08,14,1.09861],\"rule\":4}],\"node\":{\"in_edges\":[3],\"cat\":\"X28\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.98479,6,1.79176,7,3.21888,8,11.1681,9,1,11,2.30259,12,2.30259,15,1],\"rule\":5}],\"node\":{\"in_edges\":[4],\"cat\":\"X0\"},\"edges\":[{\"tail\":[0],\"spans\":[0,2,-1,-1],\"feats\":[5,3.82533,6,3.21888,7,2.52573,8,11.3276,9,1,10,1.20397,11,1.20397,12,2.30259,13,2.30259,15,1],\"rule\":6}],\"node\":{\"in_edges\":[5],\"cat\":\"X8\"},\"edges\":[{\"tail\":[1,2],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":7},{\"tail\":[1,3],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":8},{\"tail\":[4],\"spans\":[0,2,-1,-1],\"feats\":[2,1],\"rule\":9}],\"node\":{\"in_edges\":[6,7,8],\"cat\":\"S\"},\"edges\":[{\"tail\":[5],\"spans\":[0,2,-1,-1],\"feats\":[],\"rule\":10}],\"node\":{\"in_edges\":[9],\"cat\":\"Goal\"}}"; - Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); - SparseVector wts; - wts.set_value(FD::Convert("PassThrough"), -0.929201533002898); - hg.Reweight(wts); - - vector, prob_t> > list; - std::vector > features; - KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); - for (int i = 0; i < 10; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; - } - - SparseVector axis; axis.set_value(FD::Convert("Glue"),1.0); - ConvexHullWeightFunction wf(wts, axis); // wts = starting point, axis = search direction - vector envs(1); - envs[0] = Inside(hg, NULL, wf); - - vector > mr(4); - TD::ConvertSentence("untitled", &mr[0]); - TD::ConvertSentence("with no title", &mr[1]); - TD::ConvertSentence("without a title", &mr[2]); - TD::ConvertSentence("without title", &mr[3]); - EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU"); - boost::shared_ptr scorer1 = metric->CreateSegmentEvaluator(mr); - vector es(1); - ComputeErrorSurface(*scorer1, envs[0], &es[0], metric, hg); -} - diff --git a/dpmert/mert_geometry.cc b/dpmert/mert_geometry.cc deleted file mode 100644 index d6973658..00000000 --- a/dpmert/mert_geometry.cc +++ /dev/null @@ -1,185 +0,0 @@ -#include "mert_geometry.h" - -#include -#include - -using namespace std; - -ConvexHull::ConvexHull(int i) { - if (i == 0) { - // do nothing - <> - } else if (i == 1) { - points.push_back(boost::shared_ptr(new MERTPoint(0, 0, 0, boost::shared_ptr(), boost::shared_ptr()))); - assert(this->IsMultiplicativeIdentity()); - } else { - cerr << "Only can create ConvexHull semiring 0 and 1 with this constructor!\n"; - abort(); - } -} - -const ConvexHull ConvexHullWeightFunction::operator()(const Hypergraph::Edge& e) const { - const double m = direction.dot(e.feature_values_); - const double b = origin.dot(e.feature_values_); - MERTPoint* point = new MERTPoint(m, b, e); - return ConvexHull(1, point); -} - -ostream& operator<<(ostream& os, const ConvexHull& env) { - os << '<'; - const vector >& points = env.GetSortedSegs(); - for (int i = 0; i < points.size(); ++i) - os << (i==0 ? "" : "|") << "x=" << points[i]->x << ",b=" << points[i]->b << ",m=" << points[i]->m << ",p1=" << points[i]->p1 << ",p2=" << points[i]->p2; - return os << '>'; -} - -#define ORIGINAL_MERT_IMPLEMENTATION 1 -#ifdef ORIGINAL_MERT_IMPLEMENTATION - -struct SlopeCompare { - bool operator() (const boost::shared_ptr& a, const boost::shared_ptr& b) const { - return a->m < b->m; - } -}; - -const ConvexHull& ConvexHull::operator+=(const ConvexHull& other) { - if (!other.is_sorted) other.Sort(); - if (points.empty()) { - points = other.points; - return *this; - } - is_sorted = false; - int j = points.size(); - points.resize(points.size() + other.points.size()); - for (int i = 0; i < other.points.size(); ++i) - points[j++] = other.points[i]; - assert(j == points.size()); - return *this; -} - -void ConvexHull::Sort() const { - sort(points.begin(), points.end(), SlopeCompare()); - const int k = points.size(); - int j = 0; - for (int i = 0; i < k; ++i) { - MERTPoint l = *points[i]; - l.x = kMinusInfinity; - // cerr << "m=" << l.m << endl; - if (0 < j) { - if (points[j-1]->m == l.m) { // lines are parallel - if (l.b <= points[j-1]->b) continue; - --j; - } - while(0 < j) { - l.x = (l.b - points[j-1]->b) / (points[j-1]->m - l.m); - if (points[j-1]->x < l.x) break; - --j; - } - if (0 == j) l.x = kMinusInfinity; - } - *points[j++] = l; - } - points.resize(j); - is_sorted = true; -} - -const ConvexHull& ConvexHull::operator*=(const ConvexHull& other) { - if (other.IsMultiplicativeIdentity()) { return *this; } - if (this->IsMultiplicativeIdentity()) { (*this) = other; return *this; } - - if (!is_sorted) Sort(); - if (!other.is_sorted) other.Sort(); - - if (this->IsEdgeEnvelope()) { -// if (other.size() > 1) -// cerr << *this << " (TIMES) " << other << endl; - boost::shared_ptr edge_parent = points[0]; - const double& edge_b = edge_parent->b; - const double& edge_m = edge_parent->m; - points.clear(); - for (int i = 0; i < other.points.size(); ++i) { - const MERTPoint& p = *other.points[i]; - const double m = p.m + edge_m; - const double b = p.b + edge_b; - const double& x = p.x; // x's don't change with * - points.push_back(boost::shared_ptr(new MERTPoint(x, m, b, edge_parent, other.points[i]))); - assert(points.back()->p1->edge); - } -// if (other.size() > 1) -// cerr << " = " << *this << endl; - } else { - vector > new_points; - int this_i = 0; - int other_i = 0; - const int this_size = points.size(); - const int other_size = other.points.size(); - double cur_x = kMinusInfinity; // moves from left to right across the - // real numbers, stopping for all inter- - // sections - double this_next_val = (1 < this_size ? points[1]->x : kPlusInfinity); - double other_next_val = (1 < other_size ? other.points[1]->x : kPlusInfinity); - while (this_i < this_size && other_i < other_size) { - const MERTPoint& this_point = *points[this_i]; - const MERTPoint& other_point= *other.points[other_i]; - const double m = this_point.m + other_point.m; - const double b = this_point.b + other_point.b; - - new_points.push_back(boost::shared_ptr(new MERTPoint(cur_x, m, b, points[this_i], other.points[other_i]))); - int comp = 0; - if (this_next_val < other_next_val) comp = -1; else - if (this_next_val > other_next_val) comp = 1; - if (0 == comp) { // the next values are equal, advance both indices - ++this_i; - ++other_i; - cur_x = this_next_val; // could be other_next_val (they're equal!) - this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity); - other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity); - } else { // advance the i with the lower x, update cur_x - if (-1 == comp) { - ++this_i; - cur_x = this_next_val; - this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity); - } else { - ++other_i; - cur_x = other_next_val; - other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity); - } - } - } - points.swap(new_points); - } - //cerr << "Multiply: result=" << (*this) << endl; - return *this; -} - -// recursively construct translation -void MERTPoint::ConstructTranslation(vector* trans) const { - const MERTPoint* cur = this; - vector > ant_trans; - while(!cur->edge) { - ant_trans.resize(ant_trans.size() + 1); - cur->p2->ConstructTranslation(&ant_trans.back()); - cur = cur->p1.get(); - } - size_t ant_size = ant_trans.size(); - vector*> pants(ant_size); - assert(ant_size == cur->edge->tail_nodes_.size()); - --ant_size; - for (int i = 0; i < pants.size(); ++i) pants[ant_size - i] = &ant_trans[i]; - cur->edge->rule_->ESubstitute(pants, trans); -} - -void MERTPoint::CollectEdgesUsed(std::vector* edges_used) const { - if (edge) { - assert(edge->id_ < edges_used->size()); - (*edges_used)[edge->id_] = true; - } - if (p1) p1->CollectEdgesUsed(edges_used); - if (p2) p2->CollectEdgesUsed(edges_used); -} - -#else - -// THIS IS THE NEW FASTER IMPLEMENTATION OF THE MERT SEMIRING OPERATIONS - -#endif - diff --git a/dpmert/mert_geometry.h b/dpmert/mert_geometry.h deleted file mode 100644 index a8b6959e..00000000 --- a/dpmert/mert_geometry.h +++ /dev/null @@ -1,81 +0,0 @@ -#ifndef _MERT_GEOMETRY_H_ -#define _MERT_GEOMETRY_H_ - -#include -#include -#include - -#include "hg.h" -#include "sparse_vector.h" - -static const double kMinusInfinity = -std::numeric_limits::infinity(); -static const double kPlusInfinity = std::numeric_limits::infinity(); - -struct MERTPoint { - MERTPoint() : x(), m(), b(), edge() {} - MERTPoint(double _m, double _b) : - x(kMinusInfinity), m(_m), b(_b), edge() {} - MERTPoint(double _x, double _m, double _b, const boost::shared_ptr& p1_, const boost::shared_ptr& p2_) : - x(_x), m(_m), b(_b), p1(p1_), p2(p2_), edge() {} - MERTPoint(double _m, double _b, const Hypergraph::Edge& edge) : - x(kMinusInfinity), m(_m), b(_b), edge(&edge) {} - - double x; // x intersection with previous segment in env, or -inf if none - double m; // this line's slope - double b; // intercept with y-axis - - // we keep a pointer to the "parents" of this segment so we can reconstruct - // the Viterbi translation corresponding to this segment - boost::shared_ptr p1; - boost::shared_ptr p2; - - // only MERTPoints created from an edge using the ConvexHullWeightFunction - // have rules - // TRulePtr rule; - const Hypergraph::Edge* edge; - - // recursively recover the Viterbi translation that will result from setting - // the weights to origin + axis * x, where x is any value from this->x up - // until the next largest x in the containing ConvexHull - void ConstructTranslation(std::vector* trans) const; - void CollectEdgesUsed(std::vector* edges_used) const; -}; - -// this is the semiring value type, -// it defines constructors for 0, 1, and the operations + and * -struct ConvexHull { - // create semiring zero - ConvexHull() : is_sorted(true) {} // zero - // for debugging: - ConvexHull(const std::vector >& s) : points(s) { Sort(); } - // create semiring 1 or 0 - explicit ConvexHull(int i); - ConvexHull(int n, MERTPoint* point) : is_sorted(true), points(n, boost::shared_ptr(point)) {} - const ConvexHull& operator+=(const ConvexHull& other); - const ConvexHull& operator*=(const ConvexHull& other); - bool IsMultiplicativeIdentity() const { - return size() == 1 && (points[0]->b == 0.0 && points[0]->m == 0.0) && (!points[0]->edge) && (!points[0]->p1) && (!points[0]->p2); } - const std::vector >& GetSortedSegs() const { - if (!is_sorted) Sort(); - return points; - } - size_t size() const { return points.size(); } - - private: - bool IsEdgeEnvelope() const { - return points.size() == 1 && points[0]->edge; } - void Sort() const; - mutable bool is_sorted; - mutable std::vector > points; -}; -std::ostream& operator<<(std::ostream& os, const ConvexHull& env); - -struct ConvexHullWeightFunction { - ConvexHullWeightFunction(const SparseVector& ori, - const SparseVector& dir) : origin(ori), direction(dir) {} - const ConvexHull operator()(const Hypergraph::Edge& e) const; - const SparseVector origin; - const SparseVector direction; -}; - -#endif diff --git a/dpmert/mr_dpmert_generate_mapper_input.cc b/dpmert/mr_dpmert_generate_mapper_input.cc deleted file mode 100644 index 199cd23a..00000000 --- a/dpmert/mr_dpmert_generate_mapper_input.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include -#include - -#include -#include - -#include "filelib.h" -#include "weights.h" -#include "line_optimizer.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("dev_set_size,s",po::value(),"[REQD] Development set size (# of parallel sentences)") - ("forest_repository,r",po::value(),"[REQD] Path to forest repository") - ("weights,w",po::value(),"[REQD] Current feature weights file") - ("optimize_feature,o",po::value >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") - ("random_directions,d",po::value()->default_value(20),"Number of random directions to run the line optimizer in") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; - if (conf->count("dev_set_size") == 0) { - cerr << "Please specify the size of the development set using -d N\n"; - flag = true; - } - if (conf->count("weights") == 0) { - cerr << "Please specify the starting-point weights using -w \n"; - flag = true; - } - if (conf->count("forest_repository") == 0) { - cerr << "Please specify the forest repository location using -r \n"; - flag = true; - } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -int main(int argc, char** argv) { - RandomNumberGenerator rng; - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - vector features; - SparseVector origin; - vector w; - Weights::InitFromFile(conf["weights"].as(), &w, &features); - Weights::InitSparseVector(w, &origin); - const string forest_repository = conf["forest_repository"].as(); - if (!DirectoryExists(forest_repository)) { - cerr << "Forest repository directory " << forest_repository << " not found!\n"; - return 1; - } - if (conf.count("optimize_feature") > 0) - features=conf["optimize_feature"].as >(); - vector > directions; - vector fids(features.size()); - for (unsigned i = 0; i < features.size(); ++i) - fids[i] = FD::Convert(features[i]); - LineOptimizer::CreateOptimizationDirections( - fids, - conf["random_directions"].as(), - &rng, - &directions); - unsigned dev_set_size = conf["dev_set_size"].as(); - for (unsigned i = 0; i < dev_set_size; ++i) { - for (unsigned j = 0; j < directions.size(); ++j) { - cout << forest_repository << '/' << i << ".json.gz " << i << ' '; - print(cout, origin, "=", ";"); - cout << ' '; - print(cout, directions[j], "=", ";"); - cout << endl; - } - } - return 0; -} diff --git a/dpmert/mr_dpmert_map.cc b/dpmert/mr_dpmert_map.cc deleted file mode 100644 index d1efcf96..00000000 --- a/dpmert/mr_dpmert_map.cc +++ /dev/null @@ -1,112 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "ns.h" -#include "ns_docscorer.h" -#include "ces.h" -#include "filelib.h" -#include "stringlib.h" -#include "sparse_vector.h" -#include "mert_geometry.h" -#include "inside_outside.h" -#include "error_surface.h" -#include "b64tools.h" -#include "hg_io.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") - ("source,s",po::value(), "Source file (ignored, except for AER)") - ("evaluation_metric,m",po::value()->default_value("ibm_bleu"), "Evaluation metric being optimized") - ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; - if (!conf->count("reference")) { - cerr << "Please specify one or more references using -r \n"; - flag = true; - } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -bool ReadSparseVectorString(const string& s, SparseVector* v) { -#if 0 - // this should work, but untested. - std::istringstream i(s); - i>>*v; -#else - vector fields; - Tokenize(s, ';', &fields); - if (fields.empty()) return false; - for (unsigned i = 0; i < fields.size(); ++i) { - vector pair(2); - Tokenize(fields[i], '=', &pair); - if (pair.size() != 2) { - cerr << "Error parsing vector string: " << fields[i] << endl; - return false; - } - v->set_value(FD::Convert(pair[0]), atof(pair[1].c_str())); - } - return true; -#endif -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string evaluation_metric = conf["evaluation_metric"].as(); - EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); - DocumentScorer ds(metric, conf["reference"].as >()); - cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; - Hypergraph hg; - string last_file; - ReadFile in_read(conf["input"].as()); - istream &in=*in_read.stream(); - while(in) { - string line; - getline(in, line); - if (line.empty()) continue; - istringstream is(line); - int sent_id; - string file, s_origin, s_direction; - // path-to-file (JSON) sent_ed starting-point search-direction - is >> file >> sent_id >> s_origin >> s_direction; - SparseVector origin; - ReadSparseVectorString(s_origin, &origin); - SparseVector direction; - ReadSparseVectorString(s_direction, &direction); - // cerr << "File: " << file << "\nDir: " << direction << "\n X: " << origin << endl; - if (last_file != file) { - last_file = file; - ReadFile rf(file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); - } - const ConvexHullWeightFunction wf(origin, direction); - const ConvexHull hull = Inside(hg, NULL, wf); - - ErrorSurface es; - ComputeErrorSurface(*ds[sent_id], hull, &es, metric, hg); - //cerr << "Viterbi envelope has " << ve.size() << " segments\n"; - // cerr << "Error surface has " << es.size() << " segments\n"; - string val; - es.Serialize(&val); - cout << 'M' << ' ' << s_origin << ' ' << s_direction << '\t'; - B64::b64encode(val.c_str(), val.size(), &cout); - cout << endl << flush; - } - return 0; -} diff --git a/dpmert/mr_dpmert_reduce.cc b/dpmert/mr_dpmert_reduce.cc deleted file mode 100644 index 31512a03..00000000 --- a/dpmert/mr_dpmert_reduce.cc +++ /dev/null @@ -1,77 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "sparse_vector.h" -#include "error_surface.h" -#include "line_optimizer.h" -#include "b64tools.h" -#include "stringlib.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("evaluation_metric,m",po::value(), "Evaluation metric (IBM_BLEU, etc.)") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = conf->count("evaluation_metric") == 0; - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string evaluation_metric = conf["evaluation_metric"].as(); - EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); - LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE; - if (metric->IsErrorMetric()) - opt_type = LineOptimizer::MINIMIZE_SCORE; - - vector esv; - string last_key, line, key, val; - while(getline(cin, line)) { - size_t ks = line.find("\t"); - assert(string::npos != ks); - assert(ks > 2); - key = line.substr(2, ks - 2); - val = line.substr(ks + 1); - if (key != last_key) { - if (!last_key.empty()) { - float score; - double x = LineOptimizer::LineOptimize(metric, esv, opt_type, &score); - cout << last_key << "|" << x << "|" << score << endl; - } - last_key.swap(key); - esv.clear(); - } - if (val.size() % 4 != 0) { - cerr << "B64 encoding error 1! Skipping.\n"; - continue; - } - string encoded(val.size() / 4 * 3, '\0'); - if (!B64::b64decode(reinterpret_cast(&val[0]), val.size(), &encoded[0], encoded.size())) { - cerr << "B64 encoding error 2! Skipping.\n"; - continue; - } - esv.push_back(ErrorSurface()); - esv.back().Deserialize(encoded); - } - if (!esv.empty()) { - float score; - double x = LineOptimizer::LineOptimize(metric, esv, opt_type, &score); - cout << last_key << "|" << x << "|" << score << endl; - } - return 0; -} diff --git a/dpmert/parallelize.pl b/dpmert/parallelize.pl deleted file mode 100755 index d2ebaeea..00000000 --- a/dpmert/parallelize.pl +++ /dev/null @@ -1,423 +0,0 @@ -#!/usr/bin/env perl - -# Author: Adam Lopez -# -# This script takes a command that processes input -# from stdin one-line-at-time, and parallelizes it -# on the cluster using David Chiang's sentserver/ -# sentclient architecture. -# -# Prerequisites: the command *must* read each line -# without waiting for subsequent lines of input -# (for instance, a command which must read all lines -# of input before processing will not work) and -# return it to the output *without* buffering -# multiple lines. - -#TODO: if -j 1, run immediately, not via sentserver? possible differences in environment might make debugging harder - -#ANNOYANCE: if input is shorter than -j n lines, or at the very last few lines, repeatedly sleeps. time cut down to 15s from 60s - -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } -use LocalConfig; - -use Cwd qw/ abs_path cwd getcwd /; -use File::Temp qw/ tempfile /; -use Getopt::Long; -use IPC::Open2; -use strict; -use POSIX ":sys_wait_h"; - -use File::Basename; -my $myDir = dirname(__FILE__); -print STDERR __FILE__." -> $myDir\n"; -push(@INC, $myDir); -require "libcall.pl"; - -my $tailn=5; # +0 = concatenate all the client logs. 5 = last 5 lines -my $recycle_clients; # spawn new clients when previous ones terminate -my $stay_alive; # dont let server die when having zero clients -my $joblist = ""; -my $errordir=""; -my $multiline; -my $workdir = '.'; -my $numnodes = 8; -my $user = $ENV{"USER"}; -my $pmem = "9g"; -my $basep=50300; -my $randp=300; -my $tryp=50; -my $no_which; -my $no_cd; - -my $DEBUG=$ENV{DEBUG}; -print STDERR "DEBUG=$DEBUG output enabled.\n" if $DEBUG; -my $verbose = 1; -sub verbose { - if ($verbose) { - print STDERR @_,"\n"; - } -} -sub debug { - if ($DEBUG) { - my ($package, $filename, $line) = caller; - print STDERR "DEBUG: $filename($line): ",join(' ',@_),"\n"; - } -} -my $is_shell_special=qr.[ \t\n\\><|&;"'`~*?{}$!()].; -my $shell_escape_in_quote=qr.[\\"\$`!].; -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - return '""' unless $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} -sub preview_files { - my ($l,$skipempty,$footer,$n)=@_; - $n=$tailn unless defined $n; - my @f=grep { ! ($skipempty && -z $_) } @$l; - my $fn=join(' ',map {escape_shell($_)} @f); - my $cmd="tail -n $n $fn"; - unchecked_output("$cmd").($footer?"\nNONEMPTY FILES:\n$fn\n":""); -} -sub prefix_dirname($) { - #like `dirname but if ends in / then return the whole thing - local ($_)=@_; - if (/\/$/) { - $_; - } else { - s#/[^/]$##; - $_ ? $_ : ''; - } -} -sub ensure_final_slash($) { - local ($_)=@_; - m#/$# ? $_ : ($_."/"); -} -sub extend_path($$;$$) { - my ($base,$ext,$mkdir,$baseisdir)=@_; - if (-d $base) { - $base.="/"; - } else { - my $dir; - if ($baseisdir) { - $dir=$base; - $base.='/' unless $base =~ /\/$/; - } else { - $dir=prefix_dirname($base); - } - my @cmd=("/bin/mkdir","-p",$dir); - check_call(@cmd) if $mkdir; - } - return $base.$ext; -} - -my $abscwd=abs_path(&getcwd); -sub print_help; - -my $use_fork; -my @pids; - -# Process command-line options -unless (GetOptions( - "stay-alive" => \$stay_alive, - "recycle-clients" => \$recycle_clients, - "error-dir=s" => \$errordir, - "multi-line" => \$multiline, - "workdir=s" => \$workdir, - "use-fork" => \$use_fork, - "verbose" => \$verbose, - "jobs=i" => \$numnodes, - "pmem=s" => \$pmem, - "baseport=i" => \$basep, -# "iport=i" => \$randp, #for short name -i - "no-which!" => \$no_which, - "no-cd!" => \$no_cd, - "tailn=s" => \$tailn, -) && scalar @ARGV){ - print_help(); - die "bad options."; -} - -my $cmd = ""; -my $prog=shift; -if ($no_which) { - $cmd=$prog; -} else { - $cmd=check_output("which $prog"); - chomp $cmd; - die "$prog not found - $cmd" unless $cmd; -} -#$cmd=abs_path($cmd); -for my $arg (@ARGV) { - $cmd .= " ".escape_shell($arg); -} -die "Please specify a command to parallelize\n" if $cmd eq ''; - -my $cdcmd=$no_cd ? '' : ("cd ".escape_shell($abscwd)."\n"); - -my $executable = $cmd; -$executable =~ s/^\s*(\S+)($|\s.*)/$1/; -$executable=check_output("basename $executable"); -chomp $executable; - - -print STDERR "Parallelizing ($numnodes ways): $cmd\n\n"; - -# create -e dir and save .sh -use File::Temp qw/tempdir/; -unless ($errordir) { - $errordir=tempdir("$executable.XXXXXX",CLEANUP=>1); -} -if ($errordir) { - my $scriptfile=extend_path("$errordir/","$executable.sh",1,1); - -d $errordir || die "should have created -e dir $errordir"; - open SF,">",$scriptfile || die; - print SF "$cdcmd$cmd\n"; - close SF; - chmod 0755,$scriptfile; - $errordir=abs_path($errordir); - &verbose("-e dir: $errordir"); -} - -# set cleanup handler -my @cleanup_cmds; -sub cleanup; -sub cleanup_and_die; -$SIG{INT} = "cleanup_and_die"; -$SIG{TERM} = "cleanup_and_die"; -$SIG{HUP} = "cleanup_and_die"; - -# other subs: -sub numof_live_jobs; -sub launch_job_on_node; - - -# vars -my $mydir = check_output("dirname $0"); chomp $mydir; -my $sentserver = "$mydir/sentserver"; -my $sentclient = "$mydir/sentclient"; -my $host = check_output("hostname"); -chomp $host; - - -# find open port -srand; -my $port = 50300+int(rand($randp)); -my $endp=$port+$tryp; -sub listening_port_lines { - my $quiet=$verbose?'':'2>/dev/null'; - return unchecked_output("netstat -a -n $quiet | grep LISTENING | grep -i tcp"); -} -my $netstat=&listening_port_lines; - -if ($verbose){ print STDERR "Testing port $port...";} - -while ($netstat=~/$port/ || &listening_port_lines=~/$port/){ - if ($verbose){ print STDERR "port is busy\n";} - $port++; - if ($port > $endp){ - die "Unable to find open port\n"; - } - if ($verbose){ print STDERR "Testing port $port... "; } -} -if ($verbose){ - print STDERR "port $port is available\n"; -} - -my $key = int(rand()*1000000); - -my $multiflag = ""; -if ($multiline){ $multiflag = "-m"; print STDERR "expecting multiline output.\n"; } -my $stay_alive_flag = ""; -if ($stay_alive){ $stay_alive_flag = "--stay-alive"; print STDERR "staying alive while no clients are connected.\n"; } - -my $node_count = 0; -my $script = ""; -# fork == one thread runs the sentserver, while the -# other spawns the sentclient commands. -my $pid = fork; -if ($pid == 0) { # child - sleep 8; # give other thread time to start sentserver - $script = "$cdcmd$sentclient $host:$port:$key $cmd"; - - if ($verbose){ - print STDERR "Client script:\n====\n"; - print STDERR $script; - print STDERR "====\n"; - } - for (my $jobn=0; $jobn<$numnodes; $jobn++){ - launch_job(); - } - if ($recycle_clients) { - my $ret; - my $livejobs; - while (1) { - $ret = waitpid($pid, WNOHANG); - #print STDERR "waitpid $pid ret = $ret \n"; - last if ($ret != 0); - $livejobs = numof_live_jobs(); - if ($numnodes >= $livejobs ) { # a client terminated, OR # lines of input was less than -j - print STDERR "num of requested nodes = $numnodes; num of currently live jobs = $livejobs; Client terminated - launching another.\n"; - launch_job(); - } else { - sleep 15; - } - } - } - print STDERR "CHILD PROCESSES SPAWNED ... WAITING\n"; - for my $p (@pids) { - waitpid($p, 0); - } -} else { -# my $todo = "$sentserver -k $key $multiflag $port "; - my $todo = "$sentserver -k $key $multiflag $port $stay_alive_flag "; - if ($verbose){ print STDERR "Running: $todo\n"; } - check_call($todo); - print STDERR "Call to $sentserver returned.\n"; - cleanup(); - exit(0); -} - -sub numof_live_jobs { - if ($use_fork) { - die "not implemented"; - } else { - # We can probably continue decoding if the qstat error is only temporary - my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat"))); - return ($#livejobs + 1); - } -} -my (@errors,@outs,@cmds); - -sub launch_job { - if ($use_fork) { return launch_job_fork(); } - my $errorfile = "/dev/null"; - my $outfile = "/dev/null"; - $node_count++; - my $clientname = $executable; - $clientname =~ s/^(.{4}).*$/$1/; - $clientname = "$clientname.$node_count"; - if ($errordir){ - $errorfile = "$errordir/$clientname.ER"; - $outfile = "$errordir/$clientname.OU"; - push @errors,$errorfile; - push @outs,$outfile; - } - my $todo = qsub_args($pmem) . " -N $clientname -o $outfile -e $errorfile"; - push @cmds,$todo; - - print STDERR "Running: $todo\n"; - local(*QOUT, *QIN); - open2(\*QOUT, \*QIN, $todo) or die "Failed to open2: $!"; - print QIN $script; - close QIN; - while (my $jobid=){ - chomp $jobid; - if ($verbose){ print STDERR "Launched client job: $jobid"; } - $jobid =~ s/^(\d+)(.*?)$/\1/g; - $jobid =~ s/^Your job (\d+) .*$/\1/; - print STDERR " short job id $jobid\n"; - if ($verbose){ - print STDERR "cd: $abscwd\n"; - print STDERR "cmd: $cmd\n"; - } - if ($joblist == "") { $joblist = $jobid; } - else {$joblist = $joblist . "\|" . $jobid; } - my $cleanfn="qdel $jobid 2> /dev/null"; - push(@cleanup_cmds, $cleanfn); - } - close QOUT; -} - -sub launch_job_fork { - my $errorfile = "/dev/null"; - my $outfile = "/dev/null"; - $node_count++; - my $clientname = $executable; - $clientname =~ s/^(.{4}).*$/$1/; - $clientname = "$clientname.$node_count"; - if ($errordir){ - $errorfile = "$errordir/$clientname.ER"; - $outfile = "$errordir/$clientname.OU"; - push @errors,$errorfile; - push @outs,$outfile; - } - my $pid = fork; - if ($pid == 0) { - my ($fh, $scr_name) = get_temp_script(); - print $fh $script; - close $fh; - my $todo = "/bin/bash -xeo pipefail $scr_name 1> $outfile 2> $errorfile"; - print STDERR "EXEC: $todo\n"; - my $out = check_output("$todo"); - unlink $scr_name or warn "Failed to remove $scr_name"; - exit 0; - } else { - push @pids, $pid; - } -} - -sub get_temp_script { - my ($fh, $filename) = tempfile( "$workdir/workXXXX", SUFFIX => '.sh'); - return ($fh, $filename); -} - -sub cleanup_and_die { - cleanup(); - die "\n"; -} - -sub cleanup { - print STDERR "Cleaning up...\n"; - for $cmd (@cleanup_cmds){ - print STDERR " Cleanup command: $cmd\n"; - eval $cmd; - } - print STDERR "outputs:\n",preview_files(\@outs,1),"\n"; - print STDERR "errors:\n",preview_files(\@errors,1),"\n"; - print STDERR "cmd:\n",$cmd,"\n"; - print STDERR " cat $errordir/*.ER\nfor logs.\n"; - print STDERR "Cleanup finished.\n"; -} - -sub print_help -{ - my $name = check_output("basename $0"); chomp $name; - print << "Help"; - -usage: $name [options] - - Automatic black-box parallelization of commands. - -options: - - --use-fork - Instead of using qsub, use fork. - - -e, --error-dir - Retain output files from jobs in , rather - than silently deleting them. - - -m, --multi-line - Expect that command may produce multiple output - lines for a single input line. $name makes a - reasonable attempt to obtain all output before - processing additional inputs. However, use of this - option is inherently unsafe. - - -v, --verbose - Print diagnostic informatoin on stderr. - - -j, --jobs - Number of jobs to use. - - -p, --pmem - pmem setting for each job. - -Help -} diff --git a/dpmert/sentclient.c b/dpmert/sentclient.c deleted file mode 100644 index 91d994ab..00000000 --- a/dpmert/sentclient.c +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "sentserver.h" - -int main (int argc, char *argv[]) { - int sock, port; - char *s, *key; - struct hostent *hp; - struct sockaddr_in server; - int errors = 0; - - if (argc < 3) { - fprintf(stderr, "Usage: sentclient host[:port[:key]] command [args ...]\n"); - exit(1); - } - - s = strchr(argv[1], ':'); - key = NULL; - - if (s == NULL) { - port = DEFAULT_PORT; - } else { - *s = '\0'; - s+=1; - /* dumb hack */ - key = strchr(s, ':'); - if (key != NULL){ - *key = '\0'; - key += 1; - } - port = atoi(s); - } - - sock = socket(AF_INET, SOCK_STREAM, 0); - - hp = gethostbyname(argv[1]); - if (hp == NULL) { - fprintf(stderr, "unknown host %s\n", argv[1]); - exit(1); - } - - bzero((char *)&server, sizeof(server)); - bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); - server.sin_family = hp->h_addrtype; - server.sin_port = htons(port); - - while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { - perror("connect()"); - sleep(1); - errors++; - if (errors > 5) - exit(1); - } - - close(0); - close(1); - dup2(sock, 0); - dup2(sock, 1); - - if (key != NULL){ - write(1, key, strlen(key)); - write(1, "\n", 1); - } - - execvp(argv[2], argv+2); - return 0; -} diff --git a/dpmert/sentserver.c b/dpmert/sentserver.c deleted file mode 100644 index c20b4fa6..00000000 --- a/dpmert/sentserver.c +++ /dev/null @@ -1,515 +0,0 @@ -/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "sentserver.h" - -#define MAX_CLIENTS 64 - -struct clientinfo { - int s; - struct sockaddr_in sin; -}; - -struct line { - int id; - char *s; - int status; - struct line *next; -} *head, **ptail; - -int n_sent = 0, n_received=0, n_flushed=0; - -#define STATUS_RUNNING 0 -#define STATUS_ABORTED 1 -#define STATUS_FINISHED 2 - -pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; -pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER; -pthread_mutex_t input_mutex = PTHREAD_MUTEX_INITIALIZER; - -int n_clients = 0; -int s; -int expect_multiline_output = 0; -int log_mutex = 0; -int stay_alive = 0; /* dont panic and die with zero clients */ - -void queue_finish(struct line *node, char *s, int fid); -char * read_line(int fd, int multiline); -void done (int code); - -struct line * queue_get(int fid) { - struct line *cur; - char *s, *synch; - - if (log_mutex) fprintf(stderr, "Getting for data for fid %d\n", fid); - if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); - pthread_mutex_lock(&queue_mutex); - - /* First, check for aborted sentences. */ - - if (log_mutex) fprintf(stderr, " Checking queue for aborted jobs (fid %d)\n", fid); - for (cur = head; cur != NULL; cur = cur->next) { - if (cur->status == STATUS_ABORTED) { - cur->status = STATUS_RUNNING; - - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - - return cur; - } - } - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - - /* Otherwise, read a new one. */ - if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); - if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); - pthread_mutex_lock(&input_mutex); - s = read_line(0,0); - - while (s) { - if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); - pthread_mutex_lock(&queue_mutex); - if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); - pthread_mutex_unlock(&input_mutex); - - cur = malloc(sizeof (struct line)); - cur->id = n_sent; - cur->s = s; - cur->next = NULL; - - *ptail = cur; - ptail = &cur->next; - - n_sent++; - - if (strcmp(s,"===SYNCH===\n")==0){ - fprintf(stderr, "Received ===SYNCH=== signal (fid %d)\n", fid); - // Note: queue_finish calls free(cur->s). - // Therefore we need to create a new string here. - synch = malloc((strlen("===SYNCH===\n")+2) * sizeof (char)); - synch = strcpy(synch, s); - - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - queue_finish(cur, synch, fid); /* handles its own lock */ - - if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); - if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); - pthread_mutex_lock(&input_mutex); - - s = read_line(0,0); - } else { - if (log_mutex) fprintf(stderr, " Received new data %d (fid %d)\n", cur->id, fid); - cur->status = STATUS_RUNNING; - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - return cur; - } - } - - if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); - pthread_mutex_unlock(&input_mutex); - /* Only way to reach this point: no more output */ - - if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); - pthread_mutex_lock(&queue_mutex); - if (head == NULL) { - fprintf(stderr, "Reached end of file. Exiting.\n"); - done(0); - } else - ptail = NULL; /* This serves as a signal that there is no more input */ - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - - return NULL; -} - -void queue_panic() { - struct line *next; - while (head && head->status == STATUS_FINISHED) { - /* Write out finished sentences */ - if (head->status == STATUS_FINISHED) { - fputs(head->s, stdout); - fflush(stdout); - } - /* Write out blank line for unfinished sentences */ - if (head->status == STATUS_ABORTED) { - fputs("\n", stdout); - fflush(stdout); - } - /* By defition, there cannot be any RUNNING sentences, since - function is only called when n_clients == 0 */ - free(head->s); - next = head->next; - free(head); - head = next; - n_flushed++; - } - fclose(stdout); - fprintf(stderr, "All clients died. Panicking, flushing completed sentences and exiting.\n"); - done(1); -} - -void queue_abort(struct line *node, int fid) { - if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); - pthread_mutex_lock(&queue_mutex); - node->status = STATUS_ABORTED; - if (n_clients == 0) { - if (stay_alive) { - fprintf(stderr, "Warning! No live clients detected! Staying alive, will retry soon.\n"); - } else { - queue_panic(); - } - } - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); -} - - -void queue_print() { - struct line *cur; - - fprintf(stderr, " Queue\n"); - - for (cur = head; cur != NULL; cur = cur->next) { - switch(cur->status) { - case STATUS_RUNNING: - fprintf(stderr, " %d running ", cur->id); break; - case STATUS_ABORTED: - fprintf(stderr, " %d aborted ", cur->id); break; - case STATUS_FINISHED: - fprintf(stderr, " %d finished ", cur->id); break; - - } - fprintf(stderr, "\n"); - //fprintf(stderr, cur->s); - } -} - -void queue_finish(struct line *node, char *s, int fid) { - struct line *next; - if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); - pthread_mutex_lock(&queue_mutex); - - free(node->s); - node->s = s; - node->status = STATUS_FINISHED; - n_received++; - - /* Flush out finished nodes */ - while (head && head->status == STATUS_FINISHED) { - - if (log_mutex) fprintf(stderr, " Flushing finished node %d\n", head->id); - - fputs(head->s, stdout); - fflush(stdout); - if (log_mutex) fprintf(stderr, " Flushed node %d\n", head->id); - free(head->s); - - next = head->next; - free(head); - - head = next; - - n_flushed++; - - if (head == NULL) { /* empty queue */ - if (ptail == NULL) { /* This can only happen if set in queue_get as signal that there is no more input. */ - fprintf(stderr, "All sentences finished. Exiting.\n"); - done(0); - } else /* ptail pointed at something which was just popped off the stack -- reset to head*/ - ptail = &head; - } - } - - if (log_mutex) fprintf(stderr, " Flushing output %d\n", head->id); - fflush(stdout); - fprintf(stderr, "%d sentences sent, %d sentences finished, %d sentences flushed\n", n_sent, n_received, n_flushed); - - if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); - pthread_mutex_unlock(&queue_mutex); - -} - -char * read_line(int fd, int multiline) { - int size = 80; - char errorbuf[100]; - char *s = malloc(size+2); - int result, errors=0; - int i = 0; - - result = read(fd, s+i, 1); - - while (1) { - if (result < 0) { - perror("read()"); - sprintf(errorbuf, "Error code: %d\n", errno); - fprintf(stderr, errorbuf); - errors++; - if (errors > 5) { - free(s); - return NULL; - } else { - sleep(1); /* retry after delay */ - } - } else if (result == 0) { - break; - } else if (multiline==0 && s[i] == '\n') { - break; - } else { - if (s[i] == '\n'){ - /* if we've reached this point, - then multiline must be 1, and we're - going to poll the fd for an additional - line of data. The basic design is to - run a select on the filedescriptor fd. - Select will return under two conditions: - if there is data on the fd, or if a - timeout is reached. We'll select on this - fd. If select returns because there's data - ready, keep going; else assume there's no - more and return the data we already have. - */ - - fd_set set; - FD_ZERO(&set); - FD_SET(fd, &set); - - struct timeval timeout; - timeout.tv_sec = 3; // number of seconds for timeout - timeout.tv_usec = 0; - - int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); - if (ready<1){ - break; // no more data, stop looping - } - } - i++; - - if (i == size) { - size = size*2; - s = realloc(s, size+2); - } - } - - result = read(fd, s+i, 1); - } - - if (result == 0 && i == 0) { /* end of file */ - free(s); - return NULL; - } - - s[i] = '\n'; - s[i+1] = '\0'; - - return s; -} - -void * new_client(void *arg) { - struct clientinfo *client = (struct clientinfo *)arg; - struct line *cur; - int result; - char *s; - char errorbuf[100]; - - pthread_mutex_lock(&clients_mutex); - n_clients++; - pthread_mutex_unlock(&clients_mutex); - - fprintf(stderr, "Client connected (%d connected)\n", n_clients); - - for (;;) { - - cur = queue_get(client->s); - - if (cur) { - /* fprintf(stderr, "Sending to client: %s", cur->s); */ - fprintf(stderr, "Sending data %d to client (fid %d)\n", cur->id, client->s); - result = write(client->s, cur->s, strlen(cur->s)); - if (result < strlen(cur->s)){ - perror("write()"); - sprintf(errorbuf, "Error code: %d\n", errno); - fprintf(stderr, errorbuf); - - pthread_mutex_lock(&clients_mutex); - n_clients--; - pthread_mutex_unlock(&clients_mutex); - - fprintf(stderr, "Client died (%d connected)\n", n_clients); - queue_abort(cur, client->s); - - close(client->s); - free(client); - - pthread_exit(NULL); - } - } else { - close(client->s); - pthread_mutex_lock(&clients_mutex); - n_clients--; - pthread_mutex_unlock(&clients_mutex); - fprintf(stderr, "Client dismissed (%d connected)\n", n_clients); - pthread_exit(NULL); - } - - s = read_line(client->s,expect_multiline_output); - if (s) { - /* fprintf(stderr, "Client (fid %d) returned: %s", client->s, s); */ - fprintf(stderr, "Client (fid %d) returned data %d\n", client->s, cur->id); -// queue_print(); - queue_finish(cur, s, client->s); - } else { - pthread_mutex_lock(&clients_mutex); - n_clients--; - pthread_mutex_unlock(&clients_mutex); - - fprintf(stderr, "Client died (%d connected)\n", n_clients); - queue_abort(cur, client->s); - - close(client->s); - free(client); - - pthread_exit(NULL); - } - - } - return 0; -} - -void done (int code) { - close(s); - exit(code); -} - - - -int main (int argc, char *argv[]) { - struct sockaddr_in sin, from; - int g; - socklen_t len; - struct clientinfo *client; - int port; - int opt; - int errors = 0; - int argi; - char *key = NULL, *client_key; - int use_key = 0; - /* the key stuff here doesn't provide any - real measure of security, it's mainly to keep - jobs from bumping into each other. */ - - pthread_t tid; - port = DEFAULT_PORT; - - for (argi=1; argi < argc; argi++){ - if (strcmp(argv[argi], "-m")==0){ - expect_multiline_output = 1; - } else if (strcmp(argv[argi], "-k")==0){ - argi++; - if (argi == argc){ - fprintf(stderr, "Key must be specified after -k\n"); - exit(1); - } - key = argv[argi]; - use_key = 1; - } else if (strcmp(argv[argi], "--stay-alive")==0){ - stay_alive = 1; /* dont panic and die with zero clients */ - } else { - port = atoi(argv[argi]); - } - } - - /* Initialize data structures */ - head = NULL; - ptail = &head; - - /* Set up listener */ - s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - opt = 1; - setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); - - sin.sin_family = AF_INET; - sin.sin_addr.s_addr = htonl(INADDR_ANY); - sin.sin_port = htons(port); - while (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) { - perror("bind()"); - sleep(1); - errors++; - if (errors > 100) - exit(1); - } - - len = sizeof(sin); - getsockname(s, (struct sockaddr *) &sin, &len); - - fprintf(stderr, "Listening on port %hu\n", ntohs(sin.sin_port)); - - while (listen(s, MAX_CLIENTS) < 0) { - perror("listen()"); - sleep(1); - errors++; - if (errors > 100) - exit(1); - } - - for (;;) { - len = sizeof(from); - g = accept(s, (struct sockaddr *)&from, &len); - if (g < 0) { - perror("accept()"); - sleep(1); - continue; - } - client = malloc(sizeof(struct clientinfo)); - client->s = g; - bcopy(&from, &client->sin, len); - - if (use_key){ - fd_set set; - FD_ZERO(&set); - FD_SET(client->s, &set); - - struct timeval timeout; - timeout.tv_sec = 3; // number of seconds for timeout - timeout.tv_usec = 0; - - int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); - if (ready<1){ - fprintf(stderr, "Prospective client failed to respond with correct key.\n"); - close(client->s); - free(client); - } else { - client_key = read_line(client->s,0); - client_key[strlen(client_key)-1]='\0'; /* chop trailing newline */ - if (strcmp(key, client_key)==0){ - pthread_create(&tid, NULL, new_client, client); - } else { - fprintf(stderr, "Prospective client failed to respond with correct key.\n"); - close(client->s); - free(client); - } - free(client_key); - } - } else { - pthread_create(&tid, NULL, new_client, client); - } - } - -} - - - diff --git a/dpmert/sentserver.h b/dpmert/sentserver.h deleted file mode 100644 index cd17a546..00000000 --- a/dpmert/sentserver.h +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef SENTSERVER_H -#define SENTSERVER_H - -#define DEFAULT_PORT 50000 - -#endif diff --git a/dpmert/test_aer/README b/dpmert/test_aer/README deleted file mode 100644 index 819b2e32..00000000 --- a/dpmert/test_aer/README +++ /dev/null @@ -1,8 +0,0 @@ -To run the test: - -../dist-vest.pl --local --metric aer cdec.ini --source-file corpus.src --ref-files=ref.0 --weights weights - -This will optimize the parameters of the tiny lexical translation model -so as to minimize the AER of the Viterbi alignment on the development -set in corpus.src according to the reference alignments in ref.0. - diff --git a/dpmert/test_aer/cdec.ini b/dpmert/test_aer/cdec.ini deleted file mode 100644 index 08187848..00000000 --- a/dpmert/test_aer/cdec.ini +++ /dev/null @@ -1,3 +0,0 @@ -formalism=lextrans -grammar=grammar -aligner=true diff --git a/dpmert/test_aer/corpus.src b/dpmert/test_aer/corpus.src deleted file mode 100644 index 31b23971..00000000 --- a/dpmert/test_aer/corpus.src +++ /dev/null @@ -1,3 +0,0 @@ -el gato negro ||| the black cat -el gato ||| the cat -el libro ||| the book diff --git a/dpmert/test_aer/grammar b/dpmert/test_aer/grammar deleted file mode 100644 index 9d857824..00000000 --- a/dpmert/test_aer/grammar +++ /dev/null @@ -1,12 +0,0 @@ -el ||| cat ||| F1=1 -el ||| the ||| F2=1 -el ||| black ||| F3=1 -el ||| book ||| F11=1 -gato ||| cat ||| F4=1 NN=1 -gato ||| black ||| F5=1 -gato ||| the ||| F6=1 -negro ||| the ||| F7=1 -negro ||| cat ||| F8=1 -negro ||| black ||| F9=1 -libro ||| the ||| F10=1 -libro ||| book ||| F12=1 NN=1 diff --git a/dpmert/test_aer/ref.0 b/dpmert/test_aer/ref.0 deleted file mode 100644 index 734a9c5b..00000000 --- a/dpmert/test_aer/ref.0 +++ /dev/null @@ -1,3 +0,0 @@ -0-0 1-2 2-1 -0-0 1-1 -0-0 1-1 diff --git a/dpmert/test_aer/weights b/dpmert/test_aer/weights deleted file mode 100644 index afc9282e..00000000 --- a/dpmert/test_aer/weights +++ /dev/null @@ -1,13 +0,0 @@ -F1 0.1 -F2 -.5980815 -F3 0.24235 -F4 0.625 -F5 0.4514 -F6 0.112316 -F7 -0.123415 -F8 -0.25390285 -F9 -0.23852 -F10 0.646 -F11 0.413141 -F12 0.343216 -NN -0.1215 diff --git a/dpmert/test_data/0.json.gz b/dpmert/test_data/0.json.gz deleted file mode 100644 index 30f8dd77..00000000 Binary files a/dpmert/test_data/0.json.gz and /dev/null differ diff --git a/dpmert/test_data/1.json.gz b/dpmert/test_data/1.json.gz deleted file mode 100644 index c82cc179..00000000 Binary files a/dpmert/test_data/1.json.gz and /dev/null differ diff --git a/dpmert/test_data/c2e.txt.0 b/dpmert/test_data/c2e.txt.0 deleted file mode 100644 index 12c4abe9..00000000 --- a/dpmert/test_data/c2e.txt.0 +++ /dev/null @@ -1,2 +0,0 @@ -australia reopens embassy in manila -( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack . diff --git a/dpmert/test_data/c2e.txt.1 b/dpmert/test_data/c2e.txt.1 deleted file mode 100644 index 4ac12df1..00000000 --- a/dpmert/test_data/c2e.txt.1 +++ /dev/null @@ -1,2 +0,0 @@ -australia reopened manila embassy -( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack . diff --git a/dpmert/test_data/c2e.txt.2 b/dpmert/test_data/c2e.txt.2 deleted file mode 100644 index 2f67b72f..00000000 --- a/dpmert/test_data/c2e.txt.2 +++ /dev/null @@ -1,2 +0,0 @@ -australia to reopen embassy in manila -( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so-called confirmed terrorist attack threats . diff --git a/dpmert/test_data/c2e.txt.3 b/dpmert/test_data/c2e.txt.3 deleted file mode 100644 index 5483cef6..00000000 --- a/dpmert/test_data/c2e.txt.3 +++ /dev/null @@ -1,2 +0,0 @@ -australia to re - open its embassy to manila -( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so-called " clear " threat of terrorist attack 7 weeks ago . diff --git a/dpmert/test_data/re.txt.0 b/dpmert/test_data/re.txt.0 deleted file mode 100644 index 86eff087..00000000 --- a/dpmert/test_data/re.txt.0 +++ /dev/null @@ -1,5 +0,0 @@ -erdogan states turkey to reject any pressures to urge it to recognize cyprus -ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara will reject any pressure by the european union to urge it to recognize cyprus . this comes two weeks before the summit of european union state and government heads who will decide whether or nor membership negotiations with ankara should be opened . -erdogan told " ntv " television station that " the european union cannot address us by imposing new conditions on us with regard to cyprus . -we will discuss this dossier in the course of membership negotiations . " -he added " let me be clear , i cannot sidestep turkey , this is something we cannot accept . " diff --git a/dpmert/test_data/re.txt.1 b/dpmert/test_data/re.txt.1 deleted file mode 100644 index 2140f198..00000000 --- a/dpmert/test_data/re.txt.1 +++ /dev/null @@ -1,5 +0,0 @@ -erdogan confirms turkey will resist any pressure to recognize cyprus -ankara 12 - 1 ( afp ) - the turkish head of government , recep tayyip erdogan , announced today ( wednesday ) that ankara would resist any pressure the european union might exercise in order to force it into recognizing cyprus . this comes two weeks before a summit of european union heads of state and government , who will decide whether or not to open membership negotiations with ankara . -erdogan said to the ntv television channel : " the european union cannot engage with us through imposing new conditions on us with regard to cyprus . -we shall discuss this issue in the course of the membership negotiations . " -he added : " let me be clear - i cannot confine turkey . this is something we do not accept . " diff --git a/dpmert/test_data/re.txt.2 b/dpmert/test_data/re.txt.2 deleted file mode 100644 index 94e46286..00000000 --- a/dpmert/test_data/re.txt.2 +++ /dev/null @@ -1,5 +0,0 @@ -erdogan confirms that turkey will reject any pressures to encourage it to recognize cyprus -ankara , 12 / 1 ( afp ) - the turkish prime minister recep tayyip erdogan declared today , wednesday , that ankara will reject any pressures that the european union may apply on it to encourage to recognize cyprus . this comes two weeks before a summit of the heads of countries and governments of the european union , who will decide on whether or not to start negotiations on joining with ankara . -erdogan told the ntv television station that " it is not possible for the european union to talk to us by imposing new conditions on us regarding cyprus . -we shall discuss this dossier during the negotiations on joining . " -and he added , " let me be clear . turkey's arm should not be twisted ; this is something we cannot accept . " diff --git a/dpmert/test_data/re.txt.3 b/dpmert/test_data/re.txt.3 deleted file mode 100644 index f87c3308..00000000 --- a/dpmert/test_data/re.txt.3 +++ /dev/null @@ -1,5 +0,0 @@ -erdogan stresses that turkey will reject all pressures to force it to recognize cyprus -ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara would refuse all pressures applied on it by the european union to force it to recognize cyprus . that came two weeks before the summit of the presidents and prime ministers of the european union , who would decide on whether to open negotiations on joining with ankara or not . -erdogan said to " ntv " tv station that the " european union can not communicate with us by imposing on us new conditions related to cyprus . -we will discuss this file during the negotiations on joining . " -he added , " let me be clear . turkey's arm should not be twisted . this is unacceptable to us . " diff --git a/dtrain/Makefile.am b/dtrain/Makefile.am deleted file mode 100644 index ca9581f5..00000000 --- a/dtrain/Makefile.am +++ /dev/null @@ -1,7 +0,0 @@ -bin_PROGRAMS = dtrain - -dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval - diff --git a/dtrain/README.md b/dtrain/README.md deleted file mode 100644 index 7edabbf1..00000000 --- a/dtrain/README.md +++ /dev/null @@ -1,48 +0,0 @@ -This is a simple (and parallelizable) tuning method for cdec -which is able to train the weights of very many (sparse) features. -It was used here: - "Joint Feature Selection in Distributed Stochastic - Learning for Large-Scale Discriminative Training in - SMT" -(Simianer, Riezler, Dyer; ACL 2012) - - -Building --------- -Builds when building cdec, see ../BUILDING . -To build only parts needed for dtrain do -``` - autoreconf -ifv - ./configure [--disable-gtest] - cd dtrain/; make -``` - -Running -------- -To run this on a dev set locally: -``` - #define DTRAIN_LOCAL -``` -otherwise remove that line or undef, then recompile. You need a single -grammar file or input annotated with per-sentence grammars (psg) as you -would use with cdec. Additionally you need to give dtrain a file with -references (--refs) when running locally. - -The input for use with hadoop streaming looks like this: -``` - \t\t\t -``` -To convert a psg to this format you need to replace all "\n" -by "\t". Make sure there are no tabs in your data. - -For an example of local usage (with the 'distributed' format) -the see test/example/ . This expects dtrain to be built without -DTRAIN_LOCAL. - -Legal ------ -Copyright (c) 2012 by Patrick Simianer - -See the file ../LICENSE.txt for the licensing terms that this software is -released under. - diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc deleted file mode 100644 index 18286668..00000000 --- a/dtrain/dtrain.cc +++ /dev/null @@ -1,657 +0,0 @@ -#include "dtrain.h" - - -bool -dtrain_init(int argc, char** argv, po::variables_map* cfg) -{ - po::options_description ini("Configuration File Options"); - ini.add_options() - ("input", po::value()->default_value("-"), "input file") - ("output", po::value()->default_value("-"), "output weights file, '-' for STDOUT") - ("input_weights", po::value(), "input weights file (e.g. from previous iteration)") - ("decoder_config", po::value(), "configuration file for cdec") - ("print_weights", po::value(), "weights to print on each iteration") - ("stop_after", po::value()->default_value(0), "stop after X input sentences") - ("tmp", po::value()->default_value("/tmp"), "temp dir to use") - ("keep", po::value()->zero_tokens(), "keep weights files for each iteration") - ("hstreaming", po::value(), "run in hadoop streaming mode, arg is a task id") - ("epochs", po::value()->default_value(10), "# of iterations T (per shard)") - ("k", po::value()->default_value(100), "how many translations to sample") - ("sample_from", po::value()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") - ("filter", po::value()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") - ("pair_sampling", po::value()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") - ("hi_lo", po::value()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") - ("pair_threshold", po::value()->default_value(0.), "bleu [0,1] threshold to filter pairs") - ("N", po::value()->default_value(4), "N for Ngrams (BLEU)") - ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_") - ("learning_rate", po::value()->default_value(1.0), "learning rate") - ("gamma", po::value()->default_value(0.), "gamma for SVM (0 for perceptron)") - ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") - ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") - ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") - ("l1_reg_strength", po::value(), "l1 regularization strength") - ("fselect", po::value()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO - ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") - ("scale_bleu_diff", po::value()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") - ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") - ("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.") -#ifdef DTRAIN_LOCAL - ("refs,r", po::value(), "references in local mode") -#endif - ("noup", po::value()->zero_tokens(), "do not update weights"); - po::options_description cl("Command Line Options"); - cl.add_options() - ("config,c", po::value(), "dtrain config file") - ("quiet,q", po::value()->zero_tokens(), "be quiet") - ("verbose,v", po::value()->zero_tokens(), "be verbose"); - cl.add(ini); - po::store(parse_command_line(argc, argv, cl), *cfg); - if (cfg->count("config")) { - ifstream ini_f((*cfg)["config"].as().c_str()); - po::store(po::parse_config_file(ini_f, ini), *cfg); - } - po::notify(*cfg); - if (!cfg->count("decoder_config")) { - cerr << cl << endl; - return false; - } - if (cfg->count("hstreaming") && (*cfg)["output"].as() != "-") { - cerr << "When using 'hstreaming' the 'output' param should be '-'." << endl; - return false; - } -#ifdef DTRAIN_LOCAL - if ((*cfg)["input"].as() == "-") { - cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl; - return false; - } -#endif - if ((*cfg)["sample_from"].as() != "kbest" - && (*cfg)["sample_from"].as() != "forest") { - cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as() << "', use 'kbest' or 'forest'." << endl; - return false; - } - if ((*cfg)["sample_from"].as() == "kbest" && (*cfg)["filter"].as() != "uniq" && - (*cfg)["filter"].as() != "not") { - cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as() << "', use 'uniq' or 'not'." << endl; - return false; - } - if ((*cfg)["pair_sampling"].as() != "all" && (*cfg)["pair_sampling"].as() != "XYX" && - (*cfg)["pair_sampling"].as() != "PRO") { - cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as() << "'." << endl; - return false; - } - if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as() != "XYX") { - cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; - } - if((*cfg)["hi_lo"].as() > 0.5 || (*cfg)["hi_lo"].as() < 0.01) { - cerr << "hi_lo must lie in [0.01, 0.5]" << endl; - return false; - } - if ((*cfg)["pair_threshold"].as() < 0) { - cerr << "The threshold must be >= 0!" << endl; - return false; - } - if ((*cfg)["select_weights"].as() != "last" && (*cfg)["select_weights"].as() != "best" && - (*cfg)["select_weights"].as() != "avg" && (*cfg)["select_weights"].as() != "VOID") { - cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as() << "', use 'last' or 'best'." << endl; - return false; - } - return true; -} - -int -main(int argc, char** argv) -{ - // handle most parameters - po::variables_map cfg; - if (!dtrain_init(argc, argv, &cfg)) exit(1); // something is wrong - bool quiet = false; - if (cfg.count("quiet")) quiet = true; - bool verbose = false; - if (cfg.count("verbose")) verbose = true; - bool noup = false; - if (cfg.count("noup")) noup = true; - bool hstreaming = false; - string task_id; - if (cfg.count("hstreaming")) { - hstreaming = true; - quiet = true; - task_id = cfg["hstreaming"].as(); - cerr.precision(17); - } - bool rescale = false; - if (cfg.count("rescale")) rescale = true; - HSReporter rep(task_id); - bool keep = false; - if (cfg.count("keep")) keep = true; - - const unsigned k = cfg["k"].as(); - const unsigned N = cfg["N"].as(); - const unsigned T = cfg["epochs"].as(); - const unsigned stop_after = cfg["stop_after"].as(); - const string filter_type = cfg["filter"].as(); - const string sample_from = cfg["sample_from"].as(); - const string pair_sampling = cfg["pair_sampling"].as(); - const score_t pair_threshold = cfg["pair_threshold"].as(); - const string select_weights = cfg["select_weights"].as(); - const float hi_lo = cfg["hi_lo"].as(); - const score_t approx_bleu_d = cfg["approx_bleu_d"].as(); - const unsigned max_pairs = cfg["max_pairs"].as(); - weight_t loss_margin = cfg["loss_margin"].as(); - if (loss_margin > 9998.) loss_margin = std::numeric_limits::max(); - bool scale_bleu_diff = false; - if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; - bool average = false; - if (select_weights == "avg") - average = true; - vector print_weights; - if (cfg.count("print_weights")) - boost::split(print_weights, cfg["print_weights"].as(), boost::is_any_of(" ")); - - // setup decoder - register_feature_functions(); - SetSilent(true); - ReadFile ini_rf(cfg["decoder_config"].as()); - if (!quiet) - cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; - Decoder decoder(ini_rf.stream()); - - // scoring metric/scorer - string scorer_str = cfg["scorer"].as(); - LocalScorer* scorer; - if (scorer_str == "bleu") { - scorer = dynamic_cast(new BleuScorer); - } else if (scorer_str == "stupid_bleu") { - scorer = dynamic_cast(new StupidBleuScorer); - } else if (scorer_str == "smooth_bleu") { - scorer = dynamic_cast(new SmoothBleuScorer); - } else if (scorer_str == "sum_bleu") { - scorer = dynamic_cast(new SumBleuScorer); - } else if (scorer_str == "sumexp_bleu") { - scorer = dynamic_cast(new SumExpBleuScorer); - } else if (scorer_str == "sumwhatever_bleu") { - scorer = dynamic_cast(new SumWhateverBleuScorer); - } else if (scorer_str == "approx_bleu") { - scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d)); - } else if (scorer_str == "lc_bleu") { - scorer = dynamic_cast(new LinearBleuScorer(N)); - } else { - cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; - exit(1); - } - vector bleu_weights; - scorer->Init(N, bleu_weights); - - // setup decoder observer - MT19937 rng; // random number generator, only for forest sampling - HypSampler* observer; - if (sample_from == "kbest") - observer = dynamic_cast(new KBestGetter(k, filter_type)); - else - observer = dynamic_cast(new KSampler(k, &rng)); - observer->SetScorer(scorer); - - // init weights - vector& dense_weights = decoder.CurrentWeightVector(); - SparseVector lambdas, cumulative_penalties, w_average; - if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &dense_weights); - Weights::InitSparseVector(dense_weights, &lambdas); - - // meta params for perceptron, SVM - weight_t eta = cfg["learning_rate"].as(); - weight_t gamma = cfg["gamma"].as(); - - // l1 regularization - bool l1naive = false; - bool l1clip = false; - bool l1cumul = false; - weight_t l1_reg = 0; - if (cfg["l1_reg"].as() != "none") { - string s = cfg["l1_reg"].as(); - if (s == "naive") l1naive = true; - else if (s == "clip") l1clip = true; - else if (s == "cumul") l1cumul = true; - l1_reg = cfg["l1_reg_strength"].as(); - } - - // output - string output_fn = cfg["output"].as(); - // input - string input_fn = cfg["input"].as(); - ReadFile input(input_fn); - // buffer input for t > 0 - vector src_str_buf; // source strings (decoder takes only strings) - vector > ref_ids_buf; // references as WordID vecs - // where temp files go - string tmp_path = cfg["tmp"].as(); -#ifdef DTRAIN_LOCAL - string refs_fn = cfg["refs"].as(); - ReadFile refs(refs_fn); -#else - string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars"); - ogzstream grammar_buf_out; - grammar_buf_out.open(grammar_buf_fn.c_str()); -#endif - - unsigned in_sz = std::numeric_limits::max(); // input index, input size - vector > all_scores; - score_t max_score = 0.; - unsigned best_it = 0; - float overall_time = 0.; - - // output cfg - if (!quiet) { - cerr << _p5; - cerr << endl << "dtrain" << endl << "Parameters:" << endl; - cerr << setw(25) << "k " << k << endl; - cerr << setw(25) << "N " << N << endl; - cerr << setw(25) << "T " << T << endl; - cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; - if (scorer_str == "approx_bleu") - cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; - cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; - if (sample_from == "kbest") - cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; - if (!scale_bleu_diff) cerr << setw(25) << "learning rate " << eta << endl; - else cerr << setw(25) << "learning rate " << "bleu diff" << endl; - cerr << setw(25) << "gamma " << gamma << endl; - cerr << setw(25) << "loss margin " << loss_margin << endl; - cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; - if (pair_sampling == "XYX") - cerr << setw(25) << "hi lo " << hi_lo << endl; - cerr << setw(25) << "pair threshold " << pair_threshold << endl; - cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; - if (cfg.count("l1_reg")) - cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as() << "'" << endl; - if (rescale) - cerr << setw(25) << "rescale " << rescale << endl; - cerr << setw(25) << "max pairs " << max_pairs << endl; - cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; - cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; -#ifdef DTRAIN_LOCAL - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; -#endif - cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; - if (cfg.count("input_weights")) - cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as() << "'" << endl; - if (stop_after > 0) - cerr << setw(25) << "stop_after " << stop_after << endl; - if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; - } - - - for (unsigned t = 0; t < T; t++) // T epochs - { - - if (hstreaming) cerr << "reporter:status:Iteration #" << t+1 << " of " << T << endl; - - time_t start, end; - time(&start); -#ifndef DTRAIN_LOCAL - igzstream grammar_buf_in; - if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str()); -#endif - score_t score_sum = 0.; - score_t model_sum(0); - unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; - if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; - - while(true) - { - - string in; - bool next = false, stop = false; // next iteration or premature stop - if (t == 0) { - if(!getline(*input, in)) next = true; - } else { - if (ii == in_sz) next = true; // stop if we reach the end of our input - } - // stop after X sentences (but still go on for those) - if (stop_after > 0 && stop_after == ii && !next) stop = true; - - // produce some pretty output - if (!quiet && !verbose) { - if (ii == 0) cerr << " "; - if ((ii+1) % (DTRAIN_DOTS) == 0) { - cerr << "."; - cerr.flush(); - } - if ((ii+1) % (20*DTRAIN_DOTS) == 0) { - cerr << " " << ii+1 << endl; - if (!next && !stop) cerr << " "; - } - if (stop) { - if (ii % (20*DTRAIN_DOTS) != 0) cerr << " " << ii << endl; - cerr << "Stopping after " << stop_after << " input sentences." << endl; - } else { - if (next) { - if (ii % (20*DTRAIN_DOTS) != 0) cerr << " " << ii << endl; - } - } - } - - // next iteration - if (next || stop) break; - - // weights - lambdas.init_vector(&dense_weights); - - // getting input - vector ref_ids; // reference as vector -#ifndef DTRAIN_LOCAL - vector in_split; // input: sid\tsrc\tref\tpsg - if (t == 0) { - // handling input - split_in(in, in_split); - if (hstreaming && ii == 0) cerr << "reporter:counter:" << task_id << ",First ID," << in_split[0] << endl; - // getting reference - vector ref_tok; - boost::split(ref_tok, in_split[2], boost::is_any_of(" ")); - register_and_convert(ref_tok, ref_ids); - ref_ids_buf.push_back(ref_ids); - // process and set grammar - bool broken_grammar = true; // ignore broken grammars - for (string::iterator it = in.begin(); it != in.end(); it++) { - if (!isspace(*it)) { - broken_grammar = false; - break; - } - } - if (broken_grammar) { - cerr << "Broken grammar for " << ii+1 << "! Ignoring this input." << endl; - continue; - } - boost::replace_all(in, "\t", "\n"); - in += "\n"; - grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; - decoder.AddSupplementalGrammarFromString(in); - src_str_buf.push_back(in_split[1]); - // decode - observer->SetRef(ref_ids); - decoder.Decode(in_split[1], observer); - } else { - // get buffered grammar - string grammar_str; - while (true) { - string rule; - getline(grammar_buf_in, rule); - if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break; - grammar_str += rule + "\n"; - } - decoder.AddSupplementalGrammarFromString(grammar_str); - // decode - observer->SetRef(ref_ids_buf[ii]); - decoder.Decode(src_str_buf[ii], observer); - } -#else - if (t == 0) { - string r_; - getline(*refs, r_); - vector ref_tok; - boost::split(ref_tok, r_, boost::is_any_of(" ")); - register_and_convert(ref_tok, ref_ids); - ref_ids_buf.push_back(ref_ids); - src_str_buf.push_back(in); - } else { - ref_ids = ref_ids_buf[ii]; - } - observer->SetRef(ref_ids); - if (t == 0) - decoder.Decode(in, observer); - else - decoder.Decode(src_str_buf[ii], observer); -#endif - - // get (scored) samples - vector* samples = observer->GetSamples(); - - if (verbose) { - cerr << "--- ref for " << ii << ": "; - if (t > 0) printWordIDVec(ref_ids_buf[ii]); - else printWordIDVec(ref_ids); - cerr << endl; - for (unsigned u = 0; u < samples->size(); u++) { - cerr << _p2 << _np << "[" << u << ". '"; - printWordIDVec((*samples)[u].w); - cerr << "'" << endl; - cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl; - cerr << "F{" << (*samples)[u].f << "} ]" << endl << endl; - } - } - - score_sum += (*samples)[0].score; // stats for 1best - model_sum += (*samples)[0].model; - - f_count += observer->get_f_count(); - list_sz += observer->get_sz(); - - // weight updates - if (!noup) { - // get pairs - vector > pairs; - if (pair_sampling == "all") - all_pairs(samples, pairs, pair_threshold, max_pairs); - if (pair_sampling == "XYX") - partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo); - if (pair_sampling == "PRO") - PROsampling(samples, pairs, pair_threshold, max_pairs); - npairs += pairs.size(); - - for (vector >::iterator it = pairs.begin(); - it != pairs.end(); it++) { -#ifdef DTRAIN_FASTER_PERCEPTRON - bool rank_error = true; // pair sampling already did this for us - rank_errors++; - score_t margin = std::numeric_limits::max(); -#else - bool rank_error = it->first.model <= it->second.model; - if (rank_error) rank_errors++; - score_t margin = fabs(fabs(it->first.model) - fabs(it->second.model)); - if (!rank_error && margin < loss_margin) margin_violations++; -#endif - if (scale_bleu_diff) eta = it->first.score - it->second.score; - if (rank_error || margin < loss_margin) { - SparseVector diff_vec = it->first.f - it->second.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - if (gamma) - lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); - } - } - - // l1 regularization - if (l1naive) { - for (unsigned d = 0; d < lambdas.size(); d++) { - weight_t v = lambdas.get(d); - lambdas.set_value(d, v - sign(v) * l1_reg); - } - } else if (l1clip) { - for (unsigned d = 0; d < lambdas.size(); d++) { - if (lambdas.nonzero(d)) { - weight_t v = lambdas.get(d); - if (v > 0) { - lambdas.set_value(d, max(0., v - l1_reg)); - } else { - lambdas.set_value(d, min(0., v + l1_reg)); - } - } - } - } else if (l1cumul) { - weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input - for (unsigned d = 0; d < lambdas.size(); d++) { - if (lambdas.nonzero(d)) { - weight_t v = lambdas.get(d); - weight_t penalty = 0; - if (v > 0) { - penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); - } else { - penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); - } - lambdas.set_value(d, penalty); - cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); - } - } - } - - } - - if (rescale) lambdas /= lambdas.l2norm(); - - ++ii; - - if (hstreaming) { - rep.update_counter("Seen #"+boost::lexical_cast(t+1), 1u); - rep.update_counter("Seen", 1u); - } - - } // input loop - - if (average) w_average += lambdas; - - if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); - - if (t == 0) { - in_sz = ii; // remember size of input (# lines) - if (hstreaming) { - rep.update_counter("|Input|", ii); - rep.update_gcounter("|Input|", ii); - rep.update_gcounter("Shards", 1u); - } - } - -#ifndef DTRAIN_LOCAL - if (t == 0) { - grammar_buf_out.close(); - } else { - grammar_buf_in.close(); - } -#endif - - // print some stats - score_t score_avg = score_sum/(score_t)in_sz; - score_t model_avg = model_sum/(score_t)in_sz; - score_t score_diff, model_diff; - if (t > 0) { - score_diff = score_avg - all_scores[t-1].first; - model_diff = model_avg - all_scores[t-1].second; - } else { - score_diff = score_avg; - model_diff = model_avg; - } - - unsigned nonz = 0; - if (!quiet || hstreaming) nonz = (unsigned)lambdas.num_nonzero(); - - if (!quiet) { - cerr << _p5 << _p << "WEIGHTS" << endl; - for (vector::iterator it = print_weights.begin(); it != print_weights.end(); it++) { - cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl; - } - cerr << " ---" << endl; - cerr << _np << " 1best avg score: " << score_avg; - cerr << _p << " (" << score_diff << ")" << endl; - cerr << _np << " 1best avg model score: " << model_avg; - cerr << _p << " (" << model_diff << ")" << endl; - cerr << " avg # pairs: "; - cerr << _np << npairs/(float)in_sz << endl; - cerr << " avg # rank err: "; - cerr << rank_errors/(float)in_sz << endl; -#ifndef DTRAIN_FASTER_PERCEPTRON - cerr << " avg # margin viol: "; - cerr << margin_violations/(float)in_sz << endl; -#endif - cerr << " non0 feature count: " << nonz << endl; - cerr << " avg list sz: " << list_sz/(float)in_sz << endl; - cerr << " avg f count: " << f_count/(float)list_sz << endl; - } - - if (hstreaming) { - rep.update_counter("Score 1best avg #"+boost::lexical_cast(t+1), (unsigned)(score_avg*DTRAIN_SCALE)); - rep.update_counter("Model 1best avg #"+boost::lexical_cast(t+1), (unsigned)(model_avg*DTRAIN_SCALE)); - rep.update_counter("Pairs avg #"+boost::lexical_cast(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE)); - rep.update_counter("Rank errors avg #"+boost::lexical_cast(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE)); - rep.update_counter("Margin violations avg #"+boost::lexical_cast(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE)); - rep.update_counter("Non zero feature count #"+boost::lexical_cast(t+1), nonz); - rep.update_gcounter("Non zero feature count #"+boost::lexical_cast(t+1), nonz); - } - - pair remember; - remember.first = score_avg; - remember.second = model_avg; - all_scores.push_back(remember); - if (score_avg > max_score) { - max_score = score_avg; - best_it = t; - } - time (&end); - float time_diff = difftime(end, start); - overall_time += time_diff; - if (!quiet) { - cerr << _p2 << _np << "(time " << time_diff/60. << " min, "; - cerr << time_diff/in_sz << " s/S)" << endl; - } - if (t+1 != T && !quiet) cerr << endl; - - if (noup) break; - - // write weights to file - if (select_weights == "best" || keep) { - lambdas.init_vector(&dense_weights); - string w_fn = "weights." + boost::lexical_cast(t) + ".gz"; - Weights::WriteToFile(w_fn, dense_weights, true); - } - - } // outer loop - - if (average) w_average /= (weight_t)T; - -#ifndef DTRAIN_LOCAL - unlink(grammar_buf_fn.c_str()); -#endif - - if (!noup) { - if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; - if (select_weights == "last" || average) { // last, average - WriteFile of(output_fn); // works with '-' - ostream& o = *of.stream(); - o.precision(17); - o << _np; - if (average) { - for (SparseVector::iterator it = w_average.begin(); it != w_average.end(); ++it) { - if (it->second == 0) continue; - o << FD::Convert(it->first) << '\t' << it->second << endl; - } - } else { - for (SparseVector::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { - if (it->second == 0) continue; - o << FD::Convert(it->first) << '\t' << it->second << endl; - } - } - } else if (select_weights == "VOID") { // do nothing with the weights - } else { // best - if (output_fn != "-") { - CopyFile("weights."+boost::lexical_cast(best_it)+".gz", output_fn); - } else { - ReadFile bestw("weights."+boost::lexical_cast(best_it)+".gz"); - string o; - cout.precision(17); - cout << _np; - while(getline(*bestw, o)) cout << o << endl; - } - if (!keep) { - for (unsigned i = 0; i < T; i++) { - string s = "weights." + boost::lexical_cast(i) + ".gz"; - unlink(s.c_str()); - } - } - } - if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl; - if (!quiet) cerr << "done" << endl; - } - - if (!quiet) { - cerr << _p5 << _np << endl << "---" << endl << "Best iteration: "; - cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl; - cerr << "This took " << overall_time/60. << " min." << endl; - } -} - diff --git a/dtrain/dtrain.h b/dtrain/dtrain.h deleted file mode 100644 index 4b6f415c..00000000 --- a/dtrain/dtrain.h +++ /dev/null @@ -1,97 +0,0 @@ -#ifndef _DTRAIN_H_ -#define _DTRAIN_H_ - -#undef DTRAIN_FASTER_PERCEPTRON // only look at misranked pairs - // DO NOT USE WITH SVM! -//#define DTRAIN_LOCAL -#define DTRAIN_DOTS 10 // after how many inputs to display a '.' -#define DTRAIN_GRAMMAR_DELIM "########EOS########" -#define DTRAIN_SCALE 100000 - - -#include -#include -#include - -#include -#include - -#include "ksampler.h" -#include "pairsampling.h" - -#include "filelib.h" - - -using namespace std; -using namespace dtrain; -namespace po = boost::program_options; - -inline void register_and_convert(const vector& strs, vector& ids) -{ - vector::const_iterator it; - for (it = strs.begin(); it < strs.end(); it++) - ids.push_back(TD::Convert(*it)); -} - -inline string gettmpf(const string path, const string infix) -{ - char fn[path.size() + infix.size() + 8]; - strcpy(fn, path.c_str()); - strcat(fn, "/"); - strcat(fn, infix.c_str()); - strcat(fn, "-XXXXXX"); - if (!mkstemp(fn)) { - cerr << "Cannot make temp file in" << path << " , exiting." << endl; - exit(1); - } - return string(fn); -} - -inline void split_in(string& s, vector& parts) -{ - unsigned f = 0; - for(unsigned i = 0; i < 3; i++) { - unsigned e = f; - f = s.find("\t", f+1); - if (e != 0) parts.push_back(s.substr(e+1, f-e-1)); - else parts.push_back(s.substr(0, f)); - } - s.erase(0, f+1); -} - -struct HSReporter -{ - string task_id_; - - HSReporter(string task_id) : task_id_(task_id) {} - - inline void update_counter(string name, unsigned amount) { - cerr << "reporter:counter:" << task_id_ << "," << name << "," << amount << endl; - } - inline void update_gcounter(string name, unsigned amount) { - cerr << "reporter:counter:Global," << name << "," << amount << endl; - } -}; - -inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } -inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } -inline ostream& _p2(ostream& out) { return out << setprecision(2); } -inline ostream& _p5(ostream& out) { return out << setprecision(5); } - -inline void printWordIDVec(vector& v) -{ - for (unsigned i = 0; i < v.size(); i++) { - cerr << TD::Convert(v[i]); - if (i < v.size()-1) cerr << " "; - } -} - -template -inline T sign(T z) -{ - if (z == 0) return 0; - return z < 0 ? -1 : +1; -} - -#endif - diff --git a/dtrain/hstreaming/avg.rb b/dtrain/hstreaming/avg.rb deleted file mode 100755 index 2599c732..00000000 --- a/dtrain/hstreaming/avg.rb +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env ruby -# first arg may be an int of custom shard count - -shard_count_key = "__SHARD_COUNT__" - -STDIN.set_encoding 'utf-8' -STDOUT.set_encoding 'utf-8' - -w = {} -c = {} -w.default = 0 -c.default = 0 -while line = STDIN.gets - key, val = line.split /\s/ - w[key] += val.to_f - c[key] += 1 -end - -if ARGV.size == 0 - shard_count = w["__SHARD_COUNT__"] -else - shard_count = ARGV[0].to_f -end -w.each_key { |k| - if k == shard_count_key - next - else - puts "#{k}\t#{w[k]/shard_count}" - #puts "# #{c[k]}" - end -} - diff --git a/dtrain/hstreaming/cdec.ini b/dtrain/hstreaming/cdec.ini deleted file mode 100644 index d4f5cecd..00000000 --- a/dtrain/hstreaming/cdec.ini +++ /dev/null @@ -1,22 +0,0 @@ -formalism=scfg -add_pass_through_rules=true -scfg_max_span_limit=15 -intersection_strategy=cube_pruning -cubepruning_pop_limit=30 -feature_function=WordPenalty -feature_function=KLanguageModel nc-wmt11.en.srilm.gz -#feature_function=ArityPenalty -#feature_function=CMR2008ReorderingFeatures -#feature_function=Dwarf -#feature_function=InputIndicator -#feature_function=LexNullJump -#feature_function=NewJump -#feature_function=NgramFeatures -#feature_function=NonLatinCount -#feature_function=OutputIndicator -#feature_function=RuleIdentityFeatures -#feature_function=RuleNgramFeatures -#feature_function=RuleShape -#feature_function=SourceSpanSizeFeatures -#feature_function=SourceWordPenalty -#feature_function=SpanFeatures diff --git a/dtrain/hstreaming/dtrain.ini b/dtrain/hstreaming/dtrain.ini deleted file mode 100644 index a2c219a1..00000000 --- a/dtrain/hstreaming/dtrain.ini +++ /dev/null @@ -1,15 +0,0 @@ -input=- -output=- -decoder_config=cdec.ini -tmp=/var/hadoop/mapred/local/ -epochs=1 -k=100 -N=4 -learning_rate=0.0001 -gamma=0 -scorer=stupid_bleu -sample_from=kbest -filter=uniq -pair_sampling=XYX -pair_threshold=0 -select_weights=last diff --git a/dtrain/hstreaming/dtrain.sh b/dtrain/hstreaming/dtrain.sh deleted file mode 100755 index 877ff94c..00000000 --- a/dtrain/hstreaming/dtrain.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -# script to run dtrain with a task id - -pushd . &>/dev/null -cd .. -ID=$(basename $(pwd)) # attempt_... -popd &>/dev/null -./dtrain -c dtrain.ini --hstreaming $ID - diff --git a/dtrain/hstreaming/hadoop-streaming-job.sh b/dtrain/hstreaming/hadoop-streaming-job.sh deleted file mode 100755 index 92419956..00000000 --- a/dtrain/hstreaming/hadoop-streaming-job.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/sh - -EXP=a_simple_test - -# change these vars to fit your hadoop installation -HADOOP_HOME=/usr/lib/hadoop-0.20 -JAR=contrib/streaming/hadoop-streaming-0.20.2-cdh3u1.jar -HSTREAMING="$HADOOP_HOME/bin/hadoop jar $HADOOP_HOME/$JAR" - - IN=input_on_hdfs -OUT=output_weights_on_hdfs - -# you can -reducer to NONE if you want to -# do feature selection/averaging locally (e.g. to -# keep weights of all epochs) -$HSTREAMING \ - -mapper "dtrain.sh" \ - -reducer "ruby lplp.rb l2 select_k 100000" \ - -input $IN \ - -output $OUT \ - -file dtrain.sh \ - -file lplp.rb \ - -file ../dtrain \ - -file dtrain.ini \ - -file cdec.ini \ - -file ../test/example/nc-wmt11.en.srilm.gz \ - -jobconf mapred.reduce.tasks=30 \ - -jobconf mapred.max.map.failures.percent=0 \ - -jobconf mapred.job.name="dtrain $EXP" - diff --git a/dtrain/hstreaming/lplp.rb b/dtrain/hstreaming/lplp.rb deleted file mode 100755 index f0cd58c5..00000000 --- a/dtrain/hstreaming/lplp.rb +++ /dev/null @@ -1,131 +0,0 @@ -# lplp.rb - -# norms -def l0(feature_column, n) - if feature_column.size >= n then return 1 else return 0 end -end - -def l1(feature_column, n=-1) - return feature_column.map { |i| i.abs }.reduce { |sum,i| sum+i } -end - -def l2(feature_column, n=-1) - return Math.sqrt feature_column.map { |i| i.abs2 }.reduce { |sum,i| sum+i } -end - -def linfty(feature_column, n=-1) - return feature_column.map { |i| i.abs }.max -end - -# stats -def median(feature_column, n) - return feature_column.concat(0.step(n-feature_column.size-1).map{|i|0}).sort[feature_column.size/2] -end - -def mean(feature_column, n) - return feature_column.reduce { |sum, i| sum+i } / n -end - -# selection -def select_k(weights, norm_fun, n, k=10000) - weights.sort{|a,b| norm_fun.call(b[1], n) <=> norm_fun.call(a[1], n)}.each { |p| - puts "#{p[0]}\t#{mean(p[1], n)}" - k -= 1 - if k == 0 then break end - } -end - -def cut(weights, norm_fun, n, epsilon=0.0001) - weights.each { |k,v| - if norm_fun.call(v, n).abs >= epsilon - puts "#{k}\t#{mean(v, n)}" - end - } -end - -# test -def _test() - puts - w = {} - w["a"] = [1, 2, 3] - w["b"] = [1, 2] - w["c"] = [66] - w["d"] = [10, 20, 30] - n = 3 - puts w.to_s - puts - puts "select_k" - puts "l0 expect ad" - select_k(w, method(:l0), n, 2) - puts "l1 expect cd" - select_k(w, method(:l1), n, 2) - puts "l2 expect c" - select_k(w, method(:l2), n, 1) - puts - puts "cut" - puts "l1 expect cd" - cut(w, method(:l1), n, 7) - puts - puts "median" - a = [1,2,3,4,5] - puts a.to_s - puts median(a, 5) - puts - puts "#{median(a, 7)} <- that's because we add missing 0s:" - puts a.concat(0.step(7-a.size-1).map{|i|0}).to_s - puts - puts "mean expect bc" - w.clear - w["a"] = [2] - w["b"] = [2.1] - w["c"] = [2.2] - cut(w, method(:mean), 1, 2.05) - exit -end -#_test() - -# actually do something -def usage() - puts "lplp.rb [n] < " - puts " l0...: norms for selection" - puts "select_k: only output top k (according to the norm of their column vector) features" - puts " cut: output features with weight >= threshold" - puts " n: if we do not have a shard count use this number for averaging" - exit -end - -if ARGV.size < 3 then usage end -norm_fun = method(ARGV[0].to_sym) -type = ARGV[1] -x = ARGV[2].to_f - -shard_count_key = "__SHARD_COUNT__" - -STDIN.set_encoding 'utf-8' -STDOUT.set_encoding 'utf-8' - -w = {} -shard_count = 0 -while line = STDIN.gets - key, val = line.split /\s+/ - if key == shard_count_key - shard_count += 1 - next - end - if w.has_key? key - w[key].push val.to_f - else - w[key] = [val.to_f] - end -end - -if ARGV.size == 4 then shard_count = ARGV[3].to_f end - -if type == 'cut' - cut(w, norm_fun, shard_count, x) -elsif type == 'select_k' - select_k(w, norm_fun, shard_count, x) -else - puts "oh oh" -end - diff --git a/dtrain/hstreaming/red-test b/dtrain/hstreaming/red-test deleted file mode 100644 index 2623d697..00000000 --- a/dtrain/hstreaming/red-test +++ /dev/null @@ -1,9 +0,0 @@ -a 1 -b 2 -c 3.5 -a 1 -b 2 -c 3.5 -d 1 -e 2 -__SHARD_COUNT__ 2 diff --git a/dtrain/kbestget.h b/dtrain/kbestget.h deleted file mode 100644 index dd8882e1..00000000 --- a/dtrain/kbestget.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef _DTRAIN_KBESTGET_H_ -#define _DTRAIN_KBESTGET_H_ - -#include "kbest.h" // cdec -#include "sentence_metadata.h" - -#include "verbose.h" -#include "viterbi.h" -#include "ff_register.h" -#include "decoder.h" -#include "weights.h" -#include "logval.h" - -using namespace std; - -namespace dtrain -{ - - -typedef double score_t; - -struct ScoredHyp -{ - vector w; - SparseVector f; - score_t model; - score_t score; - unsigned rank; -}; - -struct LocalScorer -{ - unsigned N_; - vector w_; - - virtual score_t - Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len)=0; - - void Reset() {} // only for approx bleu - - inline void - Init(unsigned N, vector weights) - { - assert(N > 0); - N_ = N; - if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); - else w_ = weights; - } - - inline score_t - brevity_penalty(const unsigned hyp_len, const unsigned ref_len) - { - if (hyp_len > ref_len) return 1; - return exp(1 - (score_t)ref_len/hyp_len); - } -}; - -struct HypSampler : public DecoderObserver -{ - LocalScorer* scorer_; - vector* ref_; - unsigned f_count_, sz_; - virtual vector* GetSamples()=0; - inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } - inline void SetRef(vector& ref) { ref_ = &ref; } - inline unsigned get_f_count() { return f_count_; } - inline unsigned get_sz() { return sz_; } -}; -//////////////////////////////////////////////////////////////////////////////// - - - - -struct KBestGetter : public HypSampler -{ - const unsigned k_; - const string filter_type_; - vector s_; - unsigned src_len_; - - KBestGetter(const unsigned k, const string filter_type) : - k_(k), filter_type_(filter_type) {} - - virtual void - NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) - { - src_len_ = smeta.GetSourceLength(); - KBestScored(*hg); - } - - vector* GetSamples() { return &s_; } - - void - KBestScored(const Hypergraph& forest) - { - if (filter_type_ == "uniq") { - KBestUnique(forest); - } else if (filter_type_ == "not") { - KBestNoFilter(forest); - } - } - - void - KBestUnique(const Hypergraph& forest) - { - s_.clear(); sz_ = f_count_ = 0; - KBest::KBestDerivations, ESentenceTraversal, - KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); - for (unsigned i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, - prob_t, EdgeProb>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - } - - void - KBestNoFilter(const Hypergraph& forest) - { - s_.clear(); sz_ = f_count_ = 0; - KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k_); - for (unsigned i = 0; i < k_; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - ScoredHyp h; - h.w = d->yield; - h.f = d->feature_values; - h.model = log(d->score); - h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - } -}; - - -} // namespace - -#endif - diff --git a/dtrain/ksampler.h b/dtrain/ksampler.h deleted file mode 100644 index bc2f56cd..00000000 --- a/dtrain/ksampler.h +++ /dev/null @@ -1,61 +0,0 @@ -#ifndef _DTRAIN_KSAMPLER_H_ -#define _DTRAIN_KSAMPLER_H_ - -#include "hg_sampler.h" // cdec -#include "kbestget.h" -#include "score.h" - -namespace dtrain -{ - -bool -cmp_hyp_by_model_d(ScoredHyp a, ScoredHyp b) -{ - return a.model > b.model; -} - -struct KSampler : public HypSampler -{ - const unsigned k_; - vector s_; - MT19937* prng_; - score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector); - unsigned src_len_; - - explicit KSampler(const unsigned k, MT19937* prng) : - k_(k), prng_(prng) {} - - virtual void - NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) - { - src_len_ = smeta.GetSourceLength(); - ScoredSamples(*hg); - } - - vector* GetSamples() { return &s_; } - - void ScoredSamples(const Hypergraph& forest) { - s_.clear(); sz_ = f_count_ = 0; - std::vector samples; - HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples); - for (unsigned i = 0; i < k_; ++i) { - ScoredHyp h; - h.w = samples[i].words; - h.f = samples[i].fmap; - h.model = log(samples[i].model_score); - h.rank = i; - h.score = scorer_->Score(h.w, *ref_, i, src_len_); - s_.push_back(h); - sz_++; - f_count_ += h.f.size(); - } - sort(s_.begin(), s_.end(), cmp_hyp_by_model_d); - for (unsigned i = 0; i < s_.size(); i++) s_[i].rank = i; - } -}; - - -} // namespace - -#endif - diff --git a/dtrain/pairsampling.h b/dtrain/pairsampling.h deleted file mode 100644 index 84be1efb..00000000 --- a/dtrain/pairsampling.h +++ /dev/null @@ -1,149 +0,0 @@ -#ifndef _DTRAIN_PAIRSAMPLING_H_ -#define _DTRAIN_PAIRSAMPLING_H_ - -namespace dtrain -{ - - -bool -accept_pair(score_t a, score_t b, score_t threshold) -{ - if (fabs(a - b) < threshold) return false; - return true; -} - -bool -cmp_hyp_by_score_d(ScoredHyp a, ScoredHyp b) -{ - return a.score > b.score; -} - -inline void -all_pairs(vector* s, vector >& training, score_t threshold, unsigned max, float _unused=1) -{ - sort(s->begin(), s->end(), cmp_hyp_by_score_d); - unsigned sz = s->size(); - bool b = false; - unsigned count = 0; - for (unsigned i = 0; i < sz-1; i++) { - for (unsigned j = i+1; j < sz; j++) { - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) { - b = true; - break; - } - } - if (b) break; - } -} - -/* - * multipartite ranking - * sort (descending) by bleu - * compare top X to middle Y and low X - * cmp middle Y to low X - */ - -inline void -partXYX(vector* s, vector >& training, score_t threshold, unsigned max, float hi_lo) -{ - unsigned sz = s->size(); - if (sz < 2) return; - sort(s->begin(), s->end(), cmp_hyp_by_score_d); - unsigned sep = round(sz*hi_lo); - unsigned sep_hi = sep; - if (sz > 4) while (sep_hi < sz && (*s)[sep_hi-1].score == (*s)[sep_hi].score) ++sep_hi; - else sep_hi = 1; - bool b = false; - unsigned count = 0; - for (unsigned i = 0; i < sep_hi; i++) { - for (unsigned j = sep_hi; j < sz; j++) { -#ifdef DTRAIN_FASTER_PERCEPTRON - if ((*s)[i].model <= (*s)[j].model) { -#endif - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) { - b = true; - break; - } -#ifdef DTRAIN_FASTER_PERCEPTRON - } -#endif - } - if (b) break; - } - unsigned sep_lo = sz-sep; - while (sep_lo > 0 && (*s)[sep_lo-1].score == (*s)[sep_lo].score) --sep_lo; - for (unsigned i = sep_hi; i < sz-sep_lo; i++) { - for (unsigned j = sz-sep_lo; j < sz; j++) { -#ifdef DTRAIN_FASTER_PERCEPTRON - if ((*s)[i].model <= (*s)[j].model) { -#endif - if (threshold > 0) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) - training.push_back(make_pair((*s)[i], (*s)[j])); - } else { - if ((*s)[i].score != (*s)[j].score) - training.push_back(make_pair((*s)[i], (*s)[j])); - } - if (++count == max) return; -#ifdef DTRAIN_FASTER_PERCEPTRON - } -#endif - } - } -} - -/* - * pair sampling as in - * 'Tuning as Ranking' (Hopkins & May, 2011) - * count = 5000 - * threshold = 5% BLEU (0.05 for param 3) - * cut = top 50 - */ -bool -_PRO_cmp_pair_by_diff_d(pair a, pair b) -{ - return (fabs(a.first.score - a.second.score)) > (fabs(b.first.score - b.second.score)); -} -inline void -PROsampling(vector* s, vector >& training, score_t threshold, unsigned max, float _unused=1) -{ - unsigned max_count = 5000, count = 0, sz = s->size(); - bool b = false; - for (unsigned i = 0; i < sz-1; i++) { - for (unsigned j = i+1; j < sz; j++) { - if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) { - training.push_back(make_pair((*s)[i], (*s)[j])); - if (++count == max_count) { - b = true; - break; - } - } - } - if (b) break; - } - if (training.size() > 50) { - sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff_d); - training.erase(training.begin()+50, training.end()); - } - return; -} - - -} // namespace - -#endif - diff --git a/dtrain/parallelize.rb b/dtrain/parallelize.rb deleted file mode 100755 index 1d277ff6..00000000 --- a/dtrain/parallelize.rb +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env ruby - - -if ARGV.size != 5 - STDERR.write "Usage: " - STDERR.write "ruby parallelize.rb <#shards> \n" - exit -end - -dtrain_bin = '/home/pks/bin/dtrain_local' -ruby = '/usr/bin/ruby' -lplp_rb = '/home/pks/mt/cdec-dtrain/dtrain/hstreaming/lplp.rb' -lplp_args = 'l2 select_k 100000' -gzip = '/bin/gzip' - -num_shards = ARGV[0].to_i -input = ARGV[1] -refs = ARGV[2] -epochs = ARGV[3].to_i -ini = ARGV[4] - - -`mkdir work` - -def make_shards(input, refs, num_shards) - lc = `wc -l #{input}`.split.first.to_i - shard_sz = lc / num_shards - leftover = lc % num_shards - in_f = File.new input, 'r' - refs_f = File.new refs, 'r' - shard_in_files = [] - shard_refs_files = [] - 0.upto(num_shards-1) { |shard| - shard_in = File.new "work/shard.#{shard}.in", 'w+' - shard_refs = File.new "work/shard.#{shard}.refs", 'w+' - 0.upto(shard_sz-1) { |i| - shard_in.write in_f.gets - shard_refs.write refs_f.gets - } - shard_in_files << shard_in - shard_refs_files << shard_refs - } - while leftover > 0 - shard_in_files[-1].write in_f.gets - shard_refs_files[-1].write refs_f.gets - leftover -= 1 - end - (shard_in_files + shard_refs_files).each do |f| f.close end - in_f.close - refs_f.close -end - -make_shards input, refs, num_shards - -0.upto(epochs-1) { |epoch| - pids = [] - input_weights = '' - if epoch > 0 then input_weights = "--input_weights work/weights.#{epoch-1}" end - weights_files = [] - 0.upto(num_shards-1) { |shard| - pids << Kernel.fork { - `#{dtrain_bin} -c #{ini}\ - --input work/shard.#{shard}.in\ - --refs work/shard.#{shard}.refs #{input_weights}\ - --output work/weights.#{shard}.#{epoch}\ - &> work/out.#{shard}.#{epoch}` - } - weights_files << "work/weights.#{shard}.#{epoch}" - } - pids.each { |pid| Process.wait(pid) } - cat = File.new('work/weights_cat', 'w+') - weights_files.each { |f| cat.write File.new(f, 'r').read } - cat.close - `#{ruby} #{lplp_rb} #{lplp_args} #{num_shards} < work/weights_cat &> work/weights.#{epoch}` -} - -`rm work/weights_cat` -`#{gzip} work/*` - diff --git a/dtrain/parallelize/test/cdec.ini b/dtrain/parallelize/test/cdec.ini deleted file mode 100644 index 72e99dc5..00000000 --- a/dtrain/parallelize/test/cdec.ini +++ /dev/null @@ -1,22 +0,0 @@ -formalism=scfg -add_pass_through_rules=true -intersection_strategy=cube_pruning -cubepruning_pop_limit=200 -scfg_max_span_limit=15 -feature_function=WordPenalty -feature_function=KLanguageModel /stor/dat/wmt12/en/news_only/m/wmt12.news.en.3.kenv5 -#feature_function=ArityPenalty -#feature_function=CMR2008ReorderingFeatures -#feature_function=Dwarf -#feature_function=InputIndicator -#feature_function=LexNullJump -#feature_function=NewJump -#feature_function=NgramFeatures -#feature_function=NonLatinCount -#feature_function=OutputIndicator -#feature_function=RuleIdentityFeatures -#feature_function=RuleNgramFeatures -#feature_function=RuleShape -#feature_function=SourceSpanSizeFeatures -#feature_function=SourceWordPenalty -#feature_function=SpanFeatures diff --git a/dtrain/parallelize/test/dtrain.ini b/dtrain/parallelize/test/dtrain.ini deleted file mode 100644 index 03f9d240..00000000 --- a/dtrain/parallelize/test/dtrain.ini +++ /dev/null @@ -1,15 +0,0 @@ -k=100 -N=4 -learning_rate=0.0001 -gamma=0 -loss_margin=0 -epochs=1 -scorer=stupid_bleu -sample_from=kbest -filter=uniq -pair_sampling=XYX -hi_lo=0.1 -select_weights=last -print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough -tmp=/tmp -decoder_config=cdec.ini diff --git a/dtrain/parallelize/test/in b/dtrain/parallelize/test/in deleted file mode 100644 index a312809f..00000000 --- a/dtrain/parallelize/test/in +++ /dev/null @@ -1,10 +0,0 @@ -barack obama erhält als vierter us @-@ präsident den frieden nobelpreis -der amerikanische präsident barack obama kommt für 26 stunden nach oslo , norwegen , um hier als vierter us @-@ präsident in der geschichte den frieden nobelpreis entgegen zunehmen . -darüber hinaus erhält er das diplom sowie die medaille und einen scheck über 1,4 mio. dollar für seine außer gewöhnlichen bestrebungen um die intensivierung der welt diplomatie und zusammen arbeit unter den völkern . -der chef des weißen hauses kommt morgen zusammen mit seiner frau michelle in der nordwegischen metropole an und wird die ganze zeit beschäftigt sein . -zunächst stattet er dem nobel @-@ institut einen besuch ab , wo er überhaupt zum ersten mal mit den fünf ausschuss mitglieder zusammen trifft , die ihn im oktober aus 172 leuten und 33 organisationen gewählt haben . -das präsidenten paar hat danach ein treffen mit dem norwegischen könig harald v. und königin sonja eingeplant . -nachmittags erreicht dann der besuch seinen höhepunkt mit der zeremonie , bei der obama den prestige preis übernimmt . -diesen erhält er als der vierte us @-@ präsident , aber erst als der dritte , der den preis direkt im amt entgegen nimmt . -das weiße haus avisierte schon , dass obama bei der übernahme des preises über den afghanistan krieg sprechen wird . -der präsident will diesem thema nicht ausweichen , weil er weiß , dass er den preis als ein präsident übernimmt , der zur zeit krieg in zwei ländern führt . diff --git a/dtrain/parallelize/test/refs b/dtrain/parallelize/test/refs deleted file mode 100644 index 4d3128cb..00000000 --- a/dtrain/parallelize/test/refs +++ /dev/null @@ -1,10 +0,0 @@ -barack obama becomes the fourth american president to receive the nobel peace prize -the american president barack obama will fly into oslo , norway for 26 hours to receive the nobel peace prize , the fourth american president in history to do so . -he will receive a diploma , medal and cheque for 1.4 million dollars for his exceptional efforts to improve global diplomacy and encourage international cooperation , amongst other things . -the head of the white house will be flying into the norwegian city in the morning with his wife michelle and will have a busy schedule . -first , he will visit the nobel institute , where he will have his first meeting with the five committee members who selected him from 172 people and 33 organisations . -the presidential couple then has a meeting scheduled with king harald v and queen sonja of norway . -then , in the afternoon , the visit will culminate in a grand ceremony , at which obama will receive the prestigious award . -he will be the fourth american president to be awarded the prize , and only the third to have received it while actually in office . -the white house has stated that , when he accepts the prize , obama will speak about the war in afghanistan . -the president does not want to skirt around this topic , as he realises that he is accepting the prize as a president whose country is currently at war in two countries . diff --git a/dtrain/score.cc b/dtrain/score.cc deleted file mode 100644 index 34fc86a9..00000000 --- a/dtrain/score.cc +++ /dev/null @@ -1,254 +0,0 @@ -#include "score.h" - -namespace dtrain -{ - - -/* - * bleu - * - * as in "BLEU: a Method for Automatic Evaluation - * of Machine Translation" - * (Papineni et al. '02) - * - * NOTE: 0 if for one n \in {1..N} count is 0 - */ -score_t -BleuScorer::Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len) -{ - if (hyp_len == 0 || ref_len == 0) return 0.; - unsigned M = N_; - vector v = w_; - if (ref_len < N_) { - M = ref_len; - for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); - } - score_t sum = 0; - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.; - sum += v[i] * log((score_t)counts.clipped_[i]/counts.sum_[i]); - } - return brevity_penalty(hyp_len, ref_len) * exp(sum); -} - -score_t -BleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - return Bleu(counts, hyp_len, ref_len); -} - -/* - * 'stupid' bleu - * - * as in "ORANGE: a Method for Evaluating - * Automatic Evaluation Metrics - * for Machine Translation" - * (Lin & Och '04) - * - * NOTE: 0 iff no 1gram match - */ -score_t -StupidBleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - unsigned M = N_; - vector v = w_; - if (ref_len < N_) { - M = ref_len; - for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); - } - score_t sum = 0, add = 0; - for (unsigned i = 0; i < M; i++) { - if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; - if (i == 1) add = 1; - sum += v[i] * log(((score_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); - } - return brevity_penalty(hyp_len, ref_len) * exp(sum); -} - -/* - * smooth bleu - * - * as in "An End-to-End Discriminative Approach - * to Machine Translation" - * (Liang et al. '06) - * - * NOTE: max is 0.9375 (with N=4) - */ -score_t -SmoothBleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - unsigned M = N_; - if (ref_len < N_) M = ref_len; - score_t sum = 0.; - vector i_bleu; - for (unsigned i = 0; i < M; i++) i_bleu.push_back(0.); - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) { - break; - } else { - score_t i_ng = log((score_t)counts.clipped_[i]/counts.sum_[i]); - for (unsigned j = i; j < M; j++) { - i_bleu[j] += (1/((score_t)j+1)) * i_ng; - } - } - sum += exp(i_bleu[i])/pow(2.0, (double)(N_-i)); - } - return brevity_penalty(hyp_len, ref_len) * sum; -} - -/* - * 'sum' bleu - * - * sum up Ngram precisions - */ -score_t -SumBleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - unsigned M = N_; - if (ref_len < N_) M = ref_len; - score_t sum = 0.; - unsigned j = 1; - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; - sum += ((score_t)counts.clipped_[i]/counts.sum_[i])/pow(2.0, (double) (N_-j+1)); - j++; - } - return brevity_penalty(hyp_len, ref_len) * sum; -} - -/* - * 'sum' (exp) bleu - * - * sum up exp(Ngram precisions) - */ -score_t -SumExpBleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - unsigned M = N_; - if (ref_len < N_) M = ref_len; - score_t sum = 0.; - unsigned j = 1; - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; - sum += exp(((score_t)counts.clipped_[i]/counts.sum_[i]))/pow(2.0, (double) (N_-j+1)); - j++; - } - return brevity_penalty(hyp_len, ref_len) * sum; -} - -/* - * 'sum' (whatever) bleu - * - * sum up exp(weight * log(Ngram precisions)) - */ -score_t -SumWhateverBleuScorer::Score(vector& hyp, vector& ref, - const unsigned /*rank*/, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (hyp_len == 0 || ref_len == 0) return 0.; - NgramCounts counts = make_ngram_counts(hyp, ref, N_); - unsigned M = N_; - vector v = w_; - if (ref_len < N_) { - M = ref_len; - for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); - } - score_t sum = 0.; - unsigned j = 1; - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; - sum += exp(v[i] * log(((score_t)counts.clipped_[i]/counts.sum_[i])))/pow(2.0, (double) (N_-j+1)); - j++; - } - return brevity_penalty(hyp_len, ref_len) * sum; -} - -/* - * approx. bleu - * - * as in "Online Large-Margin Training of Syntactic - * and Structural Translation Features" - * (Chiang et al. '08) - * - * NOTE: Needs some more code in dtrain.cc . - * No scaling by src len. - */ -score_t -ApproxBleuScorer::Score(vector& hyp, vector& ref, - const unsigned rank, const unsigned src_len) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (ref_len == 0) return 0.; - score_t score = 0.; - NgramCounts counts(N_); - if (hyp_len > 0) { - counts = make_ngram_counts(hyp, ref, N_); - NgramCounts tmp = glob_onebest_counts_ + counts; - score = Bleu(tmp, hyp_len, ref_len); - } - if (rank == 0) { // 'context of 1best translations' - glob_onebest_counts_ += counts; - glob_onebest_counts_ *= discount_; - glob_hyp_len_ = discount_ * (glob_hyp_len_ + hyp_len); - glob_ref_len_ = discount_ * (glob_ref_len_ + ref_len); - glob_src_len_ = discount_ * (glob_src_len_ + src_len); - } - return score; -} - -/* - * Linear (Corpus) Bleu - * - * as in "Lattice Minimum Bayes-Risk Decoding - * for Statistical Machine Translation" - * (Tromble et al. '08) - * - */ -score_t -LinearBleuScorer::Score(vector& hyp, vector& ref, - const unsigned rank, const unsigned /*src_len*/) -{ - unsigned hyp_len = hyp.size(), ref_len = ref.size(); - if (ref_len == 0) return 0.; - unsigned M = N_; - if (ref_len < N_) M = ref_len; - NgramCounts counts(M); - if (hyp_len > 0) - counts = make_ngram_counts(hyp, ref, M); - score_t ret = 0.; - for (unsigned i = 0; i < M; i++) { - if (counts.sum_[i] == 0 || onebest_counts_.sum_[i] == 0) break; - ret += counts.sum_[i]/onebest_counts_.sum_[i]; - } - ret = -(hyp_len/(score_t)onebest_len_) + (1./M) * ret; - if (rank == 0) { - onebest_len_ += hyp_len; - onebest_counts_ += counts; - } - return ret; -} - - -} // namespace - diff --git a/dtrain/score.h b/dtrain/score.h deleted file mode 100644 index f317c903..00000000 --- a/dtrain/score.h +++ /dev/null @@ -1,212 +0,0 @@ -#ifndef _DTRAIN_SCORE_H_ -#define _DTRAIN_SCORE_H_ - -#include "kbestget.h" - -using namespace std; - -namespace dtrain -{ - - -struct NgramCounts -{ - unsigned N_; - map clipped_; - map sum_; - - NgramCounts(const unsigned N) : N_(N) { Zero(); } - - inline void - operator+=(const NgramCounts& rhs) - { - if (rhs.N_ > N_) Resize(rhs.N_); - for (unsigned i = 0; i < N_; i++) { - this->clipped_[i] += rhs.clipped_.find(i)->second; - this->sum_[i] += rhs.sum_.find(i)->second; - } - } - - inline const NgramCounts - operator+(const NgramCounts &other) const - { - NgramCounts result = *this; - result += other; - return result; - } - - inline void - operator*=(const score_t rhs) - { - for (unsigned i = 0; i < N_; i++) { - this->clipped_[i] *= rhs; - this->sum_[i] *= rhs; - } - } - - inline void - Add(const unsigned count, const unsigned ref_count, const unsigned i) - { - assert(i < N_); - if (count > ref_count) { - clipped_[i] += ref_count; - } else { - clipped_[i] += count; - } - sum_[i] += count; - } - - inline void - Zero() - { - for (unsigned i = 0; i < N_; i++) { - clipped_[i] = 0.; - sum_[i] = 0.; - } - } - - inline void - One() - { - for (unsigned i = 0; i < N_; i++) { - clipped_[i] = 1.; - sum_[i] = 1.; - } - } - - inline void - Print() - { - for (unsigned i = 0; i < N_; i++) { - cout << i+1 << "grams (clipped):\t" << clipped_[i] << endl; - cout << i+1 << "grams:\t\t\t" << sum_[i] << endl; - } - } - - inline void Resize(unsigned N) - { - if (N == N_) return; - else if (N > N_) { - for (unsigned i = N_; i < N; i++) { - clipped_[i] = 0.; - sum_[i] = 0.; - } - } else { // N < N_ - for (unsigned i = N_-1; i > N-1; i--) { - clipped_.erase(i); - sum_.erase(i); - } - } - N_ = N; - } -}; - -typedef map, unsigned> Ngrams; - -inline Ngrams -make_ngrams(const vector& s, const unsigned N) -{ - Ngrams ngrams; - vector ng; - for (size_t i = 0; i < s.size(); i++) { - ng.clear(); - for (unsigned j = i; j < min(i+N, s.size()); j++) { - ng.push_back(s[j]); - ngrams[ng]++; - } - } - return ngrams; -} - -inline NgramCounts -make_ngram_counts(const vector& hyp, const vector& ref, const unsigned N) -{ - Ngrams hyp_ngrams = make_ngrams(hyp, N); - Ngrams ref_ngrams = make_ngrams(ref, N); - NgramCounts counts(N); - Ngrams::iterator it; - Ngrams::iterator ti; - for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { - ti = ref_ngrams.find(it->first); - if (ti != ref_ngrams.end()) { - counts.Add(it->second, ti->second, it->first.size() - 1); - } else { - counts.Add(it->second, 0, it->first.size() - 1); - } - } - return counts; -} - -struct BleuScorer : public LocalScorer -{ - score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct StupidBleuScorer : public LocalScorer -{ - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct SmoothBleuScorer : public LocalScorer -{ - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct SumBleuScorer : public LocalScorer -{ - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct SumExpBleuScorer : public LocalScorer -{ - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct SumWhateverBleuScorer : public LocalScorer -{ - score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); -}; - -struct ApproxBleuScorer : public BleuScorer -{ - NgramCounts glob_onebest_counts_; - unsigned glob_hyp_len_, glob_ref_len_, glob_src_len_; - score_t discount_; - - ApproxBleuScorer(unsigned N, score_t d) : glob_onebest_counts_(NgramCounts(N)), discount_(d) - { - glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0; - } - - inline void Reset() { - glob_onebest_counts_.Zero(); - glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0.; - } - - score_t Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len); -}; - -struct LinearBleuScorer : public BleuScorer -{ - unsigned onebest_len_; - NgramCounts onebest_counts_; - - LinearBleuScorer(unsigned N) : onebest_len_(1), onebest_counts_(N) - { - onebest_counts_.One(); - } - - score_t Score(vector& hyp, vector& ref, const unsigned rank, const unsigned /*src_len*/); - - inline void Reset() { - onebest_len_ = 1; - onebest_counts_.One(); - } -}; - - -} // namespace - -#endif - diff --git a/dtrain/test/example/README b/dtrain/test/example/README deleted file mode 100644 index 6937b11b..00000000 --- a/dtrain/test/example/README +++ /dev/null @@ -1,8 +0,0 @@ -Small example of input format for distributed training. -Call dtrain from cdec/dtrain/ with ./dtrain -c test/example/dtrain.ini . - -For this to work, undef 'DTRAIN_LOCAL' in dtrain.h -and recompile. - -Data is here: http://simianer.de/#dtrain - diff --git a/dtrain/test/example/cdec.ini b/dtrain/test/example/cdec.ini deleted file mode 100644 index d5955f0e..00000000 --- a/dtrain/test/example/cdec.ini +++ /dev/null @@ -1,25 +0,0 @@ -formalism=scfg -add_pass_through_rules=true -scfg_max_span_limit=15 -intersection_strategy=cube_pruning -cubepruning_pop_limit=30 -feature_function=WordPenalty -feature_function=KLanguageModel test/example/nc-wmt11.en.srilm.gz -# all currently working feature functions for translation: -# (with those features active that were used in the ACL paper) -#feature_function=ArityPenalty -#feature_function=CMR2008ReorderingFeatures -#feature_function=Dwarf -#feature_function=InputIndicator -#feature_function=LexNullJump -#feature_function=NewJump -#feature_function=NgramFeatures -#feature_function=NonLatinCount -#feature_function=OutputIndicator -feature_function=RuleIdentityFeatures -feature_function=RuleSourceBigramFeatures -feature_function=RuleTargetBigramFeatures -feature_function=RuleShape -#feature_function=SourceSpanSizeFeatures -#feature_function=SourceWordPenalty -#feature_function=SpanFeatures diff --git a/dtrain/test/example/dtrain.ini b/dtrain/test/example/dtrain.ini deleted file mode 100644 index 72d50ca1..00000000 --- a/dtrain/test/example/dtrain.ini +++ /dev/null @@ -1,22 +0,0 @@ -input=test/example/nc-wmt11.1k.gz # use '-' for STDIN -output=- # a weights file (add .gz for gzip compression) or STDOUT '-' -select_weights=VOID # don't output weights -decoder_config=test/example/cdec.ini # config for cdec -# weights for these features will be printed on each iteration -print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough -tmp=/tmp -stop_after=10 # stop epoch after 10 inputs - -# interesting stuff -epochs=2 # run over input 2 times -k=100 # use 100best lists -N=4 # optimize (approx) BLEU4 -scorer=stupid_bleu # use 'stupid' BLEU+1 -learning_rate=1.0 # learning rate, don't care if gamma=0 (perceptron) -gamma=0 # use SVM reg -sample_from=kbest # use kbest lists (as opposed to forest) -filter=uniq # only unique entries in kbest (surface form) -pair_sampling=XYX -hi_lo=0.1 # 10 vs 80 vs 10 and 80 vs 10 here -pair_threshold=0 # minimum distance in BLEU (this will still only use pairs with diff > 0) -loss_margin=0 diff --git a/dtrain/test/example/expected-output b/dtrain/test/example/expected-output deleted file mode 100644 index 05326763..00000000 --- a/dtrain/test/example/expected-output +++ /dev/null @@ -1,89 +0,0 @@ - cdec cfg 'test/example/cdec.ini' -Loading the LM will be faster if you build a binary file. -Reading test/example/nc-wmt11.en.srilm.gz -----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 -**************************************************************************************************** - Example feature: Shape_S00000_T00000 -Seeding random number sequence to 2912000813 - -dtrain -Parameters: - k 100 - N 4 - T 2 - scorer 'stupid_bleu' - sample from 'kbest' - filter 'uniq' - learning rate 1 - gamma 0 - loss margin 0 - pairs 'XYX' - hi lo 0.1 - pair threshold 0 - select weights 'VOID' - l1 reg 0 'none' - max pairs 4294967295 - cdec cfg 'test/example/cdec.ini' - input 'test/example/nc-wmt11.1k.gz' - output '-' - stop_after 10 -(a dot represents 10 inputs) -Iteration #1 of 2. - . 10 -Stopping after 10 input sentences. -WEIGHTS - Glue = -637 - WordPenalty = +1064 - LanguageModel = +1175.3 - LanguageModel_OOV = -1437 - PhraseModel_0 = +1935.6 - PhraseModel_1 = +2499.3 - PhraseModel_2 = +964.96 - PhraseModel_3 = +1410.8 - PhraseModel_4 = -5977.9 - PhraseModel_5 = +522 - PhraseModel_6 = +1089 - PassThrough = -1308 - --- - 1best avg score: 0.16963 (+0.16963) - 1best avg model score: 64485 (+64485) - avg # pairs: 1494.4 - avg # rank err: 702.6 - avg # margin viol: 0 - non0 feature count: 528 - avg list sz: 85.7 - avg f count: 102.75 -(time 0.083 min, 0.5 s/S) - -Iteration #2 of 2. - . 10 -WEIGHTS - Glue = -1196 - WordPenalty = +809.52 - LanguageModel = +3112.1 - LanguageModel_OOV = -1464 - PhraseModel_0 = +3895.5 - PhraseModel_1 = +4683.4 - PhraseModel_2 = +1092.8 - PhraseModel_3 = +1079.6 - PhraseModel_4 = -6827.7 - PhraseModel_5 = -888 - PhraseModel_6 = +142 - PassThrough = -1335 - --- - 1best avg score: 0.277 (+0.10736) - 1best avg model score: -3110.5 (-67595) - avg # pairs: 1144.2 - avg # rank err: 529.1 - avg # margin viol: 0 - non0 feature count: 859 - avg list sz: 74.9 - avg f count: 112.84 -(time 0.067 min, 0.4 s/S) - -Writing weights file to '-' ... -done - ---- -Best iteration: 2 [SCORE 'stupid_bleu'=0.277]. -This took 0.15 min. diff --git a/dtrain/test/parallelize/cdec.ini b/dtrain/test/parallelize/cdec.ini deleted file mode 100644 index 72e99dc5..00000000 --- a/dtrain/test/parallelize/cdec.ini +++ /dev/null @@ -1,22 +0,0 @@ -formalism=scfg -add_pass_through_rules=true -intersection_strategy=cube_pruning -cubepruning_pop_limit=200 -scfg_max_span_limit=15 -feature_function=WordPenalty -feature_function=KLanguageModel /stor/dat/wmt12/en/news_only/m/wmt12.news.en.3.kenv5 -#feature_function=ArityPenalty -#feature_function=CMR2008ReorderingFeatures -#feature_function=Dwarf -#feature_function=InputIndicator -#feature_function=LexNullJump -#feature_function=NewJump -#feature_function=NgramFeatures -#feature_function=NonLatinCount -#feature_function=OutputIndicator -#feature_function=RuleIdentityFeatures -#feature_function=RuleNgramFeatures -#feature_function=RuleShape -#feature_function=SourceSpanSizeFeatures -#feature_function=SourceWordPenalty -#feature_function=SpanFeatures diff --git a/dtrain/test/parallelize/dtrain.ini b/dtrain/test/parallelize/dtrain.ini deleted file mode 100644 index 03f9d240..00000000 --- a/dtrain/test/parallelize/dtrain.ini +++ /dev/null @@ -1,15 +0,0 @@ -k=100 -N=4 -learning_rate=0.0001 -gamma=0 -loss_margin=0 -epochs=1 -scorer=stupid_bleu -sample_from=kbest -filter=uniq -pair_sampling=XYX -hi_lo=0.1 -select_weights=last -print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough -tmp=/tmp -decoder_config=cdec.ini diff --git a/dtrain/test/parallelize/in b/dtrain/test/parallelize/in deleted file mode 100644 index a312809f..00000000 --- a/dtrain/test/parallelize/in +++ /dev/null @@ -1,10 +0,0 @@ -barack obama erhält als vierter us @-@ präsident den frieden nobelpreis -der amerikanische präsident barack obama kommt für 26 stunden nach oslo , norwegen , um hier als vierter us @-@ präsident in der geschichte den frieden nobelpreis entgegen zunehmen . -darüber hinaus erhält er das diplom sowie die medaille und einen scheck über 1,4 mio. dollar für seine außer gewöhnlichen bestrebungen um die intensivierung der welt diplomatie und zusammen arbeit unter den völkern . -der chef des weißen hauses kommt morgen zusammen mit seiner frau michelle in der nordwegischen metropole an und wird die ganze zeit beschäftigt sein . -zunächst stattet er dem nobel @-@ institut einen besuch ab , wo er überhaupt zum ersten mal mit den fünf ausschuss mitglieder zusammen trifft , die ihn im oktober aus 172 leuten und 33 organisationen gewählt haben . -das präsidenten paar hat danach ein treffen mit dem norwegischen könig harald v. und königin sonja eingeplant . -nachmittags erreicht dann der besuch seinen höhepunkt mit der zeremonie , bei der obama den prestige preis übernimmt . -diesen erhält er als der vierte us @-@ präsident , aber erst als der dritte , der den preis direkt im amt entgegen nimmt . -das weiße haus avisierte schon , dass obama bei der übernahme des preises über den afghanistan krieg sprechen wird . -der präsident will diesem thema nicht ausweichen , weil er weiß , dass er den preis als ein präsident übernimmt , der zur zeit krieg in zwei ländern führt . diff --git a/dtrain/test/parallelize/refs b/dtrain/test/parallelize/refs deleted file mode 100644 index 4d3128cb..00000000 --- a/dtrain/test/parallelize/refs +++ /dev/null @@ -1,10 +0,0 @@ -barack obama becomes the fourth american president to receive the nobel peace prize -the american president barack obama will fly into oslo , norway for 26 hours to receive the nobel peace prize , the fourth american president in history to do so . -he will receive a diploma , medal and cheque for 1.4 million dollars for his exceptional efforts to improve global diplomacy and encourage international cooperation , amongst other things . -the head of the white house will be flying into the norwegian city in the morning with his wife michelle and will have a busy schedule . -first , he will visit the nobel institute , where he will have his first meeting with the five committee members who selected him from 172 people and 33 organisations . -the presidential couple then has a meeting scheduled with king harald v and queen sonja of norway . -then , in the afternoon , the visit will culminate in a grand ceremony , at which obama will receive the prestigious award . -he will be the fourth american president to be awarded the prize , and only the third to have received it while actually in office . -the white house has stated that , when he accepts the prize , obama will speak about the war in afghanistan . -the president does not want to skirt around this topic , as he realises that he is accepting the prize as a president whose country is currently at war in two countries . diff --git a/dtrain/test/toy/cdec.ini b/dtrain/test/toy/cdec.ini deleted file mode 100644 index 98b02d44..00000000 --- a/dtrain/test/toy/cdec.ini +++ /dev/null @@ -1,2 +0,0 @@ -formalism=scfg -add_pass_through_rules=true diff --git a/dtrain/test/toy/dtrain.ini b/dtrain/test/toy/dtrain.ini deleted file mode 100644 index a091732f..00000000 --- a/dtrain/test/toy/dtrain.ini +++ /dev/null @@ -1,12 +0,0 @@ -decoder_config=test/toy/cdec.ini -input=test/toy/input -output=- -print_weights=logp shell_rule house_rule small_rule little_rule PassThrough -k=4 -N=4 -epochs=2 -scorer=bleu -sample_from=kbest -filter=uniq -pair_sampling=all -learning_rate=1 diff --git a/dtrain/test/toy/input b/dtrain/test/toy/input deleted file mode 100644 index 4d10a9ea..00000000 --- a/dtrain/test/toy/input +++ /dev/null @@ -1,2 +0,0 @@ -0 ich sah ein kleines haus i saw a little house [S] ||| [NP,1] [VP,2] ||| [1] [2] ||| logp=0 [NP] ||| ich ||| i ||| logp=0 [NP] ||| ein [NN,1] ||| a [1] ||| logp=0 [NN] ||| [JJ,1] haus ||| [1] house ||| logp=0 house_rule=1 [NN] ||| [JJ,1] haus ||| [1] shell ||| logp=0 shell_rule=1 [JJ] ||| kleines ||| small ||| logp=0 small_rule=1 [JJ] ||| kleines ||| little ||| logp=0 little_rule=1 [JJ] ||| grosses ||| big ||| logp=0 [JJ] ||| grosses ||| large ||| logp=0 [VP] ||| [V,1] [NP,2] ||| [1] [2] ||| logp=0 [V] ||| sah ||| saw ||| logp=0 [V] ||| fand ||| found ||| logp=0 -1 ich fand ein kleines haus i found a little house [S] ||| [NP,1] [VP,2] ||| [1] [2] ||| logp=0 [NP] ||| ich ||| i ||| logp=0 [NP] ||| ein [NN,1] ||| a [1] ||| logp=0 [NN] ||| [JJ,1] haus ||| [1] house ||| logp=0 house_rule=1 [NN] ||| [JJ,1] haus ||| [1] shell ||| logp=0 shell_rule=1 [JJ] ||| kleines ||| small ||| logp=0 small_rule=1 [JJ] ||| kleines ||| little ||| logp=0 little_rule=1 [JJ] ||| grosses ||| big ||| logp=0 [JJ] ||| grosses ||| large ||| logp=0 [VP] ||| [V,1] [NP,2] ||| [1] [2] ||| logp=0 [V] ||| sah ||| saw ||| logp=0 [V] ||| fand ||| found ||| logp=0 diff --git a/minrisk/Makefile.am b/minrisk/Makefile.am deleted file mode 100644 index a24f047c..00000000 --- a/minrisk/Makefile.am +++ /dev/null @@ -1,6 +0,0 @@ -bin_PROGRAMS = minrisk_optimize - -minrisk_optimize_SOURCES = minrisk_optimize.cc -minrisk_optimize_LDADD = $(top_srcdir)/training/libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/training/liblbfgs/liblbfgs.a -lz - -AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training diff --git a/minrisk/minrisk.pl b/minrisk/minrisk.pl deleted file mode 100755 index d05b9595..00000000 --- a/minrisk/minrisk.pl +++ /dev/null @@ -1,540 +0,0 @@ -#!/usr/bin/env perl -use strict; -my @ORIG_ARGV=@ARGV; -use Cwd qw(getcwd); -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } - -# Skip local config (used for distributing jobs) if we're running in local-only mode -use LocalConfig; -use Getopt::Long; -use IPC::Open2; -use POSIX ":sys_wait_h"; -my $QSUB_CMD = qsub_args(mert_memory()); -my $default_jobs = env_default_jobs(); - -my $VEST_DIR="$SCRIPT_DIR/../dpmert"; -require "$VEST_DIR/libcall.pl"; - -# Default settings -my $srcFile; -my $refFiles; -my $bin_dir = $SCRIPT_DIR; -die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; -my $FAST_SCORE="$bin_dir/../mteval/fast_score"; -die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $MAPINPUT = "$bin_dir/minrisk_generate_input.pl"; -my $MAPPER = "$bin_dir/minrisk_optimize"; -my $parallelize = "$VEST_DIR/parallelize.pl"; -my $libcall = "$VEST_DIR/libcall.pl"; -my $sentserver = "$VEST_DIR/sentserver"; -my $sentclient = "$VEST_DIR/sentclient"; -my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; - -my $SCORER = $FAST_SCORE; -die "Can't find $MAPPER" unless -x $MAPPER; -my $cdec = "$bin_dir/../decoder/cdec"; -die "Can't find decoder in $cdec" unless -x $cdec; -die "Can't find $parallelize" unless -x $parallelize; -die "Can't find $libcall" unless -e $libcall; -my $decoder = $cdec; -my $lines_per_mapper = 30; -my $iteration = 1; -my $best_weights; -my $psi = 1; -my $default_max_iter = 30; -my $max_iterations = $default_max_iter; -my $jobs = $default_jobs; # number of decode nodes -my $pmem = "4g"; -my $disable_clean = 0; -my %seen_weights; -my $help = 0; -my $epsilon = 0.0001; -my $dryrun = 0; -my $last_score = -10000000; -my $metric = "ibm_bleu"; -my $dir; -my $iniFile; -my $weights; -my $use_make = 1; # use make to parallelize -my $useqsub = 0; -my $initial_weights; -my $pass_suffix = ''; -my $cpbin=1; - -# regularization strength -my $tune_regularizer = 0; -my $reg = 500; -my $reg_previous = 5000; -my $dont_accum = 0; - -# Process command-line options -Getopt::Long::Configure("no_auto_abbrev"); -if (GetOptions( - "jobs=i" => \$jobs, - "dont-clean" => \$disable_clean, - "dont-accumulate" => \$dont_accum, - "pass-suffix=s" => \$pass_suffix, - "qsub" => \$useqsub, - "dry-run" => \$dryrun, - "epsilon=s" => \$epsilon, - "help" => \$help, - "weights=s" => \$initial_weights, - "reg=f" => \$reg, - "use-make=i" => \$use_make, - "max-iterations=i" => \$max_iterations, - "pmem=s" => \$pmem, - "cpbin!" => \$cpbin, - "ref-files=s" => \$refFiles, - "metric=s" => \$metric, - "source-file=s" => \$srcFile, - "workdir=s" => \$dir, -) == 0 || @ARGV!=1 || $help) { - print_help(); - exit; -} - -die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; - -if ($useqsub) { - $use_make = 0; - die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); -} - -my @missing_args = (); -if (!defined $srcFile) { push @missing_args, "--source-file"; } -if (!defined $refFiles) { push @missing_args, "--ref-files"; } -if (!defined $initial_weights) { push @missing_args, "--weights"; } -die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); - -if ($metric =~ /^(combi|ter)$/i) { - $lines_per_mapper = 5; -} - -($iniFile) = @ARGV; - - -sub write_config; -sub enseg; -sub print_help; - -my $nodelist; -my $host =check_output("hostname"); chomp $host; -my $bleu; -my $interval_count = 0; -my $logfile; -my $projected_score; - -# used in sorting scores -my $DIR_FLAG = '-r'; -if ($metric =~ /^ter$|^aer$/i) { - $DIR_FLAG = ''; -} - -my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); - -unless ($dir){ - $dir = "minrisk"; -} -unless ($dir =~ /^\//){ # convert relative path to absolute path - my $basedir = check_output("pwd"); - chomp $basedir; - $dir = "$basedir/$dir"; -} - - -# Initializations and helper functions -srand; - -my @childpids = (); -my @cleanupcmds = (); - -sub cleanup { - print STDERR "Cleanup...\n"; - for my $pid (@childpids){ unchecked_call("kill $pid"); } - for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } - exit 1; -}; -# Always call cleanup, no matter how we exit -*CORE::GLOBAL::exit = - sub{ cleanup(); }; -$SIG{INT} = "cleanup"; -$SIG{TERM} = "cleanup"; -$SIG{HUP} = "cleanup"; - -my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; -my $newIniFile = "$dir/$decoderBase.ini"; -my $inputFileName = "$dir/input"; -my $user = $ENV{"USER"}; -# process ini file --e $iniFile || die "Error: could not open $iniFile for reading\n"; -open(INI, $iniFile); - -use File::Basename qw(basename); -#pass bindir, refs to vars holding bin -sub modbin { - local $_; - my $bindir=shift; - check_call("mkdir -p $bindir"); - -d $bindir || die "couldn't make bindir $bindir"; - for (@_) { - my $src=$$_; - $$_="$bindir/".basename($src); - check_call("cp -p $src $$_"); - } -} -sub dirsize { - opendir ISEMPTY,$_[0]; - return scalar(readdir(ISEMPTY))-1; -} -my @allweights; -if ($dryrun){ - write_config(*STDERR); - exit 0; -} else { - if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-pro.pl outputs - die "ERROR: working dir $dir already exists\n\n"; - } else { - -e $dir || mkdir $dir; - mkdir "$dir/hgs"; - modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; - mkdir "$dir/scripts"; - my $cmdfile="$dir/rerun-pro.sh"; - open CMD,'>',$cmdfile; - print CMD "cd ",&getcwd,"\n"; -# print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. - my $cline=&cmdline."\n"; - print CMD $cline; - close CMD; - print STDERR $cline; - chmod(0755,$cmdfile); - check_call("cp $initial_weights $dir/weights.0"); - die "Can't find weights.0" unless (-e "$dir/weights.0"); - } - write_config(*STDERR); -} - - -# Generate initial files and values -check_call("cp $iniFile $newIniFile"); -$iniFile = $newIniFile; - -my $newsrc = "$dir/dev.input"; -enseg($srcFile, $newsrc); -$srcFile = $newsrc; -my $devSize = 0; -open F, "<$srcFile" or die "Can't read $srcFile: $!"; -while() { $devSize++; } -close F; - -unless($best_weights){ $best_weights = $weights; } -unless($projected_score){ $projected_score = 0.0; } -$seen_weights{$weights} = 1; -my $kbest = "$dir/kbest"; -if ($dont_accum) { - $kbest = ''; -} else { - check_call("mkdir -p $kbest"); - $kbest = "--kbest_repository $kbest"; -} - -my $random_seed = int(time / 1000); -my $lastWeightsFile; -my $lastPScore = 0; -# main optimization loop -while (1){ - print STDERR "\n\nITERATION $iteration\n==========\n"; - - if ($iteration > $max_iterations){ - print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; - last; - } - # iteration-specific files - my $runFile="$dir/run.raw.$iteration"; - my $onebestFile="$dir/1best.$iteration"; - my $logdir="$dir/logs.$iteration"; - my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; - my $scorerLog="$logdir/scorer.log.$iteration"; - check_call("mkdir -p $logdir"); - - - #decode - print STDERR "RUNNING DECODER AT "; - print STDERR unchecked_output("date"); - my $im1 = $iteration - 1; - my $weightsFile="$dir/weights.$im1"; - push @allweights, "-w $dir/weights.$im1"; - `rm -f $dir/hgs/*.gz`; - my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - my $pcmd; - if ($use_make) { - $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; - } else { - $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; - } - my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - my $num_hgs; - my $num_topbest; - my $retries = 0; - while($retries < 5) { - $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); - $num_topbest = check_output("wc -l < $runFile"); - print STDERR "NUMBER OF HGs: $num_hgs\n"; - print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; - if($devSize == $num_hgs && $devSize == $num_topbest) { - last; - } else { - print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; - sleep(3); - } - $retries++; - } - die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); - my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -m $metric"); - chomp $dec_score; - print STDERR "DECODER SCORE: $dec_score\n"; - - # save space - check_call("gzip -f $runFile"); - check_call("gzip -f $decoderLog"); - - # run optimizer - print STDERR "RUNNING OPTIMIZER AT "; - print STDERR unchecked_output("date"); - print STDERR " - GENERATE TRAINING EXEMPLARS\n"; - my $mergeLog="$logdir/prune-merge.log.$iteration"; - - my $score = 0; - my $icc = 0; - my $inweights="$dir/weights.$im1"; - my $outweights="$dir/weights.$iteration"; - $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - $cmd="$MAPPER $refs_comma_sep -m $metric -i $dir/agenda.$im1 $kbest -w $inweights > $outweights"; - check_call($cmd); - $lastWeightsFile = $outweights; - $iteration++; - `rm hgs/*.gz`; - print STDERR "\n==========\n"; -} - -print STDERR "\nFINAL WEIGHTS: $lastWeightsFile\n(Use -w with the decoder)\n\n"; - -print STDOUT "$lastWeightsFile\n"; - -exit 0; - -sub get_lines { - my $fn = shift @_; - open FL, "<$fn" or die "Couldn't read $fn: $!"; - my $lc = 0; - while() { $lc++; } - return $lc; -} - -sub get_comma_sep_refs { - my ($r,$p) = @_; - my $o = check_output("echo $p"); - chomp $o; - my @files = split /\s+/, $o; - return "-$r " . join(" -$r ", @files); -} - -sub read_weights_file { - my ($file) = @_; - open F, "<$file" or die "Couldn't read $file: $!"; - my @r = (); - my $pm = -1; - while() { - next if /^#/; - next if /^\s*$/; - chomp; - if (/^(.+)\s+(.+)$/) { - my $m = $1; - my $w = $2; - die "Weights out of order: $m <= $pm" unless $m > $pm; - push @r, $w; - } else { - warn "Unexpected feature name in weight file: $_"; - } - } - close F; - return join ' ', @r; -} - -# subs -sub write_config { - my $fh = shift; - my $cleanup = "yes"; - if ($disable_clean) {$cleanup = "no";} - - print $fh "\n"; - print $fh "DECODER: $decoder\n"; - print $fh "INI FILE: $iniFile\n"; - print $fh "WORKING DIR: $dir\n"; - print $fh "SOURCE (DEV): $srcFile\n"; - print $fh "REFS (DEV): $refFiles\n"; - print $fh "EVAL METRIC: $metric\n"; - print $fh "MAX ITERATIONS: $max_iterations\n"; - print $fh "JOBS: $jobs\n"; - print $fh "HEAD NODE: $host\n"; - print $fh "PMEM (DECODING): $pmem\n"; - print $fh "CLEANUP: $cleanup\n"; -} - -sub update_weights_file { - my ($neww, $rfn, $rpts) = @_; - my @feats = @$rfn; - my @pts = @$rpts; - my $num_feats = scalar @feats; - my $num_pts = scalar @pts; - die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; - open G, ">$neww" or die; - for (my $i = 0; $i < $num_feats; $i++) { - my $f = $feats[$i]; - my $lambda = $pts[$i]; - print G "$f $lambda\n"; - } - close G; -} - -sub enseg { - my $src = shift; - my $newsrc = shift; - open(SRC, $src); - open(NEWSRC, ">$newsrc"); - my $i=0; - while (my $line=){ - chomp $line; - if ($line =~ /^\s* tags, you must include a zero-based id attribute"; - } - } else { - print NEWSRC "$line\n"; - } - $i++; - } - close SRC; - close NEWSRC; - die "Empty dev set!" if ($i == 0); -} - -sub print_help { - - my $executable = check_output("basename $0"); chomp $executable; - print << "Help"; - -Usage: $executable [options] - - $executable [options] - Runs a complete PRO optimization using the ini file specified. - -Required: - - --ref-files - Dev set ref files. This option takes only a single string argument. - To use multiple files (including file globbing), this argument should - be quoted. - - --source-file - Dev set source file. - - --weights - Initial weights file (use empty file to start from 0) - -General options: - - --help - Print this message and exit. - - --dont-accumulate - Don't accumulate k-best lists from multiple iterations. - - --max-iterations - Maximum number of iterations to run. If not specified, defaults - to $default_max_iter. - - --metric - Metric to optimize. - Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi - - --pass-suffix - If the decoder is doing multi-pass decoding, the pass suffix "2", - "3", etc., is used to control what iteration of weights is set. - - --workdir - Directory for intermediate and output files. If not specified, the - name is derived from the ini filename. Assuming that the ini - filename begins with the decoder name and ends with ini, the default - name of the working directory is inferred from the middle part of - the filename. E.g. an ini file named decoder.foo.ini would have - a default working directory name foo. - -Regularization options: - - --reg - l2 regularization strength [default=500]. The greater this value, - the closer to zero the weights will be. - -Job control options: - - --jobs - Number of decoder processes to run in parallel. [default=$default_jobs] - - --qsub - Use qsub to run jobs in parallel (qsub must be configured in - environment/LocalEnvironment.pm) - - --pmem - Amount of physical memory requested for parallel decoding jobs - (used with qsub requests only) - -Help -} - -sub convert { - my ($str) = @_; - my @ps = split /;/, $str; - my %dict = (); - for my $p (@ps) { - my ($k, $v) = split /=/, $p; - $dict{$k} = $v; - } - return %dict; -} - - -sub cmdline { - return join ' ',($0,@ORIG_ARGV); -} - -#buggy: last arg gets quoted sometimes? -my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; -my $shell_escape_in_quote=qr{[\\"\$`!]}; - -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} - -sub escaped_shell_args { - return map {local $_=$_;chomp;escape_shell($_)} @_; -} - -sub escaped_shell_args_str { - return join ' ',&escaped_shell_args(@_); -} - -sub escaped_cmdline { - return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); -} diff --git a/minrisk/minrisk_generate_input.pl b/minrisk/minrisk_generate_input.pl deleted file mode 100755 index b30fc4fd..00000000 --- a/minrisk/minrisk_generate_input.pl +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; -my $d = shift @ARGV; -die "Can't find directory $d" unless -d $d; - -opendir(DIR, $d) or die "Can't read $d: $!"; -my @hgs = grep { /\.gz$/ } readdir(DIR); -closedir DIR; - -for my $hg (@hgs) { - my $file = $hg; - my $id = $hg; - $id =~ s/(\.json)?\.gz//; - print "$d/$file $id\n"; -} - diff --git a/minrisk/minrisk_optimize.cc b/minrisk/minrisk_optimize.cc deleted file mode 100644 index da8b5260..00000000 --- a/minrisk/minrisk_optimize.cc +++ /dev/null @@ -1,197 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "liblbfgs/lbfgs++.h" -#include "filelib.h" -#include "stringlib.h" -#include "weights.h" -#include "hg_io.h" -#include "kbest.h" -#include "viterbi.h" -#include "ns.h" -#include "ns_docscorer.h" -#include "candidate_set.h" -#include "risk.h" -#include "entropy.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") - ("weights,w",po::value(), "[REQD] Weights files from current iterations") - ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") - ("temperature,T",po::value()->default_value(0.0), "Temperature parameter for objective (>0 increases the entropy)") - ("l1_strength,C",po::value()->default_value(0.0), "L1 regularization strength") - ("memory_buffers,M",po::value()->default_value(20), "Memory buffers used in LBFGS") - ("kbest_repository,R",po::value(), "Accumulate k-best lists from previous iterations (parameter is path to repository)") - ("kbest_size,k",po::value()->default_value(500u), "Top k-hypotheses to extract") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; - if (!conf->count("reference")) { - cerr << "Please specify one or more references using -r \n"; - flag = true; - } - if (!conf->count("weights")) { - cerr << "Please specify weights using -w \n"; - flag = true; - } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -EvaluationMetric* metric = NULL; - -struct RiskObjective { - explicit RiskObjective(const vector& tr, const double temp) : training(tr), T(temp) {} - double operator()(const vector& x, double* g) const { - fill(g, g + x.size(), 0.0); - double obj = 0; - double h = 0; - for (unsigned i = 0; i < training.size(); ++i) { - training::CandidateSetRisk risk(training[i], *metric); - training::CandidateSetEntropy entropy(training[i]); - SparseVector tg, hg; - double r = risk(x, &tg); - double hh = entropy(x, &hg); - h += hh; - obj += r; - for (SparseVector::iterator it = tg.begin(); it != tg.end(); ++it) - g[it->first] += it->second; - if (T) { - for (SparseVector::iterator it = hg.begin(); it != hg.end(); ++it) - g[it->first] += T * it->second; - } - } - cerr << (1-(obj / training.size())) << " H=" << h << endl; - return obj - T * h; - } - const vector& training; - const double T; // temperature for entropy regularization -}; - -double LearnParameters(const vector& training, - const double temp, // > 0 increases the entropy, < 0 decreases the entropy - const double C1, - const unsigned memory_buffers, - vector* px) { - RiskObjective obj(training, temp); - LBFGS lbfgs(px, obj, memory_buffers, C1); - lbfgs.MinimizeFunction(); - return 0; -} - -#if 0 -struct FooLoss { - double operator()(const vector& x, double* g) const { - fill(g, g + x.size(), 0.0); - training::CandidateSet cs; - training::CandidateSetEntropy cse(cs); - cs.cs.resize(3); - cs.cs[0].fmap.set_value(FD::Convert("F1"), -1.0); - cs.cs[1].fmap.set_value(FD::Convert("F2"), 1.0); - cs.cs[2].fmap.set_value(FD::Convert("F1"), 2.0); - cs.cs[2].fmap.set_value(FD::Convert("F2"), 0.5); - SparseVector xx; - double h = cse(x, &xx); - cerr << cse(x, &xx) << endl; cerr << "G: " << xx << endl; - for (SparseVector::iterator i = xx.begin(); i != xx.end(); ++i) - g[i->first] += i->second; - return -h; - } -}; -#endif - -int main(int argc, char** argv) { -#if 0 - training::CandidateSet cs; - training::CandidateSetEntropy cse(cs); - cs.cs.resize(3); - cs.cs[0].fmap.set_value(FD::Convert("F1"), -1.0); - cs.cs[1].fmap.set_value(FD::Convert("F2"), 1.0); - cs.cs[2].fmap.set_value(FD::Convert("F1"), 2.0); - cs.cs[2].fmap.set_value(FD::Convert("F2"), 0.5); - FooLoss foo; - vector ww(FD::NumFeats()); ww[FD::Convert("F1")] = 1.0; - LBFGS lbfgs(&ww, foo, 100, 0.0); - lbfgs.MinimizeFunction(); - return 1; -#endif - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string evaluation_metric = conf["evaluation_metric"].as(); - - metric = EvaluationMetric::Instance(evaluation_metric); - DocumentScorer ds(metric, conf["reference"].as >()); - cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; - - Hypergraph hg; - string last_file; - ReadFile in_read(conf["input"].as()); - string kbest_repo; - if (conf.count("kbest_repository")) { - kbest_repo = conf["kbest_repository"].as(); - MkDirP(kbest_repo); - } - istream &in=*in_read.stream(); - const unsigned kbest_size = conf["kbest_size"].as(); - vector weights; - const string weightsf = conf["weights"].as(); - Weights::InitFromFile(weightsf, &weights); - double t = 0; - for (unsigned i = 0; i < weights.size(); ++i) - t += weights[i] * weights[i]; - if (t > 0) { - for (unsigned i = 0; i < weights.size(); ++i) - weights[i] /= sqrt(t); - } - string line, file; - vector kis; - cerr << "Loading hypergraphs...\n"; - while(getline(in, line)) { - istringstream is(line); - int sent_id; - kis.resize(kis.size() + 1); - training::CandidateSet& curkbest = kis.back(); - string kbest_file; - if (kbest_repo.size()) { - ostringstream os; - os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; - kbest_file = os.str(); - if (FileExists(kbest_file)) - curkbest.ReadFromFile(kbest_file); - } - is >> file >> sent_id; - ReadFile rf(file); - if (kis.size() % 5 == 0) { cerr << '.'; } - if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); - hg.Reweight(weights); - curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); - if (kbest_file.size()) - curkbest.WriteToFile(kbest_file); - } - cerr << "\nHypergraphs loaded.\n"; - weights.resize(FD::NumFeats()); - - double c1 = conf["l1_strength"].as(); - double temp = conf["temperature"].as(); - unsigned m = conf["memory_buffers"].as(); - LearnParameters(kis, temp, c1, m, &weights); - Weights::WriteToFile("-", weights); - return 0; -} - diff --git a/mira/Makefile.am b/mira/Makefile.am deleted file mode 100644 index 3f8f17cd..00000000 --- a/mira/Makefile.am +++ /dev/null @@ -1,6 +0,0 @@ -bin_PROGRAMS = kbest_mira - -kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/mira/kbest_mira.cc b/mira/kbest_mira.cc deleted file mode 100644 index 8b7993dd..00000000 --- a/mira/kbest_mira.cc +++ /dev/null @@ -1,309 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "hg_sampler.h" -#include "sentence_metadata.h" -#include "scorer.h" -#include "verbose.h" -#include "viterbi.h" -#include "hg.h" -#include "prob.h" -#include "kbest.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" -#include "sampler.h" - -using namespace std; -namespace po = boost::program_options; - -bool invert_score; -std::tr1::shared_ptr rng; - -void RandomPermutation(int len, vector* p_ids) { - vector& ids = *p_ids; - ids.resize(len); - for (int i = 0; i < len; ++i) ids[i] = i; - for (int i = len; i > 0; --i) { - int j = rng->next() * i; - if (j == i) i--; - swap(ids[i-1], ids[j]); - } -} - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_weights,w",po::value(),"Input feature weights file") - ("source,i",po::value(),"Source file for development set") - ("passes,p", po::value()->default_value(15), "Number of passes through the training data") - ("reference,r",po::value >(), "[REQD] Reference translation(s) (tokenized text file)") - ("mt_metric,m",po::value()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") - ("max_step_size,C", po::value()->default_value(0.01), "regularization strength (C)") - ("mt_metric_scale,s", po::value()->default_value(1.0), "Amount to scale MT loss function by") - ("k_best_size,k", po::value()->default_value(250), "Size of hypothesis list to search for oracles") - ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") - ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") - ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") - ("decoder_config,c",po::value(),"Decoder configuration file"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("input_weights") || !conf->count("source") || !conf->count("decoder_config") || !conf->count("reference")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -static const double kMINUS_EPSILON = -1e-6; - -struct HypothesisInfo { - SparseVector features; - double mt_metric; -}; - -struct GoodBadOracle { - std::tr1::shared_ptr good; - std::tr1::shared_ptr bad; -}; - -struct TrainingObserver : public DecoderObserver { - TrainingObserver(const int k, const DocScorer& d, bool sf, vector* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} - const DocScorer& ds; - vector& oracles; - std::tr1::shared_ptr cur_best; - const int kbest_size; - const bool sample_forest; - - const HypothesisInfo& GetCurrentBestHypothesis() const { - return *cur_best; - } - - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - UpdateOracles(smeta.GetSentenceID(), *hg); - } - - std::tr1::shared_ptr MakeHypothesisInfo(const SparseVector& feats, const double score) { - std::tr1::shared_ptr h(new HypothesisInfo); - h->features = feats; - h->mt_metric = score; - return h; - } - - void UpdateOracles(int sent_id, const Hypergraph& forest) { - std::tr1::shared_ptr& cur_good = oracles[sent_id].good; - std::tr1::shared_ptr& cur_bad = oracles[sent_id].bad; - cur_bad.reset(); // TODO get rid of?? - - if (sample_forest) { - vector cur_prediction; - ViterbiESentence(forest, &cur_prediction); - float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); - cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); - - vector samples; - HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); - for (unsigned i = 0; i < samples.size(); ++i) { - sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); - } - } else { - KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); - for (int i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(forest.nodes_.size() - 1, i); - if (!d) break; - float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); - if (invert_score) sentscore *= -1.0; - // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; - if (i == 0) - cur_best = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_good || sentscore > cur_good->mt_metric) - cur_good = MakeHypothesisInfo(d->feature_values, sentscore); - if (!cur_bad || sentscore < cur_bad->mt_metric) - cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); - } - //cerr << "GOOD: " << cur_good->mt_metric << endl; - //cerr << " CUR: " << cur_best->mt_metric << endl; - //cerr << " BAD: " << cur_bad->mt_metric << endl; - } - } -}; - -void ReadTrainingCorpus(const string& fname, vector* c) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - while(in) { - getline(in, line); - if (!in) break; - c->push_back(line); - } -} - -bool ApproxEqual(double a, double b) { - if (a == b) return true; - return (fabs(a-b)/fabs(b)) < 0.000001; -} - -int main(int argc, char** argv) { - register_feature_functions(); - SetSilent(true); // turn off verbose decoder output - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) return 1; - - if (conf.count("random_seed")) - rng.reset(new MT19937(conf["random_seed"].as())); - else - rng.reset(new MT19937); - const bool sample_forest = conf.count("sample_forest") > 0; - const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; - if (sample_forest_unit_weight_vector && !sample_forest) { - cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; - return 1; - } - vector corpus; - ReadTrainingCorpus(conf["source"].as(), &corpus); - const string metric_name = conf["mt_metric"].as(); - ScoreType type = ScoreTypeFromString(metric_name); - if (type == TER) { - invert_score = true; - } else { - invert_score = false; - } - DocScorer ds(type, conf["reference"].as >(), ""); - cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; - if (ds.size() != corpus.size()) { - cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n"; - return 1; - } - - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - - // load initial weights - vector& dense_weights = decoder.CurrentWeightVector(); - SparseVector lambdas; - Weights::InitFromFile(conf["input_weights"].as(), &dense_weights); - Weights::InitSparseVector(dense_weights, &lambdas); - - const double max_step_size = conf["max_step_size"].as(); - const double mt_metric_scale = conf["mt_metric_scale"].as(); - - assert(corpus.size() > 0); - vector oracles(corpus.size()); - - TrainingObserver observer(conf["k_best_size"].as(), ds, sample_forest, &oracles); - int cur_sent = 0; - int lcount = 0; - int normalizer = 0; - double tot_loss = 0; - int dots = 0; - int cur_pass = 0; - SparseVector tot; - tot += lambdas; // initial weights - normalizer++; // count for initial weights - int max_iteration = conf["passes"].as() * corpus.size(); - string msg = "# MIRA tuned weights"; - string msga = "# MIRA tuned weights AVERAGED"; - vector order; - RandomPermutation(corpus.size(), &order); - while (lcount <= max_iteration) { - lambdas.init_vector(&dense_weights); - if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; } - if (corpus.size() == cur_sent) { - cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n"; - Weights::ShowLargestFeatures(dense_weights); - cur_sent = 0; - tot_loss = 0; - dots = 0; - ostringstream os; - os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz"; - SparseVector x = tot; - x /= normalizer; - ostringstream sa; - sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz"; - x.init_vector(&dense_weights); - Weights::WriteToFile(os.str(), dense_weights, true, &msg); - ++cur_pass; - RandomPermutation(corpus.size(), &order); - } - if (cur_sent == 0) { - cerr << "PASS " << (lcount / corpus.size() + 1) << endl; - } - decoder.SetId(order[cur_sent]); - double sc = 1.0; - if (sample_forest_unit_weight_vector) { - sc = lambdas.l2norm(); - if (sc > 0) { - for (unsigned i = 0; i < dense_weights.size(); ++i) - dense_weights[i] /= sc; - } - } - decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles - if (sc && sc != 1.0) { - for (unsigned i = 0; i < dense_weights.size(); ++i) - dense_weights[i] *= sc; - } - const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); - const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; - const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; - tot_loss += cur_hyp.mt_metric; - if (!ApproxEqual(cur_hyp.mt_metric, cur_good.mt_metric)) { - const double loss = cur_bad.features.dot(dense_weights) - cur_good.features.dot(dense_weights) + - mt_metric_scale * (cur_good.mt_metric - cur_bad.mt_metric); - //cerr << "LOSS: " << loss << endl; - if (loss > 0.0) { - SparseVector diff = cur_good.features; - diff -= cur_bad.features; - double step_size = loss / diff.l2norm_sq(); - //cerr << loss << " " << step_size << " " << diff << endl; - if (step_size > max_step_size) step_size = max_step_size; - lambdas += (cur_good.features * step_size); - lambdas -= (cur_bad.features * step_size); - //cerr << "L: " << lambdas << endl; - } - } - tot += lambdas; - ++normalizer; - ++lcount; - ++cur_sent; - } - cerr << endl; - Weights::WriteToFile("weights.mira-final.gz", dense_weights, true, &msg); - tot /= normalizer; - tot.init_vector(dense_weights); - msg = "# MIRA tuned weights (averaged vector)"; - Weights::WriteToFile("weights.mira-final-avg.gz", dense_weights, true, &msg); - cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n"; - return 0; -} - diff --git a/pro/Makefile.am b/pro/Makefile.am deleted file mode 100644 index 1e9d46b0..00000000 --- a/pro/Makefile.am +++ /dev/null @@ -1,11 +0,0 @@ -bin_PROGRAMS = \ - mr_pro_map \ - mr_pro_reduce - -mr_pro_map_SOURCES = mr_pro_map.cc -mr_pro_map_LDADD = $(top_srcdir)/training/libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -mr_pro_reduce_SOURCES = mr_pro_reduce.cc -mr_pro_reduce_LDADD = $(top_srcdir)/training/liblbfgs/liblbfgs.a $(top_srcdir)/utils/libutils.a -lz - -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training diff --git a/pro/README.shared-mem b/pro/README.shared-mem deleted file mode 100644 index 7728efc0..00000000 --- a/pro/README.shared-mem +++ /dev/null @@ -1,9 +0,0 @@ -If you want to run dist-vest.pl on a very large shared memory machine, do the -following: - - ./dist-vest.pl --use-make I --decode-nodes J --weights weights.init --source-file=dev.src --ref-files=dev.ref.* cdec.ini - -This will use I jobs for doing the line search and J jobs to run the decoder. Typically, since the -decoder must load grammars, language models, etc., J should be smaller than I, but this will depend -on the system you are running on and the complexity of the models used for decoding. - diff --git a/pro/mr_pro_generate_mapper_input.pl b/pro/mr_pro_generate_mapper_input.pl deleted file mode 100755 index b30fc4fd..00000000 --- a/pro/mr_pro_generate_mapper_input.pl +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; -my $d = shift @ARGV; -die "Can't find directory $d" unless -d $d; - -opendir(DIR, $d) or die "Can't read $d: $!"; -my @hgs = grep { /\.gz$/ } readdir(DIR); -closedir DIR; - -for my $hg (@hgs) { - my $file = $hg; - my $id = $hg; - $id =~ s/(\.json)?\.gz//; - print "$d/$file $id\n"; -} - diff --git a/pro/mr_pro_map.cc b/pro/mr_pro_map.cc deleted file mode 100644 index eef40b8a..00000000 --- a/pro/mr_pro_map.cc +++ /dev/null @@ -1,201 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "candidate_set.h" -#include "sampler.h" -#include "filelib.h" -#include "stringlib.h" -#include "weights.h" -#include "inside_outside.h" -#include "hg_io.h" -#include "ns.h" -#include "ns_docscorer.h" - -// This is Figure 4 (Algorithm Sampler) from Hopkins&May (2011) - -using namespace std; -namespace po = boost::program_options; - -boost::shared_ptr rng; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") - ("weights,w",po::value(), "[REQD] Weights files from current iterations") - ("kbest_repository,K",po::value()->default_value("./kbest"),"K-best list repository (directory)") - ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("source,s",po::value()->default_value(""), "Source file (ignored, except for AER)") - ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") - ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") - ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") - ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") - ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; - if (!conf->count("reference")) { - cerr << "Please specify one or more references using -r \n"; - flag = true; - } - if (!conf->count("weights")) { - cerr << "Please specify weights using -w \n"; - flag = true; - } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -struct ThresholdAlpha { - explicit ThresholdAlpha(double t = 0.05) : threshold(t) {} - double operator()(double mag) const { - if (mag < threshold) return 0.0; else return 1.0; - } - const double threshold; -}; - -struct TrainingInstance { - TrainingInstance(const SparseVector& feats, bool positive, float diff) : x(feats), y(positive), gdiff(diff) {} - SparseVector x; -#undef DEBUGGING_PRO -#ifdef DEBUGGING_PRO - vector a; - vector b; -#endif - bool y; - float gdiff; -}; -#ifdef DEBUGGING_PRO -ostream& operator<<(ostream& os, const TrainingInstance& d) { - return os << d.gdiff << " y=" << d.y << "\tA:" << TD::GetString(d.a) << "\n\tB: " << TD::GetString(d.b) << "\n\tX: " << d.x; -} -#endif - -struct DiffOrder { - bool operator()(const TrainingInstance& a, const TrainingInstance& b) const { - return a.gdiff > b.gdiff; - } -}; - -void Sample(const unsigned gamma, - const unsigned xi, - const training::CandidateSet& J_i, - const EvaluationMetric* metric, - vector* pv) { - const bool invert_score = metric->IsErrorMetric(); - vector v1, v2; - float avg_diff = 0; - for (unsigned i = 0; i < gamma; ++i) { - const size_t a = rng->inclusive(0, J_i.size() - 1)(); - const size_t b = rng->inclusive(0, J_i.size() - 1)(); - if (a == b) continue; - float ga = metric->ComputeScore(J_i[a].eval_feats); - float gb = metric->ComputeScore(J_i[b].eval_feats); - bool positive = gb < ga; - if (invert_score) positive = !positive; - const float gdiff = fabs(ga - gb); - if (!gdiff) continue; - avg_diff += gdiff; - SparseVector xdiff = (J_i[a].fmap - J_i[b].fmap).erase_zeros(); - if (xdiff.empty()) { - cerr << "Empty diff:\n " << TD::GetString(J_i[a].ewords) << endl << "x=" << J_i[a].fmap << endl; - cerr << " " << TD::GetString(J_i[b].ewords) << endl << "x=" << J_i[b].fmap << endl; - continue; - } - v1.push_back(TrainingInstance(xdiff, positive, gdiff)); -#ifdef DEBUGGING_PRO - v1.back().a = J_i[a].hyp; - v1.back().b = J_i[b].hyp; - cerr << "N: " << v1.back() << endl; -#endif - } - avg_diff /= v1.size(); - - for (unsigned i = 0; i < v1.size(); ++i) { - double p = 1.0 / (1.0 + exp(-avg_diff - v1[i].gdiff)); - // cerr << "avg_diff=" << avg_diff << " gdiff=" << v1[i].gdiff << " p=" << p << endl; - if (rng->next() < p) v2.push_back(v1[i]); - } - vector::iterator mid = v2.begin() + xi; - if (xi > v2.size()) mid = v2.end(); - partial_sort(v2.begin(), mid, v2.end(), DiffOrder()); - copy(v2.begin(), mid, back_inserter(*pv)); -#ifdef DEBUGGING_PRO - if (v2.size() >= 5) { - for (int i =0; i < (mid - v2.begin()); ++i) { - cerr << v2[i] << endl; - } - cerr << pv->back() << endl; - } -#endif -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - if (conf.count("random_seed")) - rng.reset(new MT19937(conf["random_seed"].as())); - else - rng.reset(new MT19937); - const string evaluation_metric = conf["evaluation_metric"].as(); - - EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); - DocumentScorer ds(metric, conf["reference"].as >()); - cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; - - Hypergraph hg; - string last_file; - ReadFile in_read(conf["input"].as()); - istream &in=*in_read.stream(); - const unsigned kbest_size = conf["kbest_size"].as(); - const unsigned gamma = conf["candidate_pairs"].as(); - const unsigned xi = conf["best_pairs"].as(); - string weightsf = conf["weights"].as(); - vector weights; - Weights::InitFromFile(weightsf, &weights); - string kbest_repo = conf["kbest_repository"].as(); - MkDirP(kbest_repo); - while(in) { - vector v; - string line; - getline(in, line); - if (line.empty()) continue; - istringstream is(line); - int sent_id; - string file; - // path-to-file (JSON) sent_id - is >> file >> sent_id; - ReadFile rf(file); - ostringstream os; - training::CandidateSet J_i; - os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; - const string kbest_file = os.str(); - if (FileExists(kbest_file)) - J_i.ReadFromFile(kbest_file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); - hg.Reweight(weights); - J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); - J_i.WriteToFile(kbest_file); - - Sample(gamma, xi, J_i, metric, &v); - for (unsigned i = 0; i < v.size(); ++i) { - const TrainingInstance& vi = v[i]; - cout << vi.y << "\t" << vi.x << endl; - cout << (!vi.y) << "\t" << (vi.x * -1.0) << endl; - } - } - return 0; -} - diff --git a/pro/mr_pro_reduce.cc b/pro/mr_pro_reduce.cc deleted file mode 100644 index 5ef9b470..00000000 --- a/pro/mr_pro_reduce.cc +++ /dev/null @@ -1,286 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include - -#include "filelib.h" -#include "weights.h" -#include "sparse_vector.h" -#include "optimize.h" -#include "liblbfgs/lbfgs++.h" - -using namespace std; -namespace po = boost::program_options; - -// since this is a ranking model, there should be equal numbers of -// positive and negative examples, so the bias should be 0 -static const double MAX_BIAS = 1e-10; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") - ("regularization_strength,C",po::value()->default_value(500.0), "l2 regularization strength") - ("l1",po::value()->default_value(0.0), "l1 regularization strength") - ("regularize_to_weights,y",po::value()->default_value(5000.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") - ("memory_buffers,m",po::value()->default_value(100), "Number of memory buffers (LBFGS)") - ("min_reg,r",po::value()->default_value(0.01), "When tuning (-T) regularization strength, minimum regularization strenght") - ("max_reg,R",po::value()->default_value(1e6), "When tuning (-T) regularization strength, maximum regularization strenght") - ("testset,t",po::value(), "Optional held-out test set") - ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") - ("interpolate_with_weights,p",po::value()->default_value(1.0), "[deprecated] Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -void ParseSparseVector(string& line, size_t cur, SparseVector* out) { - SparseVector& x = *out; - size_t last_start = cur; - size_t last_comma = string::npos; - while(cur <= line.size()) { - if (line[cur] == ' ' || cur == line.size()) { - if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { - cerr << "[ERROR] " << line << endl << " position = " << cur << endl; - exit(1); - } - const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); - if (cur < line.size()) line[cur] = 0; - const weight_t val = strtod(&line[last_comma + 1], NULL); - x.set_value(fid, val); - - last_comma = string::npos; - last_start = cur+1; - } else { - if (line[cur] == '=') - last_comma = cur; - } - ++cur; - } -} - -void ReadCorpus(istream* pin, vector > >* corpus) { - istream& in = *pin; - corpus->clear(); - bool flag = false; - int lc = 0; - string line; - SparseVector x; - while(getline(in, line)) { - ++lc; - if (lc % 1000 == 0) { cerr << '.'; flag = true; } - if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } - if (line.empty()) continue; - const size_t ks = line.find("\t"); - assert(string::npos != ks); - assert(ks == 1); - const bool y = line[0] == '1'; - x.clear(); - ParseSparseVector(line, ks + 1, &x); - corpus->push_back(make_pair(y, x)); - } - if (flag) cerr << endl; -} - -void GradAdd(const SparseVector& v, const double scale, weight_t* acc) { - for (SparseVector::const_iterator it = v.begin(); - it != v.end(); ++it) { - acc[it->first] += it->second * scale; - } -} - -double ApplyRegularizationTerms(const double C, - const double T, - const vector& weights, - const vector& prev_weights, - weight_t* g) { - double reg = 0; - for (size_t i = 0; i < weights.size(); ++i) { - const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); - const double& w_i = weights[i]; - reg += C * w_i * w_i; - g[i] += 2 * C * w_i; - - const double diff_i = w_i - prev_w_i; - reg += T * diff_i * diff_i; - g[i] += 2 * T * diff_i; - } - return reg; -} - -double TrainingInference(const vector& x, - const vector > >& corpus, - weight_t* g = NULL) { - double cll = 0; - for (int i = 0; i < corpus.size(); ++i) { - const double dotprod = corpus[i].second.dot(x) + (x.size() ? x[0] : weight_t()); // x[0] is bias - double lp_false = dotprod; - double lp_true = -dotprod; - if (0 < lp_true) { - lp_true += log1p(exp(-lp_true)); - lp_false = log1p(exp(lp_false)); - } else { - lp_true = log1p(exp(lp_true)); - lp_false += log1p(exp(-lp_false)); - } - lp_true*=-1; - lp_false*=-1; - if (corpus[i].first) { // true label - cll -= lp_true; - if (g) { - // g -= corpus[i].second * exp(lp_false); - GradAdd(corpus[i].second, -exp(lp_false), g); - g[0] -= exp(lp_false); // bias - } - } else { // false label - cll -= lp_false; - if (g) { - // g += corpus[i].second * exp(lp_true); - GradAdd(corpus[i].second, exp(lp_true), g); - g[0] += exp(lp_true); // bias - } - } - } - return cll; -} - -struct ProLoss { - ProLoss(const vector > >& tr, - const vector > >& te, - const double c, - const double t, - const vector& px) : training(tr), testing(te), C(c), T(t), prev_x(px){} - double operator()(const vector& x, double* g) const { - fill(g, g + x.size(), 0.0); - double cll = TrainingInference(x, training, g); - tppl = 0; - if (testing.size()) - tppl = pow(2.0, TrainingInference(x, testing, g) / (log(2) * testing.size())); - double ppl = cll / log(2); - ppl /= training.size(); - ppl = pow(2.0, ppl); - double reg = ApplyRegularizationTerms(C, T, x, prev_x, g); - return cll + reg; - } - const vector > >& training, testing; - const double C, T; - const vector& prev_x; - mutable double tppl; -}; - -// return held-out log likelihood -double LearnParameters(const vector > >& training, - const vector > >& testing, - const double C, - const double C1, - const double T, - const unsigned memory_buffers, - const vector& prev_x, - vector* px) { - assert(px->size() == prev_x.size()); - ProLoss loss(training, testing, C, T, prev_x); - LBFGS lbfgs(px, loss, memory_buffers, C1); - lbfgs.MinimizeFunction(); - return loss.tppl; -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - string line; - vector > > training, testing; - const bool tune_regularizer = conf.count("tune_regularizer"); - if (tune_regularizer && !conf.count("testset")) { - cerr << "--tune_regularizer requires --testset to be set\n"; - return 1; - } - const double min_reg = conf["min_reg"].as(); - const double max_reg = conf["max_reg"].as(); - double C = conf["regularization_strength"].as(); // will be overridden if parameter is tuned - double C1 = conf["l1"].as(); // will be overridden if parameter is tuned - const double T = conf["regularize_to_weights"].as(); - assert(C >= 0.0); - assert(min_reg >= 0.0); - assert(max_reg >= 0.0); - assert(max_reg > min_reg); - const double psi = conf["interpolate_with_weights"].as(); - if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; return 1; } - ReadCorpus(&cin, &training); - if (conf.count("testset")) { - ReadFile rf(conf["testset"].as()); - ReadCorpus(rf.stream(), &testing); - } - cerr << "Number of features: " << FD::NumFeats() << endl; - - vector x, prev_x; // x[0] is bias - if (conf.count("weights")) { - Weights::InitFromFile(conf["weights"].as(), &x); - x.resize(FD::NumFeats()); - prev_x = x; - } else { - x.resize(FD::NumFeats()); - prev_x = x; - } - cerr << " Number of features: " << x.size() << endl; - cerr << "Number of training examples: " << training.size() << endl; - cerr << "Number of testing examples: " << testing.size() << endl; - double tppl = 0.0; - vector > sp; - vector smoothed; - if (tune_regularizer) { - C = min_reg; - const double steps = 18; - double sweep_factor = exp((log(max_reg) - log(min_reg)) / steps); - cerr << "SWEEP FACTOR: " << sweep_factor << endl; - while(C < max_reg) { - cerr << "C=" << C << "\tT=" <(), prev_x, &x); - sp.push_back(make_pair(C, tppl)); - C *= sweep_factor; - } - smoothed.resize(sp.size(), 0); - smoothed[0] = sp[0].second; - smoothed.back() = sp.back().second; - for (int i = 1; i < sp.size()-1; ++i) { - double prev = sp[i-1].second; - double next = sp[i+1].second; - double cur = sp[i].second; - smoothed[i] = (prev*0.2) + cur * 0.6 + (0.2*next); - } - double best_ppl = 9999999; - unsigned best_i = 0; - for (unsigned i = 0; i < sp.size(); ++i) { - if (smoothed[i] < best_ppl) { - best_ppl = smoothed[i]; - best_i = i; - } - } - C = sp[best_i].first; - } // tune regularizer - tppl = LearnParameters(training, testing, C, C1, T, conf["memory_buffers"].as(), prev_x, &x); - if (conf.count("weights")) { - for (int i = 1; i < x.size(); ++i) { - x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); - } - } - cout.precision(15); - cout << "# C=" << C << "\theld out perplexity="; - if (tppl) { cout << tppl << endl; } else { cout << "N/A\n"; } - if (sp.size()) { - cout << "# Parameter sweep:\n"; - for (int i = 0; i < sp.size(); ++i) { - cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; - } - } - Weights::WriteToFile("-", x); - return 0; -} diff --git a/pro/pro.pl b/pro/pro.pl deleted file mode 100755 index 891b7e4c..00000000 --- a/pro/pro.pl +++ /dev/null @@ -1,555 +0,0 @@ -#!/usr/bin/env perl -use strict; -use File::Basename qw(basename); -my @ORIG_ARGV=@ARGV; -use Cwd qw(getcwd); -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } - -# Skip local config (used for distributing jobs) if we're running in local-only mode -use LocalConfig; -use Getopt::Long; -use IPC::Open2; -use POSIX ":sys_wait_h"; -my $QSUB_CMD = qsub_args(mert_memory()); -my $default_jobs = env_default_jobs(); - -my $VEST_DIR="$SCRIPT_DIR/../dpmert"; -require "$VEST_DIR/libcall.pl"; - -# Default settings -my $srcFile; -my $refFiles; -my $bin_dir = $SCRIPT_DIR; -die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; -my $FAST_SCORE="$bin_dir/../mteval/fast_score"; -die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $MAPINPUT = "$bin_dir/mr_pro_generate_mapper_input.pl"; -my $MAPPER = "$bin_dir/mr_pro_map"; -my $REDUCER = "$bin_dir/mr_pro_reduce"; -my $parallelize = "$VEST_DIR/parallelize.pl"; -my $libcall = "$VEST_DIR/libcall.pl"; -my $sentserver = "$VEST_DIR/sentserver"; -my $sentclient = "$VEST_DIR/sentclient"; -my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; - -my $SCORER = $FAST_SCORE; -die "Can't find $MAPPER" unless -x $MAPPER; -my $cdec = "$bin_dir/../decoder/cdec"; -die "Can't find decoder in $cdec" unless -x $cdec; -die "Can't find $parallelize" unless -x $parallelize; -die "Can't find $libcall" unless -e $libcall; -my $decoder = $cdec; -my $lines_per_mapper = 30; -my $iteration = 1; -my $best_weights; -my $psi = 1; -my $default_max_iter = 30; -my $max_iterations = $default_max_iter; -my $jobs = $default_jobs; # number of decode nodes -my $pmem = "4g"; -my $disable_clean = 0; -my %seen_weights; -my $help = 0; -my $epsilon = 0.0001; -my $dryrun = 0; -my $last_score = -10000000; -my $metric = "ibm_bleu"; -my $dir; -my $iniFile; -my $weights; -my $use_make = 1; # use make to parallelize -my $useqsub = 0; -my $initial_weights; -my $pass_suffix = ''; -my $devset; - -# regularization strength -my $reg = 500; -my $reg_previous = 5000; - -# Process command-line options -if (GetOptions( - "config=s" => \$iniFile, - "weights=s" => \$initial_weights, - "devset=s" => \$devset, - "jobs=i" => \$jobs, - "metric=s" => \$metric, - "pass-suffix=s" => \$pass_suffix, - "qsub" => \$useqsub, - "help" => \$help, - "reg=f" => \$reg, - "reg-previous=f" => \$reg_previous, - "output-dir=s" => \$dir, -) == 0 || @ARGV!=0 || $help) { - print_help(); - exit; -} - -if ($useqsub) { - $use_make = 0; - die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); -} - -my @missing_args = (); -if (!defined $iniFile) { push @missing_args, "--config"; } -if (!defined $devset) { push @missing_args, "--devset"; } -if (!defined $initial_weights) { push @missing_args, "--weights"; } -die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); - -if ($metric =~ /^(combi|ter)$/i) { - $lines_per_mapper = 5; -} - -my $host =check_output("hostname"); chomp $host; -my $bleu; -my $interval_count = 0; -my $logfile; -my $projected_score; - -# used in sorting scores -my $DIR_FLAG = '-r'; -if ($metric =~ /^ter$|^aer$/i) { - $DIR_FLAG = ''; -} - -unless ($dir){ - $dir = 'pro'; -} -unless ($dir =~ /^\//){ # convert relative path to absolute path - my $basedir = check_output("pwd"); - chomp $basedir; - $dir = "$basedir/$dir"; -} - -# Initializations and helper functions -srand; - -my @childpids = (); -my @cleanupcmds = (); - -sub cleanup { - print STDERR "Cleanup...\n"; - for my $pid (@childpids){ unchecked_call("kill $pid"); } - for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } - exit 1; -}; -# Always call cleanup, no matter how we exit -*CORE::GLOBAL::exit = - sub{ cleanup(); }; -$SIG{INT} = "cleanup"; -$SIG{TERM} = "cleanup"; -$SIG{HUP} = "cleanup"; - -my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; -my $newIniFile = "$dir/$decoderBase.ini"; -my $inputFileName = "$dir/input"; -my $user = $ENV{"USER"}; - - -# process ini file --e $iniFile || die "Error: could not open $iniFile for reading\n"; -open(INI, $iniFile); - -if (-e $dir) { - die "ERROR: working dir $dir already exists\n\n"; -} else { - mkdir "$dir" or die "Can't mkdir $dir: $!"; - mkdir "$dir/hgs" or die; - mkdir "$dir/scripts" or die; - print STDERR <) { $devSize++; } -close F; - -unless($best_weights){ $best_weights = $weights; } -unless($projected_score){ $projected_score = 0.0; } -$seen_weights{$weights} = 1; - -my $random_seed = int(time / 1000); -my $lastWeightsFile; -my $lastPScore = 0; -# main optimization loop -my @allweights; -while (1){ - print STDERR "\n\nITERATION $iteration\n==========\n"; - - if ($iteration > $max_iterations){ - print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; - last; - } - # iteration-specific files - my $runFile="$dir/run.raw.$iteration"; - my $onebestFile="$dir/1best.$iteration"; - my $logdir="$dir/logs.$iteration"; - my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; - my $scorerLog="$logdir/scorer.log.$iteration"; - check_call("mkdir -p $logdir"); - - - #decode - print STDERR "RUNNING DECODER AT "; - print STDERR unchecked_output("date"); - my $im1 = $iteration - 1; - my $weightsFile="$dir/weights.$im1"; - push @allweights, "-w $dir/weights.$im1"; - `rm -f $dir/hgs/*.gz`; - my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - my $pcmd; - if ($use_make) { - $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; - } else { - $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; - } - my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - my $num_hgs; - my $num_topbest; - my $retries = 0; - while($retries < 5) { - $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); - $num_topbest = check_output("wc -l < $runFile"); - print STDERR "NUMBER OF HGs: $num_hgs\n"; - print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; - if($devSize == $num_hgs && $devSize == $num_topbest) { - last; - } else { - print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; - sleep(3); - } - $retries++; - } - die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); - my $dec_score = check_output("cat $runFile | $SCORER -r $refs -m $metric"); - chomp $dec_score; - print STDERR "DECODER SCORE: $dec_score\n"; - - # save space - check_call("gzip -f $runFile"); - check_call("gzip -f $decoderLog"); - - # run optimizer - print STDERR "RUNNING OPTIMIZER AT "; - print STDERR unchecked_output("date"); - print STDERR " - GENERATE TRAINING EXEMPLARS\n"; - my $mergeLog="$logdir/prune-merge.log.$iteration"; - - my $score = 0; - my $icc = 0; - my $inweights="$dir/weights.$im1"; - $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - check_call("mkdir -p $dir/splag.$im1"); - $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1 $dir/splag.$im1/mapinput."; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; - my @shards = grep { /^mapinput\./ } readdir(DIR); - closedir DIR; - die "No shards!" unless scalar @shards > 0; - my $joblist = ""; - my $nmappers = 0; - @cleanupcmds = (); - my %o2i = (); - my $first_shard = 1; - my $mkfile; # only used with makefiles - my $mkfilename; - if ($use_make) { - $mkfilename = "$dir/splag.$im1/domap.mk"; - open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; - print $mkfile "all: $dir/splag.$im1/map.done\n\n"; - } - my @mkouts = (); # only used with makefiles - my @mapoutputs = (); - for my $shard (@shards) { - my $mapoutput = $shard; - my $client_name = $shard; - $client_name =~ s/mapinput.//; - $client_name = "pro.$client_name"; - $mapoutput =~ s/mapinput/mapoutput/; - push @mapoutputs, "$dir/splag.$im1/$mapoutput"; - $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; - my $script = "$MAPPER -s $srcFile -m $metric -r $refs -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; - if ($use_make) { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "#!/bin/bash\n"; - print F "$script\n"; - close F; - my $output = "$dir/splag.$im1/$mapoutput"; - push @mkouts, $output; - chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; - } else { - my $script_file = "$dir/scripts/map.$shard"; - open F, ">$script_file" or die "Can't write $script_file: $!"; - print F "$script\n"; - close F; - if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } - - $nmappers++; - my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; - my $jobid = check_output("$qcmd"); - chomp $jobid; - $jobid =~ s/^(\d+)(.*?)$/\1/g; - $jobid =~ s/^Your job (\d+) .*$/\1/; - push(@cleanupcmds, "qdel $jobid 2> /dev/null"); - print STDERR " $jobid"; - if ($joblist == "") { $joblist = $jobid; } - else {$joblist = $joblist . "\|" . $jobid; } - } - } - my @dev_outs = (); - my @devtest_outs = (); - @dev_outs = @mapoutputs; - if ($use_make) { - print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; - close $mkfile; - my $mcmd = "make -j $jobs -f $mkfilename"; - print STDERR "\nExecuting: $mcmd\n"; - check_call($mcmd); - } else { - print STDERR "\nLaunched $nmappers mappers.\n"; - sleep 8; - print STDERR "Waiting for mappers to complete...\n"; - while ($nmappers > 0) { - sleep 5; - my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); - $nmappers = scalar @livejobs; - } - print STDERR "All mappers complete.\n"; - } - my $tol = 0; - my $til = 0; - my $dev_test_file = "$dir/splag.$im1/devtest.gz"; - print STDERR "\nRUNNING CLASSIFIER (REDUCER)\n"; - print STDERR unchecked_output("date"); - $cmd="cat @dev_outs | $REDUCER -w $dir/weights.$im1 -C $reg -y $reg_previous --interpolate_with_weights $psi"; - $cmd .= " > $dir/weights.$iteration"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - $lastWeightsFile = "$dir/weights.$iteration"; - $lastPScore = $score; - $iteration++; - print STDERR "\n==========\n"; -} - - -check_call("cp $lastWeightsFile $dir/weights.final"); -print STDERR "\nFINAL WEIGHTS: $dir/weights.final\n(Use -w with the decoder)\n\n"; -print STDOUT "$dir/weights.final\n"; - -exit 0; - -sub read_weights_file { - my ($file) = @_; - open F, "<$file" or die "Couldn't read $file: $!"; - my @r = (); - my $pm = -1; - while() { - next if /^#/; - next if /^\s*$/; - chomp; - if (/^(.+)\s+(.+)$/) { - my $m = $1; - my $w = $2; - die "Weights out of order: $m <= $pm" unless $m > $pm; - push @r, $w; - } else { - warn "Unexpected feature name in weight file: $_"; - } - } - close F; - return join ' ', @r; -} - -sub enseg { - my $src = shift; - my $newsrc = shift; - open(SRC, $src); - open(NEWSRC, ">$newsrc"); - my $i=0; - while (my $line=){ - chomp $line; - if ($line =~ /^\s* tags, you must include a zero-based id attribute"; - } - } else { - print NEWSRC "$line\n"; - } - $i++; - } - close SRC; - close NEWSRC; - die "Empty dev set!" if ($i == 0); -} - -sub print_help { - - my $executable = basename($0); chomp $executable; - print << "Help"; - -Usage: $executable [options] - - $executable [options] - Runs a complete PRO optimization using the ini file specified. - -Required: - - --config - Decoder configuration file. - - --devset - Dev set source and reference data. - - --weights - Initial weights file (use empty file to start from 0) - -General options: - - --help - Print this message and exit. - - --max-iterations - Maximum number of iterations to run. If not specified, defaults - to $default_max_iter. - - --metric - Metric to optimize. - Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi - - --pass-suffix - If the decoder is doing multi-pass decoding, the pass suffix "2", - "3", etc., is used to control what iteration of weights is set. - - --workdir - Directory for intermediate and output files. If not specified, the - name is derived from the ini filename. Assuming that the ini - filename begins with the decoder name and ends with ini, the default - name of the working directory is inferred from the middle part of - the filename. E.g. an ini file named decoder.foo.ini would have - a default working directory name foo. - -Regularization options: - - --reg - l2 regularization strength [default=500]. The greater this value, - the closer to zero the weights will be. - - --reg-previous - l2 penalty for moving away from the weights from the previous - iteration. [default=5000]. The greater this value, the closer - to the previous iteration's weights the next iteration's weights - will be. - -Job control options: - - --jobs - Number of decoder processes to run in parallel. [default=$default_jobs] - - --qsub - Use qsub to run jobs in parallel (qsub must be configured in - environment/LocalEnvironment.pm) - - --pmem - Amount of physical memory requested for parallel decoding jobs - (used with qsub requests only) - -Deprecated options: - - --interpolate-with-weights - [deprecated] At each iteration the resulting weights are - interpolated with the weights from the previous iteration, with - this factor. [default=1.0, i.e., no effect] - -Help -} - -sub convert { - my ($str) = @_; - my @ps = split /;/, $str; - my %dict = (); - for my $p (@ps) { - my ($k, $v) = split /=/, $p; - $dict{$k} = $v; - } - return %dict; -} - - -sub cmdline { - return join ' ',($0,@ORIG_ARGV); -} - -#buggy: last arg gets quoted sometimes? -my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; -my $shell_escape_in_quote=qr{[\\"\$`!]}; - -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} - -sub escaped_shell_args { - return map {local $_=$_;chomp;escape_shell($_)} @_; -} - -sub escaped_shell_args_str { - return join ' ',&escaped_shell_args(@_); -} - -sub escaped_cmdline { - return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); -} - -sub split_devset { - my ($infile, $outsrc, $outref) = @_; - open F, "<$infile" or die "Can't read $infile: $!"; - open S, ">$outsrc" or die "Can't write $outsrc: $!"; - open R, ">$outref" or die "Can't write $outref: $!"; - while() { - chomp; - my ($src, @refs) = split /\s*\|\|\|\s*/; - die "Malformed devset line: $_\n" unless scalar @refs > 0; - print S "$src\n"; - print R join(' ||| ', @refs) . "\n"; - } - close R; - close S; - close F; -} - diff --git a/rampion/Makefile.am b/rampion/Makefile.am deleted file mode 100644 index f4dbb7cc..00000000 --- a/rampion/Makefile.am +++ /dev/null @@ -1,6 +0,0 @@ -bin_PROGRAMS = rampion_cccp - -rampion_cccp_SOURCES = rampion_cccp.cc -rampion_cccp_LDADD = $(top_srcdir)/training/libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz - -AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training diff --git a/rampion/rampion.pl b/rampion/rampion.pl deleted file mode 100755 index 55f7b3f1..00000000 --- a/rampion/rampion.pl +++ /dev/null @@ -1,540 +0,0 @@ -#!/usr/bin/env perl -use strict; -my @ORIG_ARGV=@ARGV; -use Cwd qw(getcwd); -my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../environment"; } - -# Skip local config (used for distributing jobs) if we're running in local-only mode -use LocalConfig; -use Getopt::Long; -use IPC::Open2; -use POSIX ":sys_wait_h"; -my $QSUB_CMD = qsub_args(mert_memory()); -my $default_jobs = env_default_jobs(); - -my $VEST_DIR="$SCRIPT_DIR/../dpmert"; -require "$VEST_DIR/libcall.pl"; - -# Default settings -my $srcFile; -my $refFiles; -my $bin_dir = $SCRIPT_DIR; -die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; -my $FAST_SCORE="$bin_dir/../mteval/fast_score"; -die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; -my $MAPINPUT = "$bin_dir/rampion_generate_input.pl"; -my $MAPPER = "$bin_dir/rampion_cccp"; -my $parallelize = "$VEST_DIR/parallelize.pl"; -my $libcall = "$VEST_DIR/libcall.pl"; -my $sentserver = "$VEST_DIR/sentserver"; -my $sentclient = "$VEST_DIR/sentclient"; -my $LocalConfig = "$SCRIPT_DIR/../environment/LocalConfig.pm"; - -my $SCORER = $FAST_SCORE; -die "Can't find $MAPPER" unless -x $MAPPER; -my $cdec = "$bin_dir/../decoder/cdec"; -die "Can't find decoder in $cdec" unless -x $cdec; -die "Can't find $parallelize" unless -x $parallelize; -die "Can't find $libcall" unless -e $libcall; -my $decoder = $cdec; -my $lines_per_mapper = 30; -my $iteration = 1; -my $best_weights; -my $psi = 1; -my $default_max_iter = 30; -my $max_iterations = $default_max_iter; -my $jobs = $default_jobs; # number of decode nodes -my $pmem = "4g"; -my $disable_clean = 0; -my %seen_weights; -my $help = 0; -my $epsilon = 0.0001; -my $dryrun = 0; -my $last_score = -10000000; -my $metric = "ibm_bleu"; -my $dir; -my $iniFile; -my $weights; -my $use_make = 1; # use make to parallelize -my $useqsub = 0; -my $initial_weights; -my $pass_suffix = ''; -my $cpbin=1; - -# regularization strength -my $tune_regularizer = 0; -my $reg = 500; -my $reg_previous = 5000; -my $dont_accum = 0; - -# Process command-line options -Getopt::Long::Configure("no_auto_abbrev"); -if (GetOptions( - "jobs=i" => \$jobs, - "dont-clean" => \$disable_clean, - "dont-accumulate" => \$dont_accum, - "pass-suffix=s" => \$pass_suffix, - "qsub" => \$useqsub, - "dry-run" => \$dryrun, - "epsilon=s" => \$epsilon, - "help" => \$help, - "weights=s" => \$initial_weights, - "reg=f" => \$reg, - "use-make=i" => \$use_make, - "max-iterations=i" => \$max_iterations, - "pmem=s" => \$pmem, - "cpbin!" => \$cpbin, - "ref-files=s" => \$refFiles, - "metric=s" => \$metric, - "source-file=s" => \$srcFile, - "workdir=s" => \$dir, -) == 0 || @ARGV!=1 || $help) { - print_help(); - exit; -} - -die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; - -if ($useqsub) { - $use_make = 0; - die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); -} - -my @missing_args = (); -if (!defined $srcFile) { push @missing_args, "--source-file"; } -if (!defined $refFiles) { push @missing_args, "--ref-files"; } -if (!defined $initial_weights) { push @missing_args, "--weights"; } -die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); - -if ($metric =~ /^(combi|ter)$/i) { - $lines_per_mapper = 5; -} - -($iniFile) = @ARGV; - - -sub write_config; -sub enseg; -sub print_help; - -my $nodelist; -my $host =check_output("hostname"); chomp $host; -my $bleu; -my $interval_count = 0; -my $logfile; -my $projected_score; - -# used in sorting scores -my $DIR_FLAG = '-r'; -if ($metric =~ /^ter$|^aer$/i) { - $DIR_FLAG = ''; -} - -my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); - -unless ($dir){ - $dir = "rampion"; -} -unless ($dir =~ /^\//){ # convert relative path to absolute path - my $basedir = check_output("pwd"); - chomp $basedir; - $dir = "$basedir/$dir"; -} - - -# Initializations and helper functions -srand; - -my @childpids = (); -my @cleanupcmds = (); - -sub cleanup { - print STDERR "Cleanup...\n"; - for my $pid (@childpids){ unchecked_call("kill $pid"); } - for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } - exit 1; -}; -# Always call cleanup, no matter how we exit -*CORE::GLOBAL::exit = - sub{ cleanup(); }; -$SIG{INT} = "cleanup"; -$SIG{TERM} = "cleanup"; -$SIG{HUP} = "cleanup"; - -my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; -my $newIniFile = "$dir/$decoderBase.ini"; -my $inputFileName = "$dir/input"; -my $user = $ENV{"USER"}; -# process ini file --e $iniFile || die "Error: could not open $iniFile for reading\n"; -open(INI, $iniFile); - -use File::Basename qw(basename); -#pass bindir, refs to vars holding bin -sub modbin { - local $_; - my $bindir=shift; - check_call("mkdir -p $bindir"); - -d $bindir || die "couldn't make bindir $bindir"; - for (@_) { - my $src=$$_; - $$_="$bindir/".basename($src); - check_call("cp -p $src $$_"); - } -} -sub dirsize { - opendir ISEMPTY,$_[0]; - return scalar(readdir(ISEMPTY))-1; -} -my @allweights; -if ($dryrun){ - write_config(*STDERR); - exit 0; -} else { - if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-pro.pl outputs - die "ERROR: working dir $dir already exists\n\n"; - } else { - -e $dir || mkdir $dir; - mkdir "$dir/hgs"; - modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; - mkdir "$dir/scripts"; - my $cmdfile="$dir/rerun-pro.sh"; - open CMD,'>',$cmdfile; - print CMD "cd ",&getcwd,"\n"; -# print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. - my $cline=&cmdline."\n"; - print CMD $cline; - close CMD; - print STDERR $cline; - chmod(0755,$cmdfile); - check_call("cp $initial_weights $dir/weights.0"); - die "Can't find weights.0" unless (-e "$dir/weights.0"); - } - write_config(*STDERR); -} - - -# Generate initial files and values -check_call("cp $iniFile $newIniFile"); -$iniFile = $newIniFile; - -my $newsrc = "$dir/dev.input"; -enseg($srcFile, $newsrc); -$srcFile = $newsrc; -my $devSize = 0; -open F, "<$srcFile" or die "Can't read $srcFile: $!"; -while() { $devSize++; } -close F; - -unless($best_weights){ $best_weights = $weights; } -unless($projected_score){ $projected_score = 0.0; } -$seen_weights{$weights} = 1; -my $kbest = "$dir/kbest"; -if ($dont_accum) { - $kbest = ''; -} else { - check_call("mkdir -p $kbest"); - $kbest = "--kbest_repository $kbest"; -} - -my $random_seed = int(time / 1000); -my $lastWeightsFile; -my $lastPScore = 0; -# main optimization loop -while (1){ - print STDERR "\n\nITERATION $iteration\n==========\n"; - - if ($iteration > $max_iterations){ - print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; - last; - } - # iteration-specific files - my $runFile="$dir/run.raw.$iteration"; - my $onebestFile="$dir/1best.$iteration"; - my $logdir="$dir/logs.$iteration"; - my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; - my $scorerLog="$logdir/scorer.log.$iteration"; - check_call("mkdir -p $logdir"); - - - #decode - print STDERR "RUNNING DECODER AT "; - print STDERR unchecked_output("date"); - my $im1 = $iteration - 1; - my $weightsFile="$dir/weights.$im1"; - push @allweights, "-w $dir/weights.$im1"; - `rm -f $dir/hgs/*.gz`; - my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; - my $pcmd; - if ($use_make) { - $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; - } else { - $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; - } - my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; - print STDERR "COMMAND:\n$cmd\n"; - check_bash_call($cmd); - my $num_hgs; - my $num_topbest; - my $retries = 0; - while($retries < 5) { - $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); - $num_topbest = check_output("wc -l < $runFile"); - print STDERR "NUMBER OF HGs: $num_hgs\n"; - print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; - if($devSize == $num_hgs && $devSize == $num_topbest) { - last; - } else { - print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; - sleep(3); - } - $retries++; - } - die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); - my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -m $metric"); - chomp $dec_score; - print STDERR "DECODER SCORE: $dec_score\n"; - - # save space - check_call("gzip -f $runFile"); - check_call("gzip -f $decoderLog"); - - # run optimizer - print STDERR "RUNNING OPTIMIZER AT "; - print STDERR unchecked_output("date"); - print STDERR " - GENERATE TRAINING EXEMPLARS\n"; - my $mergeLog="$logdir/prune-merge.log.$iteration"; - - my $score = 0; - my $icc = 0; - my $inweights="$dir/weights.$im1"; - my $outweights="$dir/weights.$iteration"; - $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; - print STDERR "COMMAND:\n$cmd\n"; - check_call($cmd); - $cmd="$MAPPER $refs_comma_sep -m $metric -i $dir/agenda.$im1 $kbest -w $inweights > $outweights"; - check_call($cmd); - $lastWeightsFile = $outweights; - $iteration++; - `rm hgs/*.gz`; - print STDERR "\n==========\n"; -} - -print STDERR "\nFINAL WEIGHTS: $lastWeightsFile\n(Use -w with the decoder)\n\n"; - -print STDOUT "$lastWeightsFile\n"; - -exit 0; - -sub get_lines { - my $fn = shift @_; - open FL, "<$fn" or die "Couldn't read $fn: $!"; - my $lc = 0; - while() { $lc++; } - return $lc; -} - -sub get_comma_sep_refs { - my ($r,$p) = @_; - my $o = check_output("echo $p"); - chomp $o; - my @files = split /\s+/, $o; - return "-$r " . join(" -$r ", @files); -} - -sub read_weights_file { - my ($file) = @_; - open F, "<$file" or die "Couldn't read $file: $!"; - my @r = (); - my $pm = -1; - while() { - next if /^#/; - next if /^\s*$/; - chomp; - if (/^(.+)\s+(.+)$/) { - my $m = $1; - my $w = $2; - die "Weights out of order: $m <= $pm" unless $m > $pm; - push @r, $w; - } else { - warn "Unexpected feature name in weight file: $_"; - } - } - close F; - return join ' ', @r; -} - -# subs -sub write_config { - my $fh = shift; - my $cleanup = "yes"; - if ($disable_clean) {$cleanup = "no";} - - print $fh "\n"; - print $fh "DECODER: $decoder\n"; - print $fh "INI FILE: $iniFile\n"; - print $fh "WORKING DIR: $dir\n"; - print $fh "SOURCE (DEV): $srcFile\n"; - print $fh "REFS (DEV): $refFiles\n"; - print $fh "EVAL METRIC: $metric\n"; - print $fh "MAX ITERATIONS: $max_iterations\n"; - print $fh "JOBS: $jobs\n"; - print $fh "HEAD NODE: $host\n"; - print $fh "PMEM (DECODING): $pmem\n"; - print $fh "CLEANUP: $cleanup\n"; -} - -sub update_weights_file { - my ($neww, $rfn, $rpts) = @_; - my @feats = @$rfn; - my @pts = @$rpts; - my $num_feats = scalar @feats; - my $num_pts = scalar @pts; - die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; - open G, ">$neww" or die; - for (my $i = 0; $i < $num_feats; $i++) { - my $f = $feats[$i]; - my $lambda = $pts[$i]; - print G "$f $lambda\n"; - } - close G; -} - -sub enseg { - my $src = shift; - my $newsrc = shift; - open(SRC, $src); - open(NEWSRC, ">$newsrc"); - my $i=0; - while (my $line=){ - chomp $line; - if ($line =~ /^\s* tags, you must include a zero-based id attribute"; - } - } else { - print NEWSRC "$line\n"; - } - $i++; - } - close SRC; - close NEWSRC; - die "Empty dev set!" if ($i == 0); -} - -sub print_help { - - my $executable = check_output("basename $0"); chomp $executable; - print << "Help"; - -Usage: $executable [options] - - $executable [options] - Runs a complete PRO optimization using the ini file specified. - -Required: - - --ref-files - Dev set ref files. This option takes only a single string argument. - To use multiple files (including file globbing), this argument should - be quoted. - - --source-file - Dev set source file. - - --weights - Initial weights file (use empty file to start from 0) - -General options: - - --help - Print this message and exit. - - --dont-accumulate - Don't accumulate k-best lists from multiple iterations. - - --max-iterations - Maximum number of iterations to run. If not specified, defaults - to $default_max_iter. - - --metric - Metric to optimize. - Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi - - --pass-suffix - If the decoder is doing multi-pass decoding, the pass suffix "2", - "3", etc., is used to control what iteration of weights is set. - - --workdir - Directory for intermediate and output files. If not specified, the - name is derived from the ini filename. Assuming that the ini - filename begins with the decoder name and ends with ini, the default - name of the working directory is inferred from the middle part of - the filename. E.g. an ini file named decoder.foo.ini would have - a default working directory name foo. - -Regularization options: - - --reg - l2 regularization strength [default=500]. The greater this value, - the closer to zero the weights will be. - -Job control options: - - --jobs - Number of decoder processes to run in parallel. [default=$default_jobs] - - --qsub - Use qsub to run jobs in parallel (qsub must be configured in - environment/LocalEnvironment.pm) - - --pmem - Amount of physical memory requested for parallel decoding jobs - (used with qsub requests only) - -Help -} - -sub convert { - my ($str) = @_; - my @ps = split /;/, $str; - my %dict = (); - for my $p (@ps) { - my ($k, $v) = split /=/, $p; - $dict{$k} = $v; - } - return %dict; -} - - -sub cmdline { - return join ' ',($0,@ORIG_ARGV); -} - -#buggy: last arg gets quoted sometimes? -my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; -my $shell_escape_in_quote=qr{[\\"\$`!]}; - -sub escape_shell { - my ($arg)=@_; - return undef unless defined $arg; - if ($arg =~ /$is_shell_special/) { - $arg =~ s/($shell_escape_in_quote)/\\$1/g; - return "\"$arg\""; - } - return $arg; -} - -sub escaped_shell_args { - return map {local $_=$_;chomp;escape_shell($_)} @_; -} - -sub escaped_shell_args_str { - return join ' ',&escaped_shell_args(@_); -} - -sub escaped_cmdline { - return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); -} diff --git a/rampion/rampion_cccp.cc b/rampion/rampion_cccp.cc deleted file mode 100644 index 1e36dc51..00000000 --- a/rampion/rampion_cccp.cc +++ /dev/null @@ -1,168 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "filelib.h" -#include "stringlib.h" -#include "weights.h" -#include "hg_io.h" -#include "kbest.h" -#include "viterbi.h" -#include "ns.h" -#include "ns_docscorer.h" -#include "candidate_set.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") - ("weights,w",po::value(), "[REQD] Weights files from current iterations") - ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") - ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") - ("kbest_repository,R",po::value(), "Accumulate k-best lists from previous iterations (parameter is path to repository)") - ("kbest_size,k",po::value()->default_value(500u), "Top k-hypotheses to extract") - ("cccp_iterations,I", po::value()->default_value(10u), "CCCP iterations (T')") - ("ssd_iterations,J", po::value()->default_value(5u), "Stochastic subgradient iterations (T'')") - ("eta", po::value()->default_value(1e-4), "Step size") - ("regularization_strength,C", po::value()->default_value(1.0), "L2 regularization strength") - ("alpha,a", po::value()->default_value(10.0), "Cost scale (alpha); alpha * [1-metric(y,y')]") - ("help,h", "Help"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - bool flag = false; - if (!conf->count("reference")) { - cerr << "Please specify one or more references using -r \n"; - flag = true; - } - if (!conf->count("weights")) { - cerr << "Please specify weights using -w \n"; - flag = true; - } - if (flag || conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -struct GainFunction { - explicit GainFunction(const EvaluationMetric* m) : metric(m) {} - float operator()(const SufficientStats& eval_feats) const { - float g = metric->ComputeScore(eval_feats); - if (!metric->IsErrorMetric()) g = 1 - g; - return g; - } - const EvaluationMetric* metric; -}; - -template -void CostAugmentedSearch(const GainFunc& gain, - const training::CandidateSet& cs, - const SparseVector& w, - double alpha, - SparseVector* fmap) { - unsigned best_i = 0; - double best = -numeric_limits::infinity(); - for (unsigned i = 0; i < cs.size(); ++i) { - double s = cs[i].fmap.dot(w) + alpha * gain(cs[i].eval_feats); - if (s > best) { - best = s; - best_i = i; - } - } - *fmap = cs[best_i].fmap; -} - - - -// runs lines 4--15 of rampion algorithm -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string evaluation_metric = conf["evaluation_metric"].as(); - - EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); - DocumentScorer ds(metric, conf["reference"].as >()); - cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; - double goodsign = -1; - double badsign = -goodsign; - - Hypergraph hg; - string last_file; - ReadFile in_read(conf["input"].as()); - string kbest_repo; - if (conf.count("kbest_repository")) { - kbest_repo = conf["kbest_repository"].as(); - MkDirP(kbest_repo); - } - istream &in=*in_read.stream(); - const unsigned kbest_size = conf["kbest_size"].as(); - const unsigned tp = conf["cccp_iterations"].as(); - const unsigned tpp = conf["ssd_iterations"].as(); - const double eta = conf["eta"].as(); - const double reg = conf["regularization_strength"].as(); - const double alpha = conf["alpha"].as(); - SparseVector weights; - { - vector vweights; - const string weightsf = conf["weights"].as(); - Weights::InitFromFile(weightsf, &vweights); - Weights::InitSparseVector(vweights, &weights); - } - string line, file; - vector kis; - cerr << "Loading hypergraphs...\n"; - while(getline(in, line)) { - istringstream is(line); - int sent_id; - kis.resize(kis.size() + 1); - training::CandidateSet& curkbest = kis.back(); - string kbest_file; - if (kbest_repo.size()) { - ostringstream os; - os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; - kbest_file = os.str(); - if (FileExists(kbest_file)) - curkbest.ReadFromFile(kbest_file); - } - is >> file >> sent_id; - ReadFile rf(file); - if (kis.size() % 5 == 0) { cerr << '.'; } - if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); - hg.Reweight(weights); - curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); - if (kbest_file.size()) - curkbest.WriteToFile(kbest_file); - } - cerr << "\nHypergraphs loaded.\n"; - - vector > goals(kis.size()); // f(x_i,y+,h+) - SparseVector fear; // f(x,y-,h-) - const GainFunction gain(metric); - for (unsigned iterp = 1; iterp <= tp; ++iterp) { - cerr << "CCCP Iteration " << iterp << endl; - for (unsigned i = 0; i < goals.size(); ++i) - CostAugmentedSearch(gain, kis[i], weights, goodsign * alpha, &goals[i]); - for (unsigned iterpp = 1; iterpp <= tpp; ++iterpp) { - cerr << " SSD Iteration " << iterpp << endl; - for (unsigned i = 0; i < goals.size(); ++i) { - CostAugmentedSearch(gain, kis[i], weights, badsign * alpha, &fear); - weights -= weights * (eta * reg / goals.size()); - weights += (goals[i] - fear) * eta; - } - } - } - vector w; - weights.init_vector(&w); - Weights::WriteToFile("-", w); - return 0; -} - diff --git a/rampion/rampion_generate_input.pl b/rampion/rampion_generate_input.pl deleted file mode 100755 index b30fc4fd..00000000 --- a/rampion/rampion_generate_input.pl +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; -my $d = shift @ARGV; -die "Can't find directory $d" unless -d $d; - -opendir(DIR, $d) or die "Can't read $d: $!"; -my @hgs = grep { /\.gz$/ } readdir(DIR); -closedir DIR; - -for my $hg (@hgs) { - my $file = $hg; - my $id = $hg; - $id =~ s/(\.json)?\.gz//; - print "$d/$file $id\n"; -} - diff --git a/training/Makefile.am b/training/Makefile.am index f9c25391..e95e045f 100644 --- a/training/Makefile.am +++ b/training/Makefile.am @@ -1,91 +1,11 @@ -bin_PROGRAMS = \ - fast_align \ - lbl_model \ - test_ngram \ - mr_em_map_adapter \ - mr_em_adapted_reduce \ - mr_reduce_to_weights \ - mr_optimize_reduce \ - grammar_convert \ - plftools \ - collapse_weights \ - mpi_extract_reachable \ - mpi_extract_features \ - mpi_online_optimize \ - mpi_flex_optimize \ - mpi_batch_optimize \ - mpi_compute_cllh \ - augment_grammar +SUBDIRS = \ + liblbfgs \ + utils \ + crf \ + minrisk \ + dpmert \ + pro \ + dtrain \ + mira \ + rampion -noinst_PROGRAMS = \ - lbfgs_test \ - optimize_test - -TESTS = lbfgs_test optimize_test - -noinst_LIBRARIES = libtraining.a -libtraining_a_SOURCES = \ - candidate_set.cc \ - entropy.cc \ - optimize.cc \ - online_optimizer.cc \ - risk.cc - -mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc -mpi_batch_optimize_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc -mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -augment_grammar_SOURCES = augment_grammar.cc -augment_grammar_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -test_ngram_SOURCES = test_ngram.cc -test_ngram_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz - -fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -lbl_model_SOURCES = lbl_model.cc -lbl_model_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -grammar_convert_SOURCES = grammar_convert.cc -grammar_convert_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -collapse_weights_SOURCES = collapse_weights.cc -collapse_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -mr_optimize_reduce_SOURCES = mr_optimize_reduce.cc -mr_optimize_reduce_LDADD = libtraining.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -mr_em_map_adapter_SOURCES = mr_em_map_adapter.cc -mr_em_map_adapter_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -mr_reduce_to_weights_SOURCES = mr_reduce_to_weights.cc -mr_reduce_to_weights_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -mr_em_adapted_reduce_SOURCES = mr_em_adapted_reduce.cc -mr_em_adapted_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -plftools_SOURCES = plftools.cc -plftools_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/utils/libutils.a -lz - -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I../klm diff --git a/training/add-model1-features-to-scfg.pl b/training/add-model1-features-to-scfg.pl deleted file mode 100755 index a0074317..00000000 --- a/training/add-model1-features-to-scfg.pl +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/perl -w - -# [X] ||| so [X,1] die [X,2] der ||| as [X,1] existing [X,2] the ||| 2.47712135315 2.53182387352 5.07100057602 ||| 0-0 2-2 4-4 -# [X] ||| so [X,1] die [X,2] der ||| this [X,1] the [X,2] of ||| 2.47712135315 3.19828724861 2.38270020485 ||| 0-0 2-2 4-4 -# [X] ||| so [X,1] die [X,2] der ||| as [X,1] the [X,2] the ||| 2.47712135315 2.53182387352 1.48463630676 ||| 0-0 2-2 4-4 -# [X] ||| so [X,1] die [X,2] der ||| is [X,1] the [X,2] of the ||| 2.47712135315 3.45197868347 2.64251494408 ||| 0-0 2-2 4-4 4-5 - -die "Usage: $0 model1.f-e model1.e-f < grammar.scfg\n (use trianing/model1 to extract the model files)\n" unless scalar @ARGV == 2; - -my $fm1 = shift @ARGV; -die unless $fm1; -my $frm1 = shift @ARGV; -die unless $frm1; -open M1,"<$fm1" or die; -open RM1,"<$frm1" or die; -print STDERR "Loading Model 1 probs from $fm1...\n"; -my %m1; -while() { - chomp; - my ($f, $e, $lp) = split /\s+/; - $m1{$e}->{$f} = exp($lp); -} -close M1; - -print STDERR "Loading Inverse Model 1 probs from $frm1...\n"; -my %rm1; -while() { - chomp; - my ($e, $f, $lp) = split /\s+/; - $rm1{$f}->{$e} = exp($lp); -} -close RM1; - -my @label = qw( EGivenF LexFGivenE LexEGivenF ); -while(<>) { - chomp; - my ($l, $f, $e, $sscores, $al) = split / \|\|\| /; - my @scores = split /\s+/, $sscores; - unless ($sscores =~ /=/) { - for (my $i=0; $i<3; $i++) { $scores[$i] = "$label[$i]=$scores[$i]"; } - } - push @scores, "RuleCount=1"; - my @fs = split /\s+/, $f; - my @es = split /\s+/, $e; - my $flen = scalar @fs; - my $elen = scalar @es; - my $pgen = 0; - my $nongen = 0; - for (my $i =0; $i < $flen; $i++) { - my $ftot = 0; - next if ($fs[$i] =~ /\[X/); - my $cr = $rm1{$fs[$i]}; - for (my $j=0; $j <= $elen; $j++) { - my $ej = ''; - if ($j < $elen) { $ej = $es[$j]; } - my $p = $cr->{$ej}; - if (defined $p) { $ftot += $p; } - } - if ($ftot == 0) { $nongen = 1; last; } - $pgen += log($ftot) - log($elen); - } - my $bad = 0; - my $good = 0; - unless ($nongen) { push @scores, "RGood=1"; $good++; } else { push @scores, "RBad=1"; $bad++; } - - $nongen = 0; - $pgen = 0; - for (my $i =0; $i < $elen; $i++) { - my $etot = 0; - next if ($es[$i] =~ /\[X/); - my $cr = $m1{$es[$i]}; -# print STDERR "$es[$i]\n"; - for (my $j=0; $j <= $flen; $j++) { - my $fj = ''; - if ($j < $flen) { $fj = $fs[$j]; } - my $p = $cr->{$fj}; -# print STDERR " $fs[$j] : $p\n"; - if (defined $p) { $etot += $p; } - } - if ($etot == 0) { $nongen = 1; last; } - $pgen += log($etot) - log($flen); - } - unless ($nongen) { - push @scores, "FGood=1"; - if ($good) { push @scores, "BothGood=1"; } else { push @scores, "SusDel=1"; } - } else { - push @scores, "FBad=1"; - if ($bad) { push @scores, "BothBad=1"; } else { push @scores, "SusHall=1"; } - } - print "$l ||| $f ||| $e ||| @scores"; - if (defined $al) { print " ||| $al\n"; } else { print "\n"; } -} - diff --git a/training/candidate_set.cc b/training/candidate_set.cc deleted file mode 100644 index 087efec3..00000000 --- a/training/candidate_set.cc +++ /dev/null @@ -1,169 +0,0 @@ -#include "candidate_set.h" - -#include - -#include - -#include "verbose.h" -#include "ns.h" -#include "filelib.h" -#include "wordid.h" -#include "tdict.h" -#include "hg.h" -#include "kbest.h" -#include "viterbi.h" - -using namespace std; - -namespace training { - -struct ApproxVectorHasher { - static const size_t MASK = 0xFFFFFFFFull; - union UType { - double f; // leave as double - size_t i; - }; - static inline double round(const double x) { - UType t; - t.f = x; - size_t r = t.i & MASK; - if ((r << 1) > MASK) - t.i += MASK - r + 1; - else - t.i &= (1ull - MASK); - return t.f; - } - size_t operator()(const SparseVector& x) const { - size_t h = 0x573915839; - for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { - UType t; - t.f = it->second; - if (t.f) { - size_t z = (t.i >> 32); - boost::hash_combine(h, it->first); - boost::hash_combine(h, z); - } - } - return h; - } -}; - -struct ApproxVectorEquals { - bool operator()(const SparseVector& a, const SparseVector& b) const { - SparseVector::const_iterator bit = b.begin(); - for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { - if (bit == b.end() || - ait->first != bit->first || - ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) - return false; - ++bit; - } - if (bit != b.end()) return false; - return true; - } -}; - -struct CandidateCompare { - bool operator()(const Candidate& a, const Candidate& b) const { - ApproxVectorEquals eq; - return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); - } -}; - -struct CandidateHasher { - size_t operator()(const Candidate& x) const { - boost::hash > hhasher; - ApproxVectorHasher vhasher; - size_t ha = hhasher(x.ewords); - boost::hash_combine(ha, vhasher(x.fmap)); - return ha; - } -}; - -static void ParseSparseVector(string& line, size_t cur, SparseVector* out) { - SparseVector& x = *out; - size_t last_start = cur; - size_t last_comma = string::npos; - while(cur <= line.size()) { - if (line[cur] == ' ' || cur == line.size()) { - if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { - cerr << "[ERROR] " << line << endl << " position = " << cur << endl; - exit(1); - } - const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); - if (cur < line.size()) line[cur] = 0; - const double val = strtod(&line[last_comma + 1], NULL); - x.set_value(fid, val); - - last_comma = string::npos; - last_start = cur+1; - } else { - if (line[cur] == '=') - last_comma = cur; - } - ++cur; - } -} - -void CandidateSet::WriteToFile(const string& file) const { - WriteFile wf(file); - ostream& out = *wf.stream(); - out.precision(10); - string ss; - for (unsigned i = 0; i < cs.size(); ++i) { - out << TD::GetString(cs[i].ewords) << endl; - out << cs[i].fmap << endl; - cs[i].eval_feats.Encode(&ss); - out << ss << endl; - } -} - -void CandidateSet::ReadFromFile(const string& file) { - if(!SILENT) cerr << "Reading candidates from " << file << endl; - ReadFile rf(file); - istream& in = *rf.stream(); - string cand; - string feats; - string ss; - while(getline(in, cand)) { - getline(in, feats); - getline(in, ss); - assert(in); - cs.push_back(Candidate()); - TD::ConvertSentence(cand, &cs.back().ewords); - ParseSparseVector(feats, 0, &cs.back().fmap); - cs.back().eval_feats = SufficientStats(ss); - } - if(!SILENT) cerr << " read " << cs.size() << " candidates\n"; -} - -void CandidateSet::Dedup() { - if(!SILENT) cerr << "Dedup in=" << cs.size(); - tr1::unordered_set u; - while(cs.size() > 0) { - u.insert(cs.back()); - cs.pop_back(); - } - tr1::unordered_set::iterator it = u.begin(); - while (it != u.end()) { - cs.push_back(*it); - it = u.erase(it); - } - if(!SILENT) cerr << " out=" << cs.size() << endl; -} - -void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { - KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); - - for (unsigned i = 0; i < kbest_size; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg.nodes_.size() - 1, i); - if (!d) break; - cs.push_back(Candidate(d->yield, d->feature_values)); - if (scorer) - scorer->Evaluate(d->yield, &cs.back().eval_feats); - } - Dedup(); -} - -} diff --git a/training/candidate_set.h b/training/candidate_set.h deleted file mode 100644 index 9d326ed0..00000000 --- a/training/candidate_set.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef _CANDIDATE_SET_H_ -#define _CANDIDATE_SET_H_ - -#include -#include - -#include "ns.h" -#include "wordid.h" -#include "sparse_vector.h" - -class Hypergraph; - -namespace training { - -struct Candidate { - Candidate() {} - Candidate(const std::vector& e, const SparseVector& fm) : - ewords(e), - fmap(fm) {} - Candidate(const std::vector& e, - const SparseVector& fm, - const SegmentEvaluator& se) : - ewords(e), - fmap(fm) { - se.Evaluate(ewords, &eval_feats); - } - - void swap(Candidate& other) { - eval_feats.swap(other.eval_feats); - ewords.swap(other.ewords); - fmap.swap(other.fmap); - } - - std::vector ewords; - SparseVector fmap; - SufficientStats eval_feats; -}; - -// represents some kind of collection of translation candidates, e.g. -// aggregated k-best lists, sample lists, etc. -class CandidateSet { - public: - CandidateSet() {} - inline size_t size() const { return cs.size(); } - const Candidate& operator[](size_t i) const { return cs[i]; } - - void ReadFromFile(const std::string& file); - void WriteToFile(const std::string& file) const; - void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); - // TODO add code to do unique k-best - // TODO add code to draw k samples - - private: - void Dedup(); - std::vector cs; -}; - -} - -#endif diff --git a/training/cllh_observer.cc b/training/cllh_observer.cc deleted file mode 100644 index 4ec2fa65..00000000 --- a/training/cllh_observer.cc +++ /dev/null @@ -1,52 +0,0 @@ -#include "cllh_observer.h" - -#include -#include - -#include "inside_outside.h" -#include "hg.h" -#include "sentence_metadata.h" - -using namespace std; - -static const double kMINUS_EPSILON = -1e-6; - -ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {} - -void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) { - cur_obj = 0; - state = 1; -} - -void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - assert(state == 1); - state = 2; - SparseVector cur_model_exp; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_obj = log(z); -} - -void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - SparseVector ref_exp; - const prob_t ref_z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); - - double log_ref_z = log(ref_z); - - // rounding errors means that <0 is too strict - if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; - exit(1); - } - assert(!std::isnan(log_ref_z)); - acc_obj += (cur_obj - log_ref_z); - trg_words += smeta.GetReference().size(); -} - diff --git a/training/cllh_observer.h b/training/cllh_observer.h deleted file mode 100644 index 0de47331..00000000 --- a/training/cllh_observer.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _CLLH_OBSERVER_H_ -#define _CLLH_OBSERVER_H_ - -#include "decoder.h" - -struct ConditionalLikelihoodObserver : public DecoderObserver { - - ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} - ~ConditionalLikelihoodObserver(); - - void Reset() { - acc_obj = 0; - trg_words = 0; - } - - virtual void NotifyDecodingStart(const SentenceMetadata&); - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg); - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg); - - unsigned trg_words; - double acc_obj; - double cur_obj; - int state; -}; - -#endif diff --git a/training/collapse_weights.cc b/training/collapse_weights.cc deleted file mode 100644 index c03eb031..00000000 --- a/training/collapse_weights.cc +++ /dev/null @@ -1,110 +0,0 @@ -char const* NOTES = - "ZF_and_E means unnormalized scaled features.\n" - "For grammars with one nonterminal: F_and_E is joint,\n" - "F_given_E and E_given_F are conditional.\n" - "TODO: group rules by root nonterminal and then normalize.\n"; - - -#include -#include -#include - -#include -#include -#include - -#include "prob.h" -#include "filelib.h" -#include "trule.h" -#include "weights.h" - -namespace po = boost::program_options; -using namespace std; - -typedef std::tr1::unordered_map, prob_t, boost::hash > > MarginalMap; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("grammar,g", po::value(), "Grammar file") - ("weights,w", po::value(), "Weights file") - ("unnormalized,u", "Always include ZF_and_E unnormalized score (default: only if sum was >1)") - ; - po::options_description clo("Command line options"); - clo.add_options() - ("config,c", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - const string cfg = (*conf)["config"].as(); - cerr << "Configuration file: " << cfg << endl; - ifstream config(cfg.c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("grammar") || !conf->count("weights")) { - cerr << dcmdline_options << endl; - cerr << NOTES << endl; - exit(1); - } -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - const string wfile = conf["weights"].as(); - const string gfile = conf["grammar"].as(); - vector w; - Weights::InitFromFile(wfile, &w); - MarginalMap e_tots; - MarginalMap f_tots; - prob_t tot; - { - ReadFile rf(gfile); - assert(*rf.stream()); - istream& in = *rf.stream(); - cerr << "Computing marginals...\n"; - int lc = 0; - while(in) { - string line; - getline(in, line); - ++lc; - if (line.empty()) continue; - TRule tr(line, true); - if (tr.GetFeatureValues().empty()) - cerr << "Line " << lc << ": empty features - may introduce bias\n"; - prob_t prob; - prob.logeq(tr.GetFeatureValues().dot(w)); - e_tots[tr.e_] += prob; - f_tots[tr.f_] += prob; - tot += prob; - } - } - bool normalized = (fabs(log(tot)) < 0.001); - cerr << "Total: " << tot << (normalized ? " [normalized]" : " [scaled]") << endl; - ReadFile rf(gfile); - istream&in = *rf.stream(); - while(in) { - string line; - getline(in, line); - if (line.empty()) continue; - TRule tr(line, true); - const double lp = tr.GetFeatureValues().dot(w); - if (std::isinf(lp)) { continue; } - tr.scores_.clear(); - - cout << tr.AsString() << " ||| F_and_E=" << lp - log(tot); - if (!normalized || conf.count("unnormalized")) { - cout << ";ZF_and_E=" << lp; - } - cout << ";F_given_E=" << lp - log(e_tots[tr.e_]) - << ";E_given_F=" << lp - log(f_tots[tr.f_]) << endl; - } - return 0; -} - diff --git a/training/crf/Makefile.am b/training/crf/Makefile.am new file mode 100644 index 00000000..d203df25 --- /dev/null +++ b/training/crf/Makefile.am @@ -0,0 +1,27 @@ +bin_PROGRAMS = \ + mpi_batch_optimize \ + mpi_compute_cllh \ + mpi_extract_features \ + mpi_extract_reachable \ + mpi_flex_optimize \ + mpi_online_optimize + +mpi_online_optimize_SOURCES = mpi_online_optimize.cc +mpi_online_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc +mpi_flex_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc +mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +mpi_extract_features_SOURCES = mpi_extract_features.cc +mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc +mpi_batch_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc +mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir)/training -I$(top_srcdir)/training/utils -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/crf/cllh_observer.cc b/training/crf/cllh_observer.cc new file mode 100644 index 00000000..4ec2fa65 --- /dev/null +++ b/training/crf/cllh_observer.cc @@ -0,0 +1,52 @@ +#include "cllh_observer.h" + +#include +#include + +#include "inside_outside.h" +#include "hg.h" +#include "sentence_metadata.h" + +using namespace std; + +static const double kMINUS_EPSILON = -1e-6; + +ConditionalLikelihoodObserver::~ConditionalLikelihoodObserver() {} + +void ConditionalLikelihoodObserver::NotifyDecodingStart(const SentenceMetadata&) { + cur_obj = 0; + state = 1; +} + +void ConditionalLikelihoodObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 1); + state = 2; + SparseVector cur_model_exp; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); +} + +void ConditionalLikelihoodObserver::NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector ref_exp; + const prob_t ref_z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + + double log_ref_z = log(ref_z); + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); +} + diff --git a/training/crf/cllh_observer.h b/training/crf/cllh_observer.h new file mode 100644 index 00000000..0de47331 --- /dev/null +++ b/training/crf/cllh_observer.h @@ -0,0 +1,26 @@ +#ifndef _CLLH_OBSERVER_H_ +#define _CLLH_OBSERVER_H_ + +#include "decoder.h" + +struct ConditionalLikelihoodObserver : public DecoderObserver { + + ConditionalLikelihoodObserver() : trg_words(), acc_obj(), cur_obj() {} + ~ConditionalLikelihoodObserver(); + + void Reset() { + acc_obj = 0; + trg_words = 0; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&); + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg); + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg); + + unsigned trg_words; + double acc_obj; + double cur_obj; + int state; +}; + +#endif diff --git a/training/crf/mpi_batch_optimize.cc b/training/crf/mpi_batch_optimize.cc new file mode 100644 index 00000000..2eff07e4 --- /dev/null +++ b/training/crf/mpi_batch_optimize.cc @@ -0,0 +1,372 @@ +#include +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#include +namespace mpi = boost::mpi; +#endif + +#include +#include +#include + +#include "sentence_metadata.h" +#include "cllh_observer.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "stringlib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value(),"Input feature weights file") + ("training_data,t",po::value(),"Training data") + ("test_data,T",po::value(),"(optional) test data") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") + ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") + ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") + ("gaussian_prior,p","Use a Gaussian prior on the weights") + ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior") + ("means,u", po::value(), "(optional) file containing the means for Gaussian prior"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_grad.clear(); + acc_obj = 0; + total_complete = 0; + trg_words = 0; + } + + void SetLocalGradientAndObjective(vector* g, double* o) const { + *o = acc_obj; + for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + (*g)[it->first] = it->second.as_float(); + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_model_exp.clear(); + cur_obj = 0; + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); + cur_model_exp /= z; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector ref_exp; + const prob_t ref_z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + ref_exp /= ref_z; + + double log_ref_z; +#if 0 + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } +#else + log_ref_z = log(ref_z); +#endif + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + ref_exp -= cur_model_exp; + acc_grad -= ref_exp; + acc_obj += (cur_obj - log_ref_z); + trg_words += smeta.GetReference().size(); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 3) { + ++total_complete; + } else { + } + } + + int total_complete; + SparseVector cur_model_exp; + SparseVector acc_grad; + double acc_obj; + double cur_obj; + unsigned trg_words; + int state; +}; + +void ReadConfig(const string& ini, vector* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + while(in) { + string line; + getline(in, line); + if (!in) continue; + out->push_back(line); + } +} + +void StoreConfig(const vector& cfg, istringstream* o) { + ostringstream os; + for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } + o->str(os.str()); +} + +template +struct VectorPlus : public binary_function, vector, vector > { + vector operator()(const vector& a, const vector& b) const { + assert(a.size() == b.size()); + vector v(a.size()); + transform(a.begin(), a.end(), b.begin(), v.begin(), plus()); + return v; + } +}; + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + + // load cdec.ini and set up decoder + vector cdec_ini; + ReadConfig(conf["decoder_config"].as(), &cdec_ini); + istringstream ini; + StoreConfig(cdec_ini, &ini); + if (rank == 0) cerr << "Loading grammar...\n"; + Decoder* decoder = new Decoder(&ini); + if (decoder->GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + return 1; + } + if (rank == 0) cerr << "Done loading grammar!\n"; + + // load initial weights + if (rank == 0) { cerr << "Loading weights...\n"; } + vector& lambdas = decoder->CurrentWeightVector(); + Weights::InitFromFile(conf["input_weights"].as(), &lambdas); + if (rank == 0) { cerr << "Done loading weights.\n"; } + + // freeze feature set (should be optional?) + const bool freeze_feature_set = true; + if (freeze_feature_set) FD::Freeze(); + + const int num_feats = FD::NumFeats(); + if (rank == 0) cerr << "Number of features: " << num_feats << endl; + lambdas.resize(num_feats); + + const bool gaussian_prior = conf.count("gaussian_prior"); + vector means(num_feats, 0); + if (conf.count("means")) { + if (!gaussian_prior) { + cerr << "Don't use --means without --gaussian_prior!\n"; + exit(1); + } + Weights::InitFromFile(conf["means"].as(), &means); + } + boost::shared_ptr o; + if (rank == 0) { + const string omethod = conf["optimization_method"].as(); + if (omethod == "rprop") + o.reset(new RPropOptimizer(num_feats)); // TODO add configuration + else + o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as())); + cerr << "Optimizer: " << o->Name() << endl; + } + double objective = 0; + vector gradient(num_feats, 0.0); + vector rcv_grad; + rcv_grad.clear(); + bool converged = false; + + vector corpus, test_corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + if (conf.count("test_data")) + ReadTrainingCorpus(conf["test_data"].as(), rank, size, &test_corpus); + + TrainingObserver observer; + ConditionalLikelihoodObserver cllh_observer; + while (!converged) { + observer.Reset(); + cllh_observer.Reset(); +#ifdef HAVE_MPI + mpi::timer timer; + world.barrier(); +#endif + if (rank == 0) { + cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; + cerr << " Testset size: " << test_corpus.size() << " sentences / proc)\n"; + } + for (int i = 0; i < corpus.size(); ++i) + decoder->Decode(corpus[i], &observer); + cerr << " process " << rank << '/' << size << " done\n"; + fill(gradient.begin(), gradient.end(), 0); + observer.SetLocalGradientAndObjective(&gradient, &objective); + + unsigned total_words = 0; +#ifdef HAVE_MPI + double to = 0; + rcv_grad.resize(num_feats, 0.0); + mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus(), 0); + swap(gradient, rcv_grad); + rcv_grad.clear(); + + reduce(world, observer.trg_words, total_words, std::plus(), 0); + mpi::reduce(world, objective, to, plus(), 0); + objective = to; +#else + total_words = observer.trg_words; +#endif + if (rank == 0) + cerr << "TRAINING CORPUS: ln p(f|e)=" << objective << "\t log_2 p(f|e) = " << (objective/log(2)) << "\t cond. entropy = " << (objective/log(2) / total_words) << "\t ppl = " << pow(2, (objective/log(2) / total_words)) << endl; + + for (int i = 0; i < test_corpus.size(); ++i) + decoder->Decode(test_corpus[i], &cllh_observer); + + double test_objective = 0; + unsigned test_total_words = 0; +#ifdef HAVE_MPI + reduce(world, cllh_observer.acc_obj, test_objective, std::plus(), 0); + reduce(world, cllh_observer.trg_words, test_total_words, std::plus(), 0); +#else + test_objective = cllh_observer.acc_obj; + test_total_words = cllh_observer.trg_words; +#endif + + if (rank == 0) { // run optimizer only on rank=0 node + if (test_corpus.size()) + cerr << " TEST CORPUS: ln p(f|e)=" << test_objective << "\t log_2 p(f|e) = " << (test_objective/log(2)) << "\t cond. entropy = " << (test_objective/log(2) / test_total_words) << "\t ppl = " << pow(2, (test_objective/log(2) / test_total_words)) << endl; + if (gaussian_prior) { + const double sigsq = conf["sigma_squared"].as(); + double norm = 0; + for (int k = 1; k < lambdas.size(); ++k) { + const double& lambda_k = lambdas[k]; + if (lambda_k) { + const double param = (lambda_k - means[k]); + norm += param * param; + gradient[k] += param / sigsq; + } + } + const double reg = norm / (2.0 * sigsq); + cerr << "REGULARIZATION TERM: " << reg << endl; + objective += reg; + } + cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; + double gnorm = 0; + for (int i = 0; i < gradient.size(); ++i) + gnorm += gradient[i] * gradient[i]; + cerr << " GNORM=" << sqrt(gnorm) << endl; + vector old = lambdas; + int c = 0; + while (old == lambdas) { + ++c; + if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } + o->Optimize(objective, gradient, &lambdas); + assert(c < 5); + } + old.clear(); + Weights::SanityCheck(lambdas); + Weights::ShowLargestFeatures(lambdas); + + converged = o->HasConverged(); + if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } + + string fname = "weights.cur.gz"; + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; + const string svv = vv.str(); + Weights::WriteToFile(fname, lambdas, true, &svv); + } // rank == 0 + int cint = converged; +#ifdef HAVE_MPI + mpi::broadcast(world, &lambdas[0], lambdas.size(), 0); + mpi::broadcast(world, cint, 0); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + converged = cint; + } + return 0; +} + diff --git a/training/crf/mpi_compute_cllh.cc b/training/crf/mpi_compute_cllh.cc new file mode 100644 index 00000000..066389d0 --- /dev/null +++ b/training/crf/mpi_compute_cllh.cc @@ -0,0 +1,134 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "cllh_observer.h" +#include "sentence_metadata.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("weights,w",po::value(),"Input feature weights file") + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadInstances(const string& fname, int rank, int size, vector* c) { + assert(fname != "-"); + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + // load weights + vector& weights = decoder.CurrentWeightVector(); + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &weights); + + vector corpus; + ReadInstances(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + ConditionalLikelihoodObserver observer; + for (int i = 0; i < corpus.size(); ++i) + decoder.Decode(corpus[i], &observer); + + double objective = 0; + unsigned total_words = 0; +#ifdef HAVE_MPI + reduce(world, observer.acc_obj, objective, std::plus(), 0); + reduce(world, observer.trg_words, total_words, std::plus(), 0); +#else + objective = observer.acc_obj; +#endif + + if (rank == 0) { + cout << "CONDITIONAL LOG_e LIKELIHOOD: " << objective << endl; + cout << "CONDITIONAL LOG_2 LIKELIHOOD: " << (objective/log(2)) << endl; + cout << " CONDITIONAL ENTROPY: " << (objective/log(2) / total_words) << endl; + cout << " PERPLEXITY: " << pow(2, (objective/log(2) / total_words)) << endl; + } + + return 0; +} + diff --git a/training/crf/mpi_extract_features.cc b/training/crf/mpi_extract_features.cc new file mode 100644 index 00000000..6750aa15 --- /dev/null +++ b/training/crf/mpi_extract_features.cc @@ -0,0 +1,151 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value()->default_value("features"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the feature strings encountered.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + } +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); + + vector corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + TrainingObserver observer; + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + for (int i = 0; i < corpus.size(); ++i) + decoder.Decode(corpus[i], &observer); + + { + ostringstream os; + os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + const unsigned num_feats = FD::NumFeats(); + for (unsigned i = 1; i < num_feats; ++i) { + out << FD::Convert(i) << endl; + } + cerr << "Wrote " << os.str() << endl; + } + +#ifdef HAVE_MPI + world.barrier(); +#else +#endif + + return 0; +} + diff --git a/training/crf/mpi_extract_reachable.cc b/training/crf/mpi_extract_reachable.cc new file mode 100644 index 00000000..2a7c2b9d --- /dev/null +++ b/training/crf/mpi_extract_reachable.cc @@ -0,0 +1,163 @@ +#include +#include +#include +#include + +#include "config.h" +#ifdef HAVE_MPI +#include +#endif +#include +#include + +#include "ff_register.h" +#include "verbose.h" +#include "filelib.h" +#include "fdict.h" +#include "decoder.h" +#include "weights.h" + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("training_data,t",po::value(),"Training data corpus") + ("decoder_config,c",po::value(),"Decoder configuration file") + ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") + ("output_prefix,o",po::value()->default_value("reachable"),"Output path prefix"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { + cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the inputs that produce reachable parallel parses.\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int lc = 0; + while(in) { + getline(in, line); + if (!in) break; + if (lc % size == rank) c->push_back(line); + ++lc; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct ReachabilityObserver : public DecoderObserver { + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + reachable = false; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + reachable = true; + } + + bool reachable; +}; + +#ifdef HAVE_MPI +namespace mpi = boost::mpi; +#endif + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return false; + + // load cdec.ini and set up decoder + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + if (decoder.GetConf()["input"].as() != "-") { + cerr << "cdec.ini must not set an input file\n"; + abort(); + } + + if (FD::UsingPerfectHashFunction()) { + cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; + return 1; + } + + // load optional weights + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); + + vector corpus; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); + assert(corpus.size() > 0); + + + if (rank == 0) + cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; + + size_t num_reached = 0; + { + ostringstream os; + os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; + WriteFile wf(os.str()); + ostream& out = *wf.stream(); + ReachabilityObserver observer; + for (int i = 0; i < corpus.size(); ++i) { + decoder.Decode(corpus[i], &observer); + if (observer.reachable) { + out << corpus[i] << endl; + ++num_reached; + } + corpus[i].clear(); + } + cerr << "Shard " << rank << '/' << size << " finished, wrote " + << num_reached << " instances to " << os.str() << endl; + } + + size_t total = 0; +#ifdef HAVE_MPI + reduce(world, num_reached, total, std::plus(), 0); +#else + total = num_reached; +#endif + if (rank == 0) { + cerr << "-----------------------------------------\n"; + cerr << "TOTAL = " << total << " instances\n"; + } + return 0; +} + diff --git a/training/crf/mpi_flex_optimize.cc b/training/crf/mpi_flex_optimize.cc new file mode 100644 index 00000000..b52decdc --- /dev/null +++ b/training/crf/mpi_flex_optimize.cc @@ -0,0 +1,386 @@ +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "stringlib.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "optimize.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include +#include +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("cdec_config,c",po::value(),"Decoder configuration file") + ("weights,w",po::value(),"Initial feature weights") + ("training_data,d",po::value(),"Training data") + ("minibatch_size_per_proc,s", po::value()->default_value(6), "Number of training instances evaluated per processor in each minibatch") + ("minibatch_iterations,i", po::value()->default_value(10), "Number of optimization iterations per minibatch") + ("iterations,I", po::value()->default_value(50), "Number of passes through the training data before termination") + ("regularization_strength,C", po::value()->default_value(0.2), "Regularization strength") + ("time_series_strength,T", po::value()->default_value(0.0), "Time series regularization strength") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("lbfgs_memory_buffers,M", po::value()->default_value(10), "Number of memory buffers for LBFGS history"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("cdec_config")) { + cerr << "LBFGS minibatch online optimizer (MPI support " +#if HAVE_MPI + << "enabled" +#else + << "not enabled" +#endif + << ")\n" << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct CopyHGsObserver : public DecoderObserver { + Hypergraph* hg_; + Hypergraph* gold_hg_; + + // this can free up some memory + void RemoveRules(Hypergraph* h) { + for (unsigned i = 0; i < h->edges_.size(); ++i) + h->edges_[i].rule_.reset(); + } + + void SetCurrentHypergraphs(Hypergraph* h, Hypergraph* gold_h) { + hg_ = h; + gold_hg_ = gold_h; + } + + virtual void NotifyDecodingStart(const SentenceMetadata&) { + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { + *hg_ = *hg; + RemoveRules(hg_); + assert(state == 1); + state = 2; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata&, Hypergraph* hg) { + assert(state == 2); + state = 3; + *gold_hg_ = *hg; + RemoveRules(gold_hg_); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata&) { + if (state == 3) { + } else { + hg_->clear(); + gold_hg_->clear(); + } + } + + int state; +}; + +void ReadConfig(const string& ini, istringstream* out) { + ReadFile rf(ini); + istream& in = *rf.stream(); + ostringstream os; + while(in) { + string line; + getline(in, line); + if (!in) continue; + os << line << endl; + } + out->str(os.str()); +} + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative >, SparseVector > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +void AddGrad(const SparseVector x, double s, SparseVector* acc) { + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) + acc->add_value(it->first, it->second.as_float() * s); +} + +double PNorm(const vector& v, const double p) { + double acc = 0; + for (int i = 0; i < v.size(); ++i) + acc += pow(v[i], p); + return pow(acc, 1.0 / p); +} + +void VV(ostream&os, const vector& v) { + for (int i = 1; i < v.size(); ++i) + if (v[i]) os << FD::Convert(i) << "=" << v[i] << " "; +} + +double ApplyRegularizationTerms(const double C, + const double T, + const vector& weights, + const vector& prev_weights, + double* g) { + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + reg += C * w_i * w_i; + g[i] += 2 * C * w_i; + + reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); + g[i] += 2 * T * (w_i - prev_w_i); + } + return reg; +} + +int main(int argc, char** argv) { +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + MT19937* rng = NULL; + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + boost::shared_ptr o; + const unsigned lbfgs_memory_buffers = conf["lbfgs_memory_buffers"].as(); + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as(); + const unsigned minibatch_iterations = conf["minibatch_iterations"].as(); + const double regularization_strength = conf["regularization_strength"].as(); + const double time_series_strength = conf["time_series_strength"].as(); + const bool use_time_series_reg = time_series_strength > 0.0; + const unsigned max_iteration = conf["iterations"].as(); + + vector corpus; + vector ids; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + if (size_per_proc > corpus.size()) { + cerr << "Minibatch size (per processor) must be smaller or equal to the local corpus size!\n"; + return 1; + } + + // initialize decoder (loads hash functions if necessary) + istringstream ins; + ReadConfig(conf["cdec_config"].as(), &ins); + Decoder decoder(&ins); + + // load initial weights + vector prev_weights; + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &prev_weights); + + if (conf.count("random_seed")) + rng = new MT19937(conf["random_seed"].as()); + else + rng = new MT19937; + + size_t total_corpus_size = 0; +#ifdef HAVE_MPI + reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); +#else + total_corpus_size = corpus.size(); +#endif + + if (rank == 0) + cerr << "Total corpus size: " << total_corpus_size << endl; + + CopyHGsObserver observer; + + int write_weights_every_ith = 100; // TODO configure + int titer = -1; + + vector& cur_weights = decoder.CurrentWeightVector(); + if (use_time_series_reg) { + cur_weights = prev_weights; + } else { + cur_weights.swap(prev_weights); + prev_weights.clear(); + } + + int iter = -1; + bool converged = false; + vector gg; + while (!converged) { +#ifdef HAVE_MPI + mpi::timer timer; +#endif + ++iter; ++titer; + if (rank == 0) { + converged = (iter == max_iteration); + string fname = "weights.cur.gz"; + if (iter % write_weights_every_ith == 0) { + ostringstream o; o << "weights.epoch_" << iter << ".gz"; + fname = o.str(); + } + if (converged) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())); + const string svv = vv.str(); + Weights::WriteToFile(fname, cur_weights, true, &svv); + } + + vector hgs(size_per_proc); + vector gold_hgs(size_per_proc); + for (int i = 0; i < size_per_proc; ++i) { + int ei = corpus.size() * rng->next(); + int id = ids[ei]; + observer.SetCurrentHypergraphs(&hgs[i], &gold_hgs[i]); + decoder.SetId(id); + decoder.Decode(corpus[ei], &observer); + } + + SparseVector local_grad, g; + double local_obj = 0; + o.reset(); + for (unsigned mi = 0; mi < minibatch_iterations; ++mi) { + local_grad.clear(); + g.clear(); + local_obj = 0; + + for (unsigned i = 0; i < size_per_proc; ++i) { + Hypergraph& hg = hgs[i]; + Hypergraph& hg_gold = gold_hgs[i]; + if (hg.edges_.size() < 2) continue; + + hg.Reweight(cur_weights); + hg_gold.Reweight(cur_weights); + SparseVector model_exp, gold_exp; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(hg, &model_exp); + local_obj += log(z); + model_exp /= z; + AddGrad(model_exp, 1.0, &local_grad); + model_exp.clear(); + + const prob_t goldz = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(hg_gold, &gold_exp); + local_obj -= log(goldz); + + if (log(z) - log(goldz) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_gold_z: " << log(z) << " " << log(goldz) << endl; + return 1; + } + + gold_exp /= goldz; + AddGrad(gold_exp, -1.0, &local_grad); + } + + double obj = 0; +#ifdef HAVE_MPI + reduce(world, local_obj, obj, std::plus(), 0); + reduce(world, local_grad, g, std::plus >(), 0); +#else + obj = local_obj; + g.swap(local_grad); +#endif + local_grad.clear(); + if (rank == 0) { + // g /= (size_per_proc * size); + if (!o) + o.reset(new LBFGSOptimizer(FD::NumFeats(), lbfgs_memory_buffers)); + gg.clear(); + gg.resize(FD::NumFeats()); + if (gg.size() != cur_weights.size()) { cur_weights.resize(gg.size()); } + for (SparseVector::iterator it = g.begin(); it != g.end(); ++it) + if (it->first) { gg[it->first] = it->second; } + g.clear(); + double r = ApplyRegularizationTerms(regularization_strength, + time_series_strength, // * (iter == 0 ? 0.0 : 1.0), + cur_weights, + prev_weights, + &gg[0]); + obj += r; + if (mi == 0 || mi == (minibatch_iterations - 1)) { + if (!mi) cerr << iter << ' '; else cerr << ' '; + cerr << "OBJ=" << obj << " (REG=" << r << ")" << " |g|=" << PNorm(gg, 2) << " |w|=" << PNorm(cur_weights, 2); + if (mi > 0) cerr << endl << flush; else cerr << ' '; + } else { cerr << '.' << flush; } + // cerr << "w = "; VV(cerr, cur_weights); cerr << endl; + // cerr << "g = "; VV(cerr, gg); cerr << endl; + o->Optimize(obj, gg, &cur_weights); + } +#ifdef HAVE_MPI + broadcast(world, cur_weights, 0); + broadcast(world, converged, 0); + world.barrier(); +#endif + } + prev_weights = cur_weights; + } + return 0; +} diff --git a/training/crf/mpi_online_optimize.cc b/training/crf/mpi_online_optimize.cc new file mode 100644 index 00000000..d6968848 --- /dev/null +++ b/training/crf/mpi_online_optimize.cc @@ -0,0 +1,374 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "stringlib.h" +#include "verbose.h" +#include "hg.h" +#include "prob.h" +#include "inside_outside.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "online_optimizer.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +#ifdef HAVE_MPI +#include +#include +namespace mpi = boost::mpi; +#endif + +using namespace std; +namespace po = boost::program_options; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value(),"Input feature weights file") + ("frozen_features,z",po::value(), "List of features not to optimize") + ("training_data,t",po::value(),"Training data corpus") + ("training_agenda,a",po::value(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") + ("minibatch_size_per_proc,s", po::value()->default_value(5), "Number of training instances evaluated per processor in each minibatch") + ("optimization_method,m", po::value()->default_value("sgd"), "Optimization method (sgd)") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("eta_0,e", po::value()->default_value(0.2), "Initial learning rate for SGD (eta_0)") + ("L1,1","Use L1 regularization") + ("regularization_strength,C", po::value()->default_value(1.0), "Regularization strength (C)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("training_data") || !conf->count("training_agenda")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + int id = 0; + while(in) { + getline(in, line); + if (!in) break; + if (id % size == rank) { + c->push_back(line); + order->push_back(id); + } + ++id; + } +} + +static const double kMINUS_EPSILON = -1e-6; + +struct TrainingObserver : public DecoderObserver { + void Reset() { + acc_grad.clear(); + acc_obj = 0; + total_complete = 0; + } + + void SetLocalGradientAndObjective(vector* g, double* o) const { + *o = acc_obj; + for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + (*g)[it->first] = it->second.as_float(); + } + + virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { + cur_model_exp.clear(); + cur_obj = 0; + state = 1; + } + + // compute model expectations, denominator of objective + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 1); + state = 2; + const prob_t z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); + cur_obj = log(z); + cur_model_exp /= z; + } + + // compute "empirical" expectations, numerator of objective + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { + assert(state == 2); + state = 3; + SparseVector ref_exp; + const prob_t ref_z = InsideOutside, + EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); + ref_exp /= ref_z; + + double log_ref_z; +#if 0 + if (crf_uniform_empirical) { + log_ref_z = ref_exp.dot(feature_weights); + } else { + log_ref_z = log(ref_z); + } +#else + log_ref_z = log(ref_z); +#endif + + // rounding errors means that <0 is too strict + if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + ref_exp -= cur_model_exp; + acc_grad += ref_exp; + acc_obj += (cur_obj - log_ref_z); + } + + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { + if (state == 3) { + ++total_complete; + } else { + } + } + + void GetGradient(SparseVector* g) const { + g->clear(); + for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) + g->set_value(it->first, it->second.as_float()); + } + + int total_complete; + SparseVector cur_model_exp; + SparseVector acc_grad; + double acc_obj; + double cur_obj; + int state; +}; + +#ifdef HAVE_MPI +namespace boost { namespace mpi { + template<> + struct is_commutative >, SparseVector > + : mpl::true_ { }; +} } // end namespace boost::mpi +#endif + +bool LoadAgenda(const string& file, vector >* a) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) break; + if (line.empty()) continue; + if (line[0] == '#') continue; + int sc = 0; + if (line.size() < 3) return false; + for (int i = 0; i < line.size(); ++i) { if (line[i] == ' ') ++sc; } + if (sc != 1) { cerr << "Too many spaces in line: " << line << endl; return false; } + size_t d = line.find(" "); + pair x; + x.first = line.substr(0,d); + x.second = atoi(line.substr(d+1).c_str()); + a->push_back(x); + if (!FileExists(x.first)) { + cerr << "Can't find file " << x.first << endl; + return false; + } + } + return true; +} + +int main(int argc, char** argv) { + cerr << "THIS SOFTWARE IS DEPRECATED YOU SHOULD USE mpi_flex_optimize\n"; +#ifdef HAVE_MPI + mpi::environment env(argc, argv); + mpi::communicator world; + const int size = world.size(); + const int rank = world.rank(); +#else + const int size = 1; + const int rank = 0; +#endif + if (size > 1) SetSilent(true); // turn off verbose decoder output + register_feature_functions(); + std::tr1::shared_ptr rng; + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) + return 1; + + vector > agenda; + if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) + return 1; + if (rank == 0) + cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; + + assert(agenda.size() > 0); + + if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini + const string& cur_config = agenda[0].first; + const unsigned max_iteration = agenda[0].second; + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + } + + // load initial weights + vector init_weights; + if (conf.count("input_weights")) + Weights::InitFromFile(conf["input_weights"].as(), &init_weights); + + vector frozen_fids; + if (conf.count("frozen_features")) { + ReadFile rf(conf["frozen_features"].as()); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (line.empty()) continue; + if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); } + frozen_fids.push_back(FD::Convert(line)); + } + if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n"; + } + + vector corpus; + vector ids; + ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); + assert(corpus.size() > 0); + + std::tr1::shared_ptr o; + std::tr1::shared_ptr lr; + + const unsigned size_per_proc = conf["minibatch_size_per_proc"].as(); + if (size_per_proc > corpus.size()) { + cerr << "Minibatch size must be smaller than corpus size!\n"; + return 1; + } + + size_t total_corpus_size = 0; +#ifdef HAVE_MPI + reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); +#else + total_corpus_size = corpus.size(); +#endif + + if (rank == 0) { + cerr << "Total corpus size: " << total_corpus_size << endl; + const unsigned batch_size = size_per_proc * size; + // TODO config + lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as())); + + const string omethod = conf["optimization_method"].as(); + if (omethod == "sgd") { + const double C = conf["regularization_strength"].as(); + o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C, frozen_fids)); + } else { + assert(!"fail"); + } + } + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as())); + else + rng.reset(new MT19937); + + SparseVector x; + Weights::InitSparseVector(init_weights, &x); + TrainingObserver observer; + + int write_weights_every_ith = 100; // TODO configure + int titer = -1; + + for (int ai = 0; ai < agenda.size(); ++ai) { + const string& cur_config = agenda[ai].first; + const unsigned max_iteration = agenda[ai].second; + if (rank == 0) + cerr << "STARTING TRAINING EPOCH " << (ai+1) << ". CONFIG=" << cur_config << endl; + // load cdec.ini and set up decoder + ReadFile ini_rf(cur_config); + Decoder decoder(ini_rf.stream()); + vector& lambdas = decoder.CurrentWeightVector(); + if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } + + if (rank == 0) + o->ResetEpoch(); // resets the learning rate-- TODO is this good? + + int iter = -1; + bool converged = false; + while (!converged) { +#ifdef HAVE_MPI + mpi::timer timer; +#endif + x.init_vector(&lambdas); + ++iter; ++titer; + observer.Reset(); + if (rank == 0) { + converged = (iter == max_iteration); + Weights::SanityCheck(lambdas); + static int cc = 0; ++cc; if (cc > 1) { Weights::ShowLargestFeatures(lambdas); } + string fname = "weights.cur.gz"; + if (iter % write_weights_every_ith == 0) { + ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; + fname = o.str(); + } + if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } + ostringstream vv; + vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); + const string svv = vv.str(); + cerr << svv << endl; + Weights::WriteToFile(fname, lambdas, true, &svv); + } + + for (int i = 0; i < size_per_proc; ++i) { + int ei = corpus.size() * rng->next(); + int id = ids[ei]; + decoder.SetId(id); + decoder.Decode(corpus[ei], &observer); + } + SparseVector local_grad, g; + observer.GetGradient(&local_grad); +#ifdef HAVE_MPI + reduce(world, local_grad, g, std::plus >(), 0); +#else + g.swap(local_grad); +#endif + local_grad.clear(); + if (rank == 0) { + g /= (size_per_proc * size); + o->UpdateWeights(g, FD::NumFeats(), &x); + } +#ifdef HAVE_MPI + broadcast(world, x, 0); + broadcast(world, converged, 0); + world.barrier(); + if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } +#endif + } + } + return 0; +} diff --git a/training/dep-reorder/conll2reordering-forest.pl b/training/dep-reorder/conll2reordering-forest.pl deleted file mode 100755 index 3cd226be..00000000 --- a/training/dep-reorder/conll2reordering-forest.pl +++ /dev/null @@ -1,65 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -my $script_dir; BEGIN { use Cwd qw/ abs_path cwd /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, $script_dir; } -my $FIRST_CONV = "$script_dir/scripts/conll2simplecfg.pl"; -my $CDEC = "$script_dir/../../decoder/cdec"; - -our $tfile1 = "grammar1.$$"; -our $tfile2 = "text.$$"; - -die "Usage: $0 parses.conll\n" unless scalar @ARGV == 1; -open C, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!"; - -END { unlink $tfile1; unlink "$tfile1.cfg"; unlink $tfile2; } - -my $first = 1; -open T, ">$tfile1" or die "Can't write $tfile1: $!"; -my $lc = 0; -my $flag = 0; -my @words = (); -while() { - print T; - chomp; - if (/^$/) { - if ($first) { $first = undef; } else { if ($flag) { print "\n"; $flag = 0; } } - $first = undef; - close T; - open SO, ">$tfile2" or die "Can't write $tfile2: $!"; - print SO "@words\n"; - close SO; - @words=(); - `$FIRST_CONV < $tfile1 > $tfile1.cfg`; - if ($? != 0) { - die "Error code: $?"; - } - my $cfg = `$CDEC -n -S 10000 -f scfg -g $tfile1.cfg -i $tfile2 --show_cfg_search_space 2>/dev/null`; - if ($? != 0) { - die "Error code: $?"; - } - my @rules = split /\n/, $cfg; - shift @rules; # get rid of output - for my $rule (@rules) { - my ($lhs, $f, $e, $feats) = split / \|\|\| /, $rule; - $f =~ s/,\d\]/\]/g; - $feats = 'TOP=1' unless $feats; - if ($lhs =~ /\[Goal_\d+\]/) { $lhs = '[S]'; } - print "$lhs ||| $f ||| $feats\n"; - if ($e eq '[1] [2]') { - my ($a, $b) = split /\s+/, $f; - $feats =~ s/=1$//; - my ($x, $y) = split /_/, $feats; - print "$lhs ||| $b $a ||| ${y}_$x=1\n"; - } - $flag = 1; - } - open T, ">$tfile1" or die "Can't write $tfile1: $!"; - $lc = -1; - } else { - my ($ind, $word, @dmmy) = split /\s+/; - push @words, $word; - } - $lc++; -} -close T; - diff --git a/training/dep-reorder/george.conll b/training/dep-reorder/george.conll deleted file mode 100644 index 7eebb360..00000000 --- a/training/dep-reorder/george.conll +++ /dev/null @@ -1,4 +0,0 @@ -1 George _ GEORGE _ _ 2 X _ _ -2 hates _ HATES _ _ 0 X _ _ -3 broccoli _ BROC _ _ 2 X _ _ - diff --git a/training/dep-reorder/scripts/conll2simplecfg.pl b/training/dep-reorder/scripts/conll2simplecfg.pl deleted file mode 100755 index b101347a..00000000 --- a/training/dep-reorder/scripts/conll2simplecfg.pl +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -# 1 在 _ 10 _ _ 4 X _ _ -# 2 门厅 _ 3 _ _ 1 X _ _ -# 3 下面 _ 23 _ _ 4 X _ _ -# 4 。 _ 45 _ _ 0 X _ _ - -my @ldeps; -my @rdeps; -@ldeps=(); for (my $i =0; $i <1000; $i++) { push @ldeps, []; } -@rdeps=(); for (my $i =0; $i <1000; $i++) { push @rdeps, []; } -my $rootcat = 0; -my @cats = ('S'); -my $len = 0; -my @noposcats = ('S'); -while(<>) { - chomp; - if (/^\s*$/) { - write_cfg($len); - $len = 0; - @cats=('S'); - @noposcats = ('S'); - @ldeps=(); for (my $i =0; $i <1000; $i++) { push @ldeps, []; } - @rdeps=(); for (my $i =0; $i <1000; $i++) { push @rdeps, []; } - next; - } - $len++; - my ($pos, $word, $d1, $xcat, $d2, $d3, $headpos, $deptype) = split /\s+/; - my $cat = "C$xcat"; - my $catpos = $cat . "_$pos"; - push @cats, $catpos; - push @noposcats, $cat; - print "[$catpos] ||| $word ||| $word ||| Word=1\n"; - if ($headpos == 0) { $rootcat = $pos; } - if ($pos < $headpos) { - push @{$ldeps[$headpos]}, $pos; - } else { - push @{$rdeps[$headpos]}, $pos; - } -} - -sub write_cfg { - my $len = shift; - for (my $i = 1; $i <= $len; $i++) { - my @lds = @{$ldeps[$i]}; - for my $ld (@lds) { - print "[$cats[$i]] ||| [$cats[$ld],1] [$cats[$i],2] ||| [1] [2] ||| $noposcats[$ld]_$noposcats[$i]=1\n"; - } - my @rds = @{$rdeps[$i]}; - for my $rd (@rds) { - print "[$cats[$i]] ||| [$cats[$i],1] [$cats[$rd],2] ||| [1] [2] ||| $noposcats[$i]_$noposcats[$rd]=1\n"; - } - } - print "[S] ||| [$cats[$rootcat],1] ||| [1] ||| TOP=1\n"; -} - diff --git a/training/dpmert/Makefile.am b/training/dpmert/Makefile.am new file mode 100644 index 00000000..ff318bef --- /dev/null +++ b/training/dpmert/Makefile.am @@ -0,0 +1,25 @@ +bin_PROGRAMS = \ + mr_dpmert_map \ + mr_dpmert_reduce \ + mr_dpmert_generate_mapper_input + +noinst_PROGRAMS = \ + lo_test +TESTS = lo_test + +mr_dpmert_generate_mapper_input_SOURCES = mr_dpmert_generate_mapper_input.cc line_optimizer.cc +mr_dpmert_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +# nbest2hg_SOURCES = nbest2hg.cc +# nbest2hg_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lfst -lz + +mr_dpmert_map_SOURCES = mert_geometry.cc ces.cc error_surface.cc mr_dpmert_map.cc line_optimizer.cc +mr_dpmert_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +mr_dpmert_reduce_SOURCES = error_surface.cc ces.cc mr_dpmert_reduce.cc line_optimizer.cc mert_geometry.cc +mr_dpmert_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +lo_test_SOURCES = lo_test.cc ces.cc mert_geometry.cc error_surface.cc line_optimizer.cc +lo_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dpmert/ces.cc b/training/dpmert/ces.cc new file mode 100644 index 00000000..157b2d17 --- /dev/null +++ b/training/dpmert/ces.cc @@ -0,0 +1,90 @@ +#include "ces.h" + +#include +#include +#include + +// TODO, if AER is to be optimized again, we will need this +// #include "aligner.h" +#include "lattice.h" +#include "mert_geometry.h" +#include "error_surface.h" +#include "ns.h" + +using namespace std; + +const bool minimize_segments = true; // if adjacent segments have equal scores, merge them + +void ComputeErrorSurface(const SegmentEvaluator& ss, + const ConvexHull& ve, + ErrorSurface* env, + const EvaluationMetric* metric, + const Hypergraph& hg) { + vector prev_trans; + const vector >& ienv = ve.GetSortedSegs(); + env->resize(ienv.size()); + SufficientStats prev_score; // defaults to 0 + int j = 0; + for (unsigned i = 0; i < ienv.size(); ++i) { + const MERTPoint& seg = *ienv[i]; + vector trans; +#if 0 + if (type == AER) { + vector edges(hg.edges_.size(), false); + seg.CollectEdgesUsed(&edges); // get the set of edges in the viterbi + // alignment + ostringstream os; + const string* psrc = ss.GetSource(); + if (psrc == NULL) { + cerr << "AER scoring in VEST requires source, but it is missing!\n"; + abort(); + } + size_t pos = psrc->rfind(" ||| "); + if (pos == string::npos) { + cerr << "Malformed source for AER: expected |||\nINPUT: " << *psrc << endl; + abort(); + } + Lattice src; + Lattice ref; + LatticeTools::ConvertTextOrPLF(psrc->substr(0, pos), &src); + LatticeTools::ConvertTextOrPLF(psrc->substr(pos + 5), &ref); + AlignerTools::WriteAlignment(src, ref, hg, &os, true, 0, &edges); + string tstr = os.str(); + TD::ConvertSentence(tstr.substr(tstr.rfind(" ||| ") + 5), &trans); + } else { +#endif + seg.ConstructTranslation(&trans); + //} + //cerr << "Scoring: " << TD::GetString(trans) << endl; + if (trans == prev_trans) { + if (!minimize_segments) { + ErrorSegment& out = (*env)[j]; + out.delta.fields.clear(); + out.x = seg.x; + ++j; + } + //cerr << "Identical translation, skipping scoring\n"; + } else { + SufficientStats score; + ss.Evaluate(trans, &score); + // cerr << "score= " << score->ComputeScore() << "\n"; + //string x1; score.Encode(&x1); cerr << "STATS: " << x1 << endl; + const SufficientStats delta = score - prev_score; + //string x2; delta.Encode(&x2); cerr << "DELTA: " << x2 << endl; + //string xx; delta.Encode(&xx); cerr << xx << endl; + prev_trans.swap(trans); + prev_score = score; + if ((!minimize_segments) || (!delta.IsAdditiveIdentity())) { + ErrorSegment& out = (*env)[j]; + out.delta = delta; + out.x = seg.x; + ++j; + } + } + } + // cerr << " In segments: " << ienv.size() << endl; + // cerr << "Out segments: " << j << endl; + assert(j > 0); + env->resize(j); +} + diff --git a/training/dpmert/ces.h b/training/dpmert/ces.h new file mode 100644 index 00000000..e4fa2080 --- /dev/null +++ b/training/dpmert/ces.h @@ -0,0 +1,16 @@ +#ifndef _CES_H_ +#define _CES_H_ + +class ConvexHull; +class Hypergraph; +class SegmentEvaluator; +class ErrorSurface; +class EvaluationMetric; + +void ComputeErrorSurface(const SegmentEvaluator& ss, + const ConvexHull& convex_hull, + ErrorSurface* es, + const EvaluationMetric* metric, + const Hypergraph& hg); + +#endif diff --git a/training/dpmert/divide_refs.py b/training/dpmert/divide_refs.py new file mode 100755 index 00000000..b478f918 --- /dev/null +++ b/training/dpmert/divide_refs.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +import sys + +(numRefs, outPrefix) = sys.argv[1:] +numRefs = int(numRefs) + +outs = [open(outPrefix+str(i), "w") for i in range(numRefs)] + +i = 0 +for line in sys.stdin: + outs[i].write(line) + i = (i + 1) % numRefs + +for out in outs: + out.close() diff --git a/training/dpmert/dpmert.pl b/training/dpmert/dpmert.pl new file mode 100755 index 00000000..559420f5 --- /dev/null +++ b/training/dpmert/dpmert.pl @@ -0,0 +1,618 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment", "$SCRIPT_DIR/../utils"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use File::Basename qw(basename); +require "libcall.pl"; + +my $QSUB_CMD = qsub_args(mert_memory()); + +# Default settings +my $srcFile; # deprecated +my $refFiles; # deprecated +my $default_jobs = env_default_jobs(); +my $bin_dir = $SCRIPT_DIR; +my $util_dir = "$SCRIPT_DIR/../utils"; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/mr_dpmert_generate_mapper_input"; +my $MAPPER = "$bin_dir/mr_dpmert_map"; +my $REDUCER = "$bin_dir/mr_dpmert_reduce"; +my $parallelize = "$util_dir/parallelize.pl"; +my $libcall = "$util_dir/libcall.pl"; +my $sentserver = "$util_dir/sentserver"; +my $sentclient = "$util_dir/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $lines_per_mapper = 200; +my $rand_directions = 15; +my $iteration = 1; +my $best_weights; +my $max_iterations = 15; +my $optimization_iters = 6; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "9g"; +my $disable_clean = 0; +my %seen_weights; +my $help = 0; +my $epsilon = 0.0001; +my $last_score = -10000000; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $initialWeights; +my $bleu_weight=1; +my $use_make = 1; # use make to parallelize line search +my $useqsub; +my $pass_suffix = ''; +my $devset; +# Process command-line options +if (GetOptions( + "config=s" => \$iniFile, + "weights=s" => \$initialWeights, + "devset=s" => \$devset, + "jobs=i" => \$jobs, + "pass-suffix=s" => \$pass_suffix, + "help" => \$help, + "qsub" => \$useqsub, + "iterations=i" => \$max_iterations, + "pmem=s" => \$pmem, + "random-directions=i" => \$rand_directions, + "metric=s" => \$metric, + "source-file=s" => \$srcFile, + "output-dir=s" => \$dir, +) == 0 || @ARGV!=0 || $help) { + print_help(); + exit; +} + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (defined $srcFile || defined $refFiles) { + die <) { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +while (1){ + print STDERR "\n\nITERATION $iteration\n==========\n"; + + if ($iteration > $max_iterations){ + print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + check_call("mkdir -p $logdir"); + + + #decode + print STDERR "RUNNING DECODER AT "; + print STDERR unchecked_output("date"); + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; + my $pcmd; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; + } else { + $pcmd = "cat $srcFile | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; + } + my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + my $num_hgs; + my $num_topbest; + my $retries = 0; + while($retries < 5) { + $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); + $num_topbest = check_output("wc -l < $runFile"); + print STDERR "NUMBER OF HGs: $num_hgs\n"; + print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; + if($devSize == $num_hgs && $devSize == $num_topbest) { + last; + } else { + print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; + sleep(3); + } + $retries++; + } + die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); + my $dec_score = check_output("cat $runFile | $SCORER $refs -m $metric"); + chomp $dec_score; + print STDERR "DECODER SCORE: $dec_score\n"; + + # save space + check_call("gzip -f $runFile"); + check_call("gzip -f $decoderLog"); + + # run optimizer + print STDERR "RUNNING OPTIMIZER AT "; + print STDERR unchecked_output("date"); + my $mergeLog="$logdir/prune-merge.log.$iteration"; + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + for (my $opt_iter=1; $opt_iter<$optimization_iters; $opt_iter++) { + print STDERR "\nGENERATE OPTIMIZATION STRATEGY (OPT-ITERATION $opt_iter/$optimization_iters)\n"; + print STDERR unchecked_output("date"); + $icc++; + $cmd="$MAPINPUT -w $inweights -r $dir/hgs -s $devSize -d $rand_directions > $dir/agenda.$im1-$opt_iter"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + check_call("mkdir -p $dir/splag.$im1"); + $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1-$opt_iter $dir/splag.$im1/mapinput."; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; + my @shards = grep { /^mapinput\./ } readdir(DIR); + closedir DIR; + die "No shards!" unless scalar @shards > 0; + my $joblist = ""; + my $nmappers = 0; + my @mapoutputs = (); + @cleanupcmds = (); + my %o2i = (); + my $first_shard = 1; + my $mkfile; # only used with makefiles + my $mkfilename; + if ($use_make) { + $mkfilename = "$dir/splag.$im1/domap.mk"; + open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; + print $mkfile "all: $dir/splag.$im1/map.done\n\n"; + } + my @mkouts = (); # only used with makefiles + for my $shard (@shards) { + my $mapoutput = $shard; + my $client_name = $shard; + $client_name =~ s/mapinput.//; + $client_name = "dpmert.$client_name"; + $mapoutput =~ s/mapinput/mapoutput/; + push @mapoutputs, "$dir/splag.$im1/$mapoutput"; + $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; + my $script = "$MAPPER -s $srcFile -m $metric $refs < $dir/splag.$im1/$shard | sort -t \$'\\t' -k 1 > $dir/splag.$im1/$mapoutput"; + if ($use_make) { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "#!/bin/bash\n"; + print F "$script\n"; + close F; + my $output = "$dir/splag.$im1/$mapoutput"; + push @mkouts, $output; + chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; + } else { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "$script\n"; + close F; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + + $nmappers++; + my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; + my $jobid = check_output("$qcmd"); + chomp $jobid; + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + push(@cleanupcmds, "qdel $jobid 2> /dev/null"); + print STDERR " $jobid"; + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + } + } + if ($use_make) { + print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; + close $mkfile; + my $mcmd = "make -j $jobs -f $mkfilename"; + print STDERR "\nExecuting: $mcmd\n"; + check_call($mcmd); + } else { + print STDERR "\nLaunched $nmappers mappers.\n"; + sleep 8; + print STDERR "Waiting for mappers to complete...\n"; + while ($nmappers > 0) { + sleep 5; + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); + $nmappers = scalar @livejobs; + } + print STDERR "All mappers complete.\n"; + } + my $tol = 0; + my $til = 0; + for my $mo (@mapoutputs) { + my $olines = get_lines($mo); + my $ilines = get_lines($o2i{$mo}); + $tol += $olines; + $til += $ilines; + die "$mo: output lines ($olines) doesn't match input lines ($ilines)" unless $olines==$ilines; + } + print STDERR "Results for $tol/$til lines\n"; + print STDERR "\nSORTING AND RUNNING VEST REDUCER\n"; + print STDERR unchecked_output("date"); + $cmd="sort -t \$'\\t' -k 1 @mapoutputs | $REDUCER -m $metric > $dir/redoutput.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + $cmd="sort -nk3 $DIR_FLAG '-t|' $dir/redoutput.$im1 | head -1"; + # sort returns failure even when it doesn't fail for some reason + my $best=unchecked_output("$cmd"); chomp $best; + print STDERR "$best\n"; + my ($oa, $x, $xscore) = split /\|/, $best; + $score = $xscore; + print STDERR "PROJECTED SCORE: $score\n"; + if (abs($x) < $epsilon) { + print STDERR "\nOPTIMIZER: no score improvement: abs($x) < $epsilon\n"; + last; + } + my $psd = $score - $last_score; + $last_score = $score; + if (abs($psd) < $epsilon) { + print STDERR "\nOPTIMIZER: no score improvement: abs($psd) < $epsilon\n"; + last; + } + my ($origin, $axis) = split /\s+/, $oa; + + my %ori = convert($origin); + my %axi = convert($axis); + + my $finalFile="$dir/weights.$im1-$opt_iter"; + open W, ">$finalFile" or die "Can't write: $finalFile: $!"; + my $norm = 0; + for my $k (sort keys %ori) { + my $dd = $ori{$k} + $axi{$k} * $x; + $norm += $dd * $dd; + } + $norm = sqrt($norm); + $norm = 1; + for my $k (sort keys %ori) { + my $v = ($ori{$k} + $axi{$k} * $x) / $norm; + print W "$k $v\n"; + } + check_call("rm $dir/splag.$im1/*"); + $inweights = $finalFile; + } + $lastWeightsFile = "$dir/weights.$iteration"; + check_call("cp $inweights $lastWeightsFile"); + if ($icc < 2) { + print STDERR "\nREACHED STOPPING CRITERION: score change too little\n"; + last; + } + $lastPScore = $score; + $iteration++; + print STDERR "\n==========\n"; +} + +check_call("cp $lastWeightsFile $dir/weights.final"); +print STDERR "\nFINAL WEIGHTS: $dir/weights.final\n(Use -w with the decoder)\n\n"; +print STDOUT "$dir/weights.final\n"; +exit 0; + + +sub get_lines { + my $fn = shift @_; + open FL, "<$fn" or die "Couldn't read $fn: $!"; + my $lc = 0; + while() { $lc++; } + return $lc; +} + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +sub update_weights_file { + my ($neww, $rfn, $rpts) = @_; + my @feats = @$rfn; + my @pts = @$rpts; + my $num_feats = scalar @feats; + my $num_pts = scalar @pts; + die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; + open G, ">$neww" or die; + for (my $i = 0; $i < $num_feats; $i++) { + my $f = $feats[$i]; + my $lambda = $pts[$i]; + print G "$f $lambda\n"; + } + close G; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; +} + +sub print_help { + + my $executable = basename($0); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable [options] + Runs a complete MERT optimization. Required options are --weights, + --devset, and --config. + +Options: + + --config [-c ] + The decoder configuration file. + + --devset [-d ] + The source *and* references for the development set. + + --weights [-w ] + A file specifying initial feature weights. The format is + FeatureName_1 value1 + FeatureName_2 value2 + **All and only the weights listed in will be optimized!** + + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --iterations + Maximum number of iterations to run. If not specified, defaults + to 10. + + --pass-suffix + If the decoder is doing multi-pass decoding, the pass suffix "2", + "3", etc., is used to control what iteration of weights is set. + + --rand-directions + MERT will attempt to optimize along all of the principle directions, + set this parameter to explore other directions. Defaults to 5. + + --output-dir + Directory for intermediate and output files. + + --help + Print this message and exit. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} + +sub split_devset { + my ($infile, $outsrc, $outref) = @_; + open F, "<$infile" or die "Can't read $infile: $!"; + open S, ">$outsrc" or die "Can't write $outsrc: $!"; + open R, ">$outref" or die "Can't write $outref: $!"; + while() { + chomp; + my ($src, @refs) = split /\s*\|\|\|\s*/; + die "Malformed devset line: $_\n" unless scalar @refs > 0; + print S "$src\n"; + print R join(' ||| ', @refs) . "\n"; + } + close R; + close S; + close F; +} + diff --git a/training/dpmert/error_surface.cc b/training/dpmert/error_surface.cc new file mode 100644 index 00000000..515b67f8 --- /dev/null +++ b/training/dpmert/error_surface.cc @@ -0,0 +1,42 @@ +#include "error_surface.h" + +#include +#include + +using namespace std; + +ErrorSurface::~ErrorSurface() {} + +void ErrorSurface::Serialize(std::string* out) const { + const int segments = this->size(); + ostringstream os(ios::binary); + os.write((const char*)&segments,sizeof(segments)); + for (int i = 0; i < segments; ++i) { + const ErrorSegment& cur = (*this)[i]; + string senc; + cur.delta.Encode(&senc); + assert(senc.size() < 1024); + unsigned char len = senc.size(); + os.write((const char*)&cur.x, sizeof(cur.x)); + os.write((const char*)&len, sizeof(len)); + os.write((const char*)&senc[0], len); + } + *out = os.str(); +} + +void ErrorSurface::Deserialize(const std::string& in) { + istringstream is(in, ios::binary); + int segments; + is.read((char*)&segments, sizeof(segments)); + this->resize(segments); + for (int i = 0; i < segments; ++i) { + ErrorSegment& cur = (*this)[i]; + unsigned char len; + is.read((char*)&cur.x, sizeof(cur.x)); + is.read((char*)&len, sizeof(len)); + string senc(len, '\0'); assert(senc.size() == len); + is.read((char*)&senc[0], len); + cur.delta = SufficientStats(senc); + } +} + diff --git a/training/dpmert/error_surface.h b/training/dpmert/error_surface.h new file mode 100644 index 00000000..bb65847b --- /dev/null +++ b/training/dpmert/error_surface.h @@ -0,0 +1,24 @@ +#ifndef _ERROR_SURFACE_H_ +#define _ERROR_SURFACE_H_ + +#include +#include + +#include "ns.h" + +class Score; + +struct ErrorSegment { + double x; + SufficientStats delta; + ErrorSegment() : x(0), delta() {} +}; + +class ErrorSurface : public std::vector { + public: + ~ErrorSurface(); + void Serialize(std::string* out) const; + void Deserialize(const std::string& in); +}; + +#endif diff --git a/training/dpmert/line_mediator.pl b/training/dpmert/line_mediator.pl new file mode 100755 index 00000000..bc2bb24c --- /dev/null +++ b/training/dpmert/line_mediator.pl @@ -0,0 +1,116 @@ +#!/usr/bin/perl -w +#hooks up two processes, 2nd of which has one line of output per line of input, expected by the first, which starts off the communication + +# if you don't know how to fork/exec in a C program, this could be helpful under limited cirmustances (would be ok to liaise with sentserver) + +#WARNING: because it waits for the result from command 2 after sending every line, and especially if command 1 does the same, using sentserver as command 2 won't actually buy you any real parallelism. + +use strict; +use IPC::Open2; +use POSIX qw(pipe dup2 STDIN_FILENO STDOUT_FILENO); + +my $quiet=!$ENV{DEBUG}; +$quiet=1 if $ENV{QUIET}; +sub info { + local $,=' '; + print STDERR @_ unless $quiet; +} + +my $mode='CROSS'; +my $ser='DIRECT'; +$mode='PIPE' if $ENV{PIPE}; +$mode='SNAKE' if $ENV{SNAKE}; +$mode='CROSS' if $ENV{CROSS}; +$ser='SERIAL' if $ENV{SERIAL}; +$ser='DIRECT' if $ENV{DIRECT}; +$ser='SERIAL' if $mode eq 'SNAKE'; +info("mode: $mode\n"); +info("connection: $ser\n"); + + +my @c1; +if (scalar @ARGV) { + do { + push @c1,shift + } while scalar @ARGV && $c1[$#c1] ne '--'; +} +pop @c1; +my @c2=@ARGV; +@ARGV=(); +(scalar @c1 && scalar @c2) || die qq{ +usage: $0 cmd1 args -- cmd2 args +all options are environment variables. +DEBUG=1 env var enables debugging output. +CROSS=1 hooks up two processes, 2nd of which has one line of output per line of input, expected by the first, which starts off the communication. crosses stdin/stderr of cmd1 and cmd2 line by line (both must flush on newline and output. cmd1 initiates the conversation (sends the first line). default: attempts to cross stdin/stdout of c1 and c2 directly (via two unidirectional posix pipes created before fork). +SERIAL=1: (no parallelism possible) but lines exchanged are logged if DEBUG. +if SNAKE then stdin -> c1 -> c2 -> c1 -> stdout. +if PIPE then stdin -> c1 -> c2 -> stdout (same as shell c1|c2, but with SERIAL you can see the intermediate in real time; you could do similar with c1 | tee /dev/fd/2 |c2. +DIRECT=1 (default) will override SERIAL=1. +CROSS=1 (default) will override SNAKE or PIPE. +}; + +info("1 cmd:",@c1,"\n"); +info("2 cmd:",@c2,"\n"); + +sub lineto { + select $_[0]; + $|=1; + shift; + print @_; +} + +if ($ser eq 'SERIAL') { + my ($R1,$W1,$R2,$W2); + my $c1p=open2($R1,$W1,@c1); # Open2 R W backward from Open3. + my $c2p=open2($R2,$W2,@c2); + if ($mode eq 'CROSS') { + while(<$R1>) { + info("1:",$_); + lineto($W2,$_); + last unless defined ($_=<$R2>); + info("1|2:",$_); + lineto($W1,$_); + } + } else { + my $snake=$mode eq 'SNAKE'; + while() { + info("IN:",$_); + lineto($W1,$_); + last unless defined ($_=<$R1>); + info("IN|1:",$_); + lineto($W2,$_); + last unless defined ($_=<$R2>); + info("IN|1|2:",$_); + if ($snake) { + lineto($W1,$_); + last unless defined ($_=<$R1>); + info("IN|1|2|1:",$_); + } + lineto(*STDOUT,$_); + } + } +} else { + info("DIRECT mode\n"); + my @rw1=POSIX::pipe(); + my @rw2=POSIX::pipe(); + my $pid=undef; + $SIG{CHLD} = sub { wait }; + while (not defined ($pid=fork())) { + sleep 1; + } + my $pipe = $mode eq 'PIPE'; + unless ($pipe) { + POSIX::close(STDOUT_FILENO); + POSIX::close(STDIN_FILENO); + } + if ($pid) { + POSIX::dup2($rw1[1],STDOUT_FILENO); + POSIX::dup2($rw2[0],STDIN_FILENO) unless $pipe; + exec @c1; + } else { + POSIX::dup2($rw2[1],STDOUT_FILENO) unless $pipe; + POSIX::dup2($rw1[0],STDIN_FILENO); + exec @c2; + } + while (wait()!=-1) {} +} diff --git a/training/dpmert/line_optimizer.cc b/training/dpmert/line_optimizer.cc new file mode 100644 index 00000000..9cf33502 --- /dev/null +++ b/training/dpmert/line_optimizer.cc @@ -0,0 +1,114 @@ +#include "line_optimizer.h" + +#include +#include + +#include "sparse_vector.h" +#include "ns.h" + +using namespace std; + +typedef ErrorSurface::const_iterator ErrorIter; + +// sort by increasing x-ints +struct IntervalComp { + bool operator() (const ErrorIter& a, const ErrorIter& b) const { + return a->x < b->x; + } +}; + +double LineOptimizer::LineOptimize( + const EvaluationMetric* metric, + const vector& surfaces, + const LineOptimizer::ScoreType type, + float* best_score, + const double epsilon) { + // cerr << "MIN=" << MINIMIZE_SCORE << " MAX=" << MAXIMIZE_SCORE << " MINE=" << type << endl; + vector all_ints; + for (vector::const_iterator i = surfaces.begin(); + i != surfaces.end(); ++i) { + const ErrorSurface& surface = *i; + for (ErrorIter j = surface.begin(); j != surface.end(); ++j) + all_ints.push_back(j); + } + sort(all_ints.begin(), all_ints.end(), IntervalComp()); + double last_boundary = all_ints.front()->x; + SufficientStats acc; + float& cur_best_score = *best_score; + cur_best_score = (type == MAXIMIZE_SCORE ? + -numeric_limits::max() : numeric_limits::max()); + bool left_edge = true; + double pos = numeric_limits::quiet_NaN(); + for (vector::iterator i = all_ints.begin(); + i != all_ints.end(); ++i) { + const ErrorSegment& seg = **i; + if (seg.x - last_boundary > epsilon) { + float sco = metric->ComputeScore(acc); + if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || + (type == MINIMIZE_SCORE && sco < cur_best_score) ) { + cur_best_score = sco; + if (left_edge) { + pos = seg.x - 0.1; + left_edge = false; + } else { + pos = last_boundary + (seg.x - last_boundary) / 2; + } + //cerr << "NEW BEST: " << pos << " (score=" << cur_best_score << ")\n"; + } + // string xx = metric->DetailedScore(acc); cerr << "---- " << xx; +#undef SHOW_ERROR_SURFACES +#ifdef SHOW_ERROR_SURFACES + cerr << "x=" << seg.x << "\ts=" << sco << "\n"; +#endif + last_boundary = seg.x; + } + // cerr << "x-boundary=" << seg.x << "\n"; + //string x2; acc.Encode(&x2); cerr << " ACC: " << x2 << endl; + //string x1; seg.delta.Encode(&x1); cerr << " DELTA: " << x1 << endl; + acc += seg.delta; + } + float sco = metric->ComputeScore(acc); + if ((type == MAXIMIZE_SCORE && sco > cur_best_score) || + (type == MINIMIZE_SCORE && sco < cur_best_score) ) { + cur_best_score = sco; + if (left_edge) { + pos = 0; + } else { + pos = last_boundary + 1000.0; + } + } + return pos; +} + +void LineOptimizer::RandomUnitVector(const vector& features_to_optimize, + SparseVector* axis, + RandomNumberGenerator* rng) { + axis->clear(); + for (int i = 0; i < features_to_optimize.size(); ++i) + axis->set_value(features_to_optimize[i], rng->NextNormal(0.0,1.0)); + (*axis) /= axis->l2norm(); +} + +void LineOptimizer::CreateOptimizationDirections( + const vector& features_to_optimize, + int additional_random_directions, + RandomNumberGenerator* rng, + vector >* dirs + , bool include_orthogonal + ) { + dirs->clear(); + typedef SparseVector Dir; + vector &out=*dirs; + int i=0; + if (include_orthogonal) + for (;i + +#include "sparse_vector.h" +#include "error_surface.h" +#include "sampler.h" + +class EvaluationMetric; +class Weights; + +struct LineOptimizer { + + // use MINIMIZE_SCORE for things like TER, WER + // MAXIMIZE_SCORE for things like BLEU + enum ScoreType { MAXIMIZE_SCORE, MINIMIZE_SCORE }; + + // merge all the error surfaces together into a global + // error surface and find (the middle of) the best segment + static double LineOptimize( + const EvaluationMetric* metric, + const std::vector& envs, + const LineOptimizer::ScoreType type, + float* best_score, + const double epsilon = 1.0/65536.0); + + // return a random vector of length 1 where all dimensions + // not listed in dimensions will be 0. + static void RandomUnitVector(const std::vector& dimensions, + SparseVector* axis, + RandomNumberGenerator* rng); + + // generate a list of directions to optimize; the list will + // contain the orthogonal vectors corresponding to the dimensions in + // primary and then additional_random_directions directions in those + // dimensions as well. All vectors will be length 1. + static void CreateOptimizationDirections( + const std::vector& primary, + int additional_random_directions, + RandomNumberGenerator* rng, + std::vector >* dirs + , bool include_primary=true + ); + +}; + +#endif diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc new file mode 100644 index 00000000..95a08d3d --- /dev/null +++ b/training/dpmert/lo_test.cc @@ -0,0 +1,229 @@ +#define BOOST_TEST_MODULE LineOptimizerTest +#include +#include + +#include +#include +#include + +#include + +#include "ns.h" +#include "ns_docscorer.h" +#include "ces.h" +#include "fdict.h" +#include "hg.h" +#include "kbest.h" +#include "hg_io.h" +#include "filelib.h" +#include "inside_outside.h" +#include "viterbi.h" +#include "mert_geometry.h" +#include "line_optimizer.h" + +using namespace std; + +const char* ref11 = "australia reopens embassy in manila"; +const char* ref12 = "( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack ."; +const char* ref21 = "australia reopened manila embassy"; +const char* ref22 = "( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack ."; +const char* ref31 = "australia to reopen embassy in manila"; +const char* ref32 = "( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so - called confirmed terrorist attack threats ."; +const char* ref41 = "australia to re - open its embassy to manila"; +const char* ref42 = "( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so - called \" clear \" threat of terrorist attack 7 weeks ago ."; + +BOOST_AUTO_TEST_CASE( TestCheckNaN) { + double x = 0; + double y = 0; + double z = x / y; + BOOST_CHECK_EQUAL(true, std::isnan(z)); +} + +BOOST_AUTO_TEST_CASE(TestConvexHull) { + boost::shared_ptr a1(new MERTPoint(-1, 0)); + boost::shared_ptr b1(new MERTPoint(1, 0)); + boost::shared_ptr a2(new MERTPoint(-1, 1)); + boost::shared_ptr b2(new MERTPoint(1, -1)); + vector > sa; sa.push_back(a1); sa.push_back(b1); + vector > sb; sb.push_back(a2); sb.push_back(b2); + ConvexHull a(sa); + cerr << a << endl; + ConvexHull b(sb); + ConvexHull c = a; + c *= b; + cerr << a << " (*) " << b << " = " << c << endl; + BOOST_CHECK_EQUAL(3, c.size()); +} + +BOOST_AUTO_TEST_CASE(TestConvexHullInside) { + const string json = "{\"rules\":[1,\"[X] ||| a\",2,\"[X] ||| A [1]\",3,\"[X] ||| c\",4,\"[X] ||| C [1]\",5,\"[X] ||| [1] B [2]\",6,\"[X] ||| [1] b [2]\",7,\"[X] ||| X [1]\",8,\"[X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + Hypergraph hg; + istringstream instr(json); + HypergraphIO::ReadFromJSON(&instr, &hg); + SparseVector wts; + wts.set_value(FD::Convert("f1"), 0.4); + wts.set_value(FD::Convert("f2"), 1.0); + hg.Reweight(wts); + vector, prob_t> > list; + std::vector > features; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + SparseVector dir; dir.set_value(FD::Convert("f1"), 1.0); + ConvexHullWeightFunction wf(wts, dir); + ConvexHull env = Inside(hg, NULL, wf); + cerr << env << endl; + const vector >& segs = env.GetSortedSegs(); + dir *= segs[1]->x; + wts += dir; + hg.Reweight(wts); + KBest::KBestDerivations, ESentenceTraversal> kbest2(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest2.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + for (unsigned i = 0; i < segs.size(); ++i) { + cerr << "seg=" << i << endl; + vector trans; + segs[i]->ConstructTranslation(&trans); + cerr << TD::GetString(trans) << endl; + } +} + +BOOST_AUTO_TEST_CASE( TestS1) { + int fPhraseModel_0 = FD::Convert("PhraseModel_0"); + int fPhraseModel_1 = FD::Convert("PhraseModel_1"); + int fPhraseModel_2 = FD::Convert("PhraseModel_2"); + int fLanguageModel = FD::Convert("LanguageModel"); + int fWordPenalty = FD::Convert("WordPenalty"); + int fPassThrough = FD::Convert("PassThrough"); + SparseVector wts; + wts.set_value(fWordPenalty, 4.25); + wts.set_value(fLanguageModel, -1.1165); + wts.set_value(fPhraseModel_0, -0.96); + wts.set_value(fPhraseModel_1, -0.65); + wts.set_value(fPhraseModel_2, -0.77); + wts.set_value(fPassThrough, -10.0); + + vector to_optimize; + to_optimize.push_back(fWordPenalty); + to_optimize.push_back(fLanguageModel); + to_optimize.push_back(fPhraseModel_0); + to_optimize.push_back(fPhraseModel_1); + to_optimize.push_back(fPhraseModel_2); + + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : "test_data"); + + Hypergraph hg; + ReadFile rf(path + "/0.json.gz"); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(wts); + + Hypergraph hg2; + ReadFile rf2(path + "/1.json.gz"); + HypergraphIO::ReadFromJSON(rf2.stream(), &hg2); + hg2.Reweight(wts); + + vector > refs1(4); + TD::ConvertSentence(ref11, &refs1[0]); + TD::ConvertSentence(ref21, &refs1[1]); + TD::ConvertSentence(ref31, &refs1[2]); + TD::ConvertSentence(ref41, &refs1[3]); + vector > refs2(4); + TD::ConvertSentence(ref12, &refs2[0]); + TD::ConvertSentence(ref22, &refs2[1]); + TD::ConvertSentence(ref32, &refs2[2]); + TD::ConvertSentence(ref42, &refs2[3]); + vector envs(2); + + RandomNumberGenerator rng; + + vector > axes; // directions to search + LineOptimizer::CreateOptimizationDirections( + to_optimize, + 10, + &rng, + &axes); + assert(axes.size() == 10 + to_optimize.size()); + for (unsigned i = 0; i < axes.size(); ++i) + cerr << axes[i] << endl; + const SparseVector& axis = axes[0]; + + cerr << "Computing Viterbi envelope using inside algorithm...\n"; + cerr << "axis: " << axis << endl; + clock_t t_start=clock(); + ConvexHullWeightFunction wf(wts, axis); // wts = starting point, axis = search direction + envs[0] = Inside(hg, NULL, wf); + envs[1] = Inside(hg2, NULL, wf); + + vector es(2); + EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU"); + boost::shared_ptr scorer1 = metric->CreateSegmentEvaluator(refs1); + boost::shared_ptr scorer2 = metric->CreateSegmentEvaluator(refs2); + ComputeErrorSurface(*scorer1, envs[0], &es[0], metric, hg); + ComputeErrorSurface(*scorer2, envs[1], &es[1], metric, hg2); + cerr << envs[0].size() << " " << envs[1].size() << endl; + cerr << es[0].size() << " " << es[1].size() << endl; + envs.clear(); + clock_t t_env=clock(); + float score; + double m = LineOptimizer::LineOptimize(metric,es, LineOptimizer::MAXIMIZE_SCORE, &score); + clock_t t_opt=clock(); + cerr << "line optimizer returned: " << m << " (SCORE=" << score << ")\n"; + BOOST_CHECK_CLOSE(0.48719698, score, 1e-5); + SparseVector res = axis; + res *= m; + res += wts; + cerr << "res: " << res << endl; + cerr << "ENVELOPE PROCESSING=" << (static_cast(t_env - t_start) / 1000.0) << endl; + cerr << " LINE OPTIMIZATION=" << (static_cast(t_opt - t_env) / 1000.0) << endl; + hg.Reweight(res); + hg2.Reweight(res); + vector t1,t2; + ViterbiESentence(hg, &t1); + ViterbiESentence(hg2, &t2); + cerr << TD::GetString(t1) << endl; + cerr << TD::GetString(t2) << endl; +} + +BOOST_AUTO_TEST_CASE(TestZeroOrigin) { + const string json = "{\"rules\":[1,\"[X7] ||| blA ||| without ||| LHSProb=3.92173 LexE2F=2.90799 LexF2E=1.85003 GenerativeProb=10.5381 RulePenalty=1 XFE=2.77259 XEF=0.441833 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=0.693147\",2,\"[X7] ||| blA ||| except ||| LHSProb=4.92173 LexE2F=3.90799 LexF2E=1.85003 GenerativeProb=11.5381 RulePenalty=1 XFE=2.77259 XEF=1.44183 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=1.69315\",3,\"[S] ||| [X7,1] ||| [1] ||| GlueTop=1\",4,\"[X28] ||| EnwAn ||| title ||| LHSProb=3.96802 LexE2F=2.22462 LexF2E=1.83258 GenerativeProb=10.0863 RulePenalty=1 XFE=0 XEF=1.20397 LabelledEF=1.20397 LabelledFE=-1.98341e-08 LogRuleCount=1.09861\",5,\"[X0] ||| EnwAn ||| funny ||| LHSProb=3.98479 LexE2F=1.79176 LexF2E=3.21888 GenerativeProb=11.1681 RulePenalty=1 XFE=0 XEF=2.30259 LabelledEF=2.30259 LabelledFE=0 LogRuleCount=0 SingletonRule=1\",6,\"[X8] ||| [X7,1] EnwAn ||| entitled [1] ||| LHSProb=3.82533 LexE2F=3.21888 LexF2E=2.52573 GenerativeProb=11.3276 RulePenalty=1 XFE=1.20397 XEF=1.20397 LabelledEF=2.30259 LabelledFE=2.30259 LogRuleCount=0 SingletonRule=1\",7,\"[S] ||| [S,1] [X28,2] ||| [1] [2] ||| Glue=1\",8,\"[S] ||| [S,1] [X0,2] ||| [1] [2] ||| Glue=1\",9,\"[S] ||| [X8,1] ||| [1] ||| GlueTop=1\",10,\"[Goal] ||| [S,1] ||| [1]\"],\"features\":[\"PassThrough\",\"Glue\",\"GlueTop\",\"LanguageModel\",\"WordPenalty\",\"LHSProb\",\"LexE2F\",\"LexF2E\",\"GenerativeProb\",\"RulePenalty\",\"XFE\",\"XEF\",\"LabelledEF\",\"LabelledFE\",\"LogRuleCount\",\"SingletonRule\"],\"edges\":[{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,3.92173,6,2.90799,7,1.85003,8,10.5381,9,1,10,2.77259,11,0.441833,12,2.63906,13,4.96981,14,0.693147],\"rule\":1},{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,4.92173,6,3.90799,7,1.85003,8,11.5381,9,1,10,2.77259,11,1.44183,12,2.63906,13,4.96981,14,1.69315],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X7\"},\"edges\":[{\"tail\":[0],\"spans\":[0,1,-1,-1],\"feats\":[2,1],\"rule\":3}],\"node\":{\"in_edges\":[2],\"cat\":\"S\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.96802,6,2.22462,7,1.83258,8,10.0863,9,1,11,1.20397,12,1.20397,13,-1.98341e-08,14,1.09861],\"rule\":4}],\"node\":{\"in_edges\":[3],\"cat\":\"X28\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.98479,6,1.79176,7,3.21888,8,11.1681,9,1,11,2.30259,12,2.30259,15,1],\"rule\":5}],\"node\":{\"in_edges\":[4],\"cat\":\"X0\"},\"edges\":[{\"tail\":[0],\"spans\":[0,2,-1,-1],\"feats\":[5,3.82533,6,3.21888,7,2.52573,8,11.3276,9,1,10,1.20397,11,1.20397,12,2.30259,13,2.30259,15,1],\"rule\":6}],\"node\":{\"in_edges\":[5],\"cat\":\"X8\"},\"edges\":[{\"tail\":[1,2],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":7},{\"tail\":[1,3],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":8},{\"tail\":[4],\"spans\":[0,2,-1,-1],\"feats\":[2,1],\"rule\":9}],\"node\":{\"in_edges\":[6,7,8],\"cat\":\"S\"},\"edges\":[{\"tail\":[5],\"spans\":[0,2,-1,-1],\"feats\":[],\"rule\":10}],\"node\":{\"in_edges\":[9],\"cat\":\"Goal\"}}"; + Hypergraph hg; + istringstream instr(json); + HypergraphIO::ReadFromJSON(&instr, &hg); + SparseVector wts; + wts.set_value(FD::Convert("PassThrough"), -0.929201533002898); + hg.Reweight(wts); + + vector, prob_t> > list; + std::vector > features; + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, 10); + for (int i = 0; i < 10; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + + SparseVector axis; axis.set_value(FD::Convert("Glue"),1.0); + ConvexHullWeightFunction wf(wts, axis); // wts = starting point, axis = search direction + vector envs(1); + envs[0] = Inside(hg, NULL, wf); + + vector > mr(4); + TD::ConvertSentence("untitled", &mr[0]); + TD::ConvertSentence("with no title", &mr[1]); + TD::ConvertSentence("without a title", &mr[2]); + TD::ConvertSentence("without title", &mr[3]); + EvaluationMetric* metric = EvaluationMetric::Instance("IBM_BLEU"); + boost::shared_ptr scorer1 = metric->CreateSegmentEvaluator(mr); + vector es(1); + ComputeErrorSurface(*scorer1, envs[0], &es[0], metric, hg); +} + diff --git a/training/dpmert/mert_geometry.cc b/training/dpmert/mert_geometry.cc new file mode 100644 index 00000000..d6973658 --- /dev/null +++ b/training/dpmert/mert_geometry.cc @@ -0,0 +1,185 @@ +#include "mert_geometry.h" + +#include +#include + +using namespace std; + +ConvexHull::ConvexHull(int i) { + if (i == 0) { + // do nothing - <> + } else if (i == 1) { + points.push_back(boost::shared_ptr(new MERTPoint(0, 0, 0, boost::shared_ptr(), boost::shared_ptr()))); + assert(this->IsMultiplicativeIdentity()); + } else { + cerr << "Only can create ConvexHull semiring 0 and 1 with this constructor!\n"; + abort(); + } +} + +const ConvexHull ConvexHullWeightFunction::operator()(const Hypergraph::Edge& e) const { + const double m = direction.dot(e.feature_values_); + const double b = origin.dot(e.feature_values_); + MERTPoint* point = new MERTPoint(m, b, e); + return ConvexHull(1, point); +} + +ostream& operator<<(ostream& os, const ConvexHull& env) { + os << '<'; + const vector >& points = env.GetSortedSegs(); + for (int i = 0; i < points.size(); ++i) + os << (i==0 ? "" : "|") << "x=" << points[i]->x << ",b=" << points[i]->b << ",m=" << points[i]->m << ",p1=" << points[i]->p1 << ",p2=" << points[i]->p2; + return os << '>'; +} + +#define ORIGINAL_MERT_IMPLEMENTATION 1 +#ifdef ORIGINAL_MERT_IMPLEMENTATION + +struct SlopeCompare { + bool operator() (const boost::shared_ptr& a, const boost::shared_ptr& b) const { + return a->m < b->m; + } +}; + +const ConvexHull& ConvexHull::operator+=(const ConvexHull& other) { + if (!other.is_sorted) other.Sort(); + if (points.empty()) { + points = other.points; + return *this; + } + is_sorted = false; + int j = points.size(); + points.resize(points.size() + other.points.size()); + for (int i = 0; i < other.points.size(); ++i) + points[j++] = other.points[i]; + assert(j == points.size()); + return *this; +} + +void ConvexHull::Sort() const { + sort(points.begin(), points.end(), SlopeCompare()); + const int k = points.size(); + int j = 0; + for (int i = 0; i < k; ++i) { + MERTPoint l = *points[i]; + l.x = kMinusInfinity; + // cerr << "m=" << l.m << endl; + if (0 < j) { + if (points[j-1]->m == l.m) { // lines are parallel + if (l.b <= points[j-1]->b) continue; + --j; + } + while(0 < j) { + l.x = (l.b - points[j-1]->b) / (points[j-1]->m - l.m); + if (points[j-1]->x < l.x) break; + --j; + } + if (0 == j) l.x = kMinusInfinity; + } + *points[j++] = l; + } + points.resize(j); + is_sorted = true; +} + +const ConvexHull& ConvexHull::operator*=(const ConvexHull& other) { + if (other.IsMultiplicativeIdentity()) { return *this; } + if (this->IsMultiplicativeIdentity()) { (*this) = other; return *this; } + + if (!is_sorted) Sort(); + if (!other.is_sorted) other.Sort(); + + if (this->IsEdgeEnvelope()) { +// if (other.size() > 1) +// cerr << *this << " (TIMES) " << other << endl; + boost::shared_ptr edge_parent = points[0]; + const double& edge_b = edge_parent->b; + const double& edge_m = edge_parent->m; + points.clear(); + for (int i = 0; i < other.points.size(); ++i) { + const MERTPoint& p = *other.points[i]; + const double m = p.m + edge_m; + const double b = p.b + edge_b; + const double& x = p.x; // x's don't change with * + points.push_back(boost::shared_ptr(new MERTPoint(x, m, b, edge_parent, other.points[i]))); + assert(points.back()->p1->edge); + } +// if (other.size() > 1) +// cerr << " = " << *this << endl; + } else { + vector > new_points; + int this_i = 0; + int other_i = 0; + const int this_size = points.size(); + const int other_size = other.points.size(); + double cur_x = kMinusInfinity; // moves from left to right across the + // real numbers, stopping for all inter- + // sections + double this_next_val = (1 < this_size ? points[1]->x : kPlusInfinity); + double other_next_val = (1 < other_size ? other.points[1]->x : kPlusInfinity); + while (this_i < this_size && other_i < other_size) { + const MERTPoint& this_point = *points[this_i]; + const MERTPoint& other_point= *other.points[other_i]; + const double m = this_point.m + other_point.m; + const double b = this_point.b + other_point.b; + + new_points.push_back(boost::shared_ptr(new MERTPoint(cur_x, m, b, points[this_i], other.points[other_i]))); + int comp = 0; + if (this_next_val < other_next_val) comp = -1; else + if (this_next_val > other_next_val) comp = 1; + if (0 == comp) { // the next values are equal, advance both indices + ++this_i; + ++other_i; + cur_x = this_next_val; // could be other_next_val (they're equal!) + this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity); + other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity); + } else { // advance the i with the lower x, update cur_x + if (-1 == comp) { + ++this_i; + cur_x = this_next_val; + this_next_val = (this_i+1 < this_size ? points[this_i+1]->x : kPlusInfinity); + } else { + ++other_i; + cur_x = other_next_val; + other_next_val = (other_i+1 < other_size ? other.points[other_i+1]->x : kPlusInfinity); + } + } + } + points.swap(new_points); + } + //cerr << "Multiply: result=" << (*this) << endl; + return *this; +} + +// recursively construct translation +void MERTPoint::ConstructTranslation(vector* trans) const { + const MERTPoint* cur = this; + vector > ant_trans; + while(!cur->edge) { + ant_trans.resize(ant_trans.size() + 1); + cur->p2->ConstructTranslation(&ant_trans.back()); + cur = cur->p1.get(); + } + size_t ant_size = ant_trans.size(); + vector*> pants(ant_size); + assert(ant_size == cur->edge->tail_nodes_.size()); + --ant_size; + for (int i = 0; i < pants.size(); ++i) pants[ant_size - i] = &ant_trans[i]; + cur->edge->rule_->ESubstitute(pants, trans); +} + +void MERTPoint::CollectEdgesUsed(std::vector* edges_used) const { + if (edge) { + assert(edge->id_ < edges_used->size()); + (*edges_used)[edge->id_] = true; + } + if (p1) p1->CollectEdgesUsed(edges_used); + if (p2) p2->CollectEdgesUsed(edges_used); +} + +#else + +// THIS IS THE NEW FASTER IMPLEMENTATION OF THE MERT SEMIRING OPERATIONS + +#endif + diff --git a/training/dpmert/mert_geometry.h b/training/dpmert/mert_geometry.h new file mode 100644 index 00000000..a8b6959e --- /dev/null +++ b/training/dpmert/mert_geometry.h @@ -0,0 +1,81 @@ +#ifndef _MERT_GEOMETRY_H_ +#define _MERT_GEOMETRY_H_ + +#include +#include +#include + +#include "hg.h" +#include "sparse_vector.h" + +static const double kMinusInfinity = -std::numeric_limits::infinity(); +static const double kPlusInfinity = std::numeric_limits::infinity(); + +struct MERTPoint { + MERTPoint() : x(), m(), b(), edge() {} + MERTPoint(double _m, double _b) : + x(kMinusInfinity), m(_m), b(_b), edge() {} + MERTPoint(double _x, double _m, double _b, const boost::shared_ptr& p1_, const boost::shared_ptr& p2_) : + x(_x), m(_m), b(_b), p1(p1_), p2(p2_), edge() {} + MERTPoint(double _m, double _b, const Hypergraph::Edge& edge) : + x(kMinusInfinity), m(_m), b(_b), edge(&edge) {} + + double x; // x intersection with previous segment in env, or -inf if none + double m; // this line's slope + double b; // intercept with y-axis + + // we keep a pointer to the "parents" of this segment so we can reconstruct + // the Viterbi translation corresponding to this segment + boost::shared_ptr p1; + boost::shared_ptr p2; + + // only MERTPoints created from an edge using the ConvexHullWeightFunction + // have rules + // TRulePtr rule; + const Hypergraph::Edge* edge; + + // recursively recover the Viterbi translation that will result from setting + // the weights to origin + axis * x, where x is any value from this->x up + // until the next largest x in the containing ConvexHull + void ConstructTranslation(std::vector* trans) const; + void CollectEdgesUsed(std::vector* edges_used) const; +}; + +// this is the semiring value type, +// it defines constructors for 0, 1, and the operations + and * +struct ConvexHull { + // create semiring zero + ConvexHull() : is_sorted(true) {} // zero + // for debugging: + ConvexHull(const std::vector >& s) : points(s) { Sort(); } + // create semiring 1 or 0 + explicit ConvexHull(int i); + ConvexHull(int n, MERTPoint* point) : is_sorted(true), points(n, boost::shared_ptr(point)) {} + const ConvexHull& operator+=(const ConvexHull& other); + const ConvexHull& operator*=(const ConvexHull& other); + bool IsMultiplicativeIdentity() const { + return size() == 1 && (points[0]->b == 0.0 && points[0]->m == 0.0) && (!points[0]->edge) && (!points[0]->p1) && (!points[0]->p2); } + const std::vector >& GetSortedSegs() const { + if (!is_sorted) Sort(); + return points; + } + size_t size() const { return points.size(); } + + private: + bool IsEdgeEnvelope() const { + return points.size() == 1 && points[0]->edge; } + void Sort() const; + mutable bool is_sorted; + mutable std::vector > points; +}; +std::ostream& operator<<(std::ostream& os, const ConvexHull& env); + +struct ConvexHullWeightFunction { + ConvexHullWeightFunction(const SparseVector& ori, + const SparseVector& dir) : origin(ori), direction(dir) {} + const ConvexHull operator()(const Hypergraph::Edge& e) const; + const SparseVector origin; + const SparseVector direction; +}; + +#endif diff --git a/training/dpmert/mr_dpmert_generate_mapper_input.cc b/training/dpmert/mr_dpmert_generate_mapper_input.cc new file mode 100644 index 00000000..199cd23a --- /dev/null +++ b/training/dpmert/mr_dpmert_generate_mapper_input.cc @@ -0,0 +1,81 @@ +#include +#include + +#include +#include + +#include "filelib.h" +#include "weights.h" +#include "line_optimizer.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("dev_set_size,s",po::value(),"[REQD] Development set size (# of parallel sentences)") + ("forest_repository,r",po::value(),"[REQD] Path to forest repository") + ("weights,w",po::value(),"[REQD] Current feature weights file") + ("optimize_feature,o",po::value >(), "Feature to optimize (if none specified, all weights listed in the weights file will be optimized)") + ("random_directions,d",po::value()->default_value(20),"Number of random directions to run the line optimizer in") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (conf->count("dev_set_size") == 0) { + cerr << "Please specify the size of the development set using -d N\n"; + flag = true; + } + if (conf->count("weights") == 0) { + cerr << "Please specify the starting-point weights using -w \n"; + flag = true; + } + if (conf->count("forest_repository") == 0) { + cerr << "Please specify the forest repository location using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + RandomNumberGenerator rng; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + vector features; + SparseVector origin; + vector w; + Weights::InitFromFile(conf["weights"].as(), &w, &features); + Weights::InitSparseVector(w, &origin); + const string forest_repository = conf["forest_repository"].as(); + if (!DirectoryExists(forest_repository)) { + cerr << "Forest repository directory " << forest_repository << " not found!\n"; + return 1; + } + if (conf.count("optimize_feature") > 0) + features=conf["optimize_feature"].as >(); + vector > directions; + vector fids(features.size()); + for (unsigned i = 0; i < features.size(); ++i) + fids[i] = FD::Convert(features[i]); + LineOptimizer::CreateOptimizationDirections( + fids, + conf["random_directions"].as(), + &rng, + &directions); + unsigned dev_set_size = conf["dev_set_size"].as(); + for (unsigned i = 0; i < dev_set_size; ++i) { + for (unsigned j = 0; j < directions.size(); ++j) { + cout << forest_repository << '/' << i << ".json.gz " << i << ' '; + print(cout, origin, "=", ";"); + cout << ' '; + print(cout, directions[j], "=", ";"); + cout << endl; + } + } + return 0; +} diff --git a/training/dpmert/mr_dpmert_map.cc b/training/dpmert/mr_dpmert_map.cc new file mode 100644 index 00000000..d1efcf96 --- /dev/null +++ b/training/dpmert/mr_dpmert_map.cc @@ -0,0 +1,112 @@ +#include +#include +#include +#include + +#include +#include + +#include "ns.h" +#include "ns_docscorer.h" +#include "ces.h" +#include "filelib.h" +#include "stringlib.h" +#include "sparse_vector.h" +#include "mert_geometry.h" +#include "inside_outside.h" +#include "error_surface.h" +#include "b64tools.h" +#include "hg_io.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("source,s",po::value(), "Source file (ignored, except for AER)") + ("evaluation_metric,m",po::value()->default_value("ibm_bleu"), "Evaluation metric being optimized") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +bool ReadSparseVectorString(const string& s, SparseVector* v) { +#if 0 + // this should work, but untested. + std::istringstream i(s); + i>>*v; +#else + vector fields; + Tokenize(s, ';', &fields); + if (fields.empty()) return false; + for (unsigned i = 0; i < fields.size(); ++i) { + vector pair(2); + Tokenize(fields[i], '=', &pair); + if (pair.size() != 2) { + cerr << "Error parsing vector string: " << fields[i] << endl; + return false; + } + v->set_value(FD::Convert(pair[0]), atof(pair[1].c_str())); + } + return true; +#endif +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string evaluation_metric = conf["evaluation_metric"].as(); + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as()); + istream &in=*in_read.stream(); + while(in) { + string line; + getline(in, line); + if (line.empty()) continue; + istringstream is(line); + int sent_id; + string file, s_origin, s_direction; + // path-to-file (JSON) sent_ed starting-point search-direction + is >> file >> sent_id >> s_origin >> s_direction; + SparseVector origin; + ReadSparseVectorString(s_origin, &origin); + SparseVector direction; + ReadSparseVectorString(s_direction, &direction); + // cerr << "File: " << file << "\nDir: " << direction << "\n X: " << origin << endl; + if (last_file != file) { + last_file = file; + ReadFile rf(file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + } + const ConvexHullWeightFunction wf(origin, direction); + const ConvexHull hull = Inside(hg, NULL, wf); + + ErrorSurface es; + ComputeErrorSurface(*ds[sent_id], hull, &es, metric, hg); + //cerr << "Viterbi envelope has " << ve.size() << " segments\n"; + // cerr << "Error surface has " << es.size() << " segments\n"; + string val; + es.Serialize(&val); + cout << 'M' << ' ' << s_origin << ' ' << s_direction << '\t'; + B64::b64encode(val.c_str(), val.size(), &cout); + cout << endl << flush; + } + return 0; +} diff --git a/training/dpmert/mr_dpmert_reduce.cc b/training/dpmert/mr_dpmert_reduce.cc new file mode 100644 index 00000000..31512a03 --- /dev/null +++ b/training/dpmert/mr_dpmert_reduce.cc @@ -0,0 +1,77 @@ +#include +#include +#include +#include + +#include +#include + +#include "sparse_vector.h" +#include "error_surface.h" +#include "line_optimizer.h" +#include "b64tools.h" +#include "stringlib.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("evaluation_metric,m",po::value(), "Evaluation metric (IBM_BLEU, etc.)") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = conf->count("evaluation_metric") == 0; + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string evaluation_metric = conf["evaluation_metric"].as(); + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + LineOptimizer::ScoreType opt_type = LineOptimizer::MAXIMIZE_SCORE; + if (metric->IsErrorMetric()) + opt_type = LineOptimizer::MINIMIZE_SCORE; + + vector esv; + string last_key, line, key, val; + while(getline(cin, line)) { + size_t ks = line.find("\t"); + assert(string::npos != ks); + assert(ks > 2); + key = line.substr(2, ks - 2); + val = line.substr(ks + 1); + if (key != last_key) { + if (!last_key.empty()) { + float score; + double x = LineOptimizer::LineOptimize(metric, esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + last_key.swap(key); + esv.clear(); + } + if (val.size() % 4 != 0) { + cerr << "B64 encoding error 1! Skipping.\n"; + continue; + } + string encoded(val.size() / 4 * 3, '\0'); + if (!B64::b64decode(reinterpret_cast(&val[0]), val.size(), &encoded[0], encoded.size())) { + cerr << "B64 encoding error 2! Skipping.\n"; + continue; + } + esv.push_back(ErrorSurface()); + esv.back().Deserialize(encoded); + } + if (!esv.empty()) { + float score; + double x = LineOptimizer::LineOptimize(metric, esv, opt_type, &score); + cout << last_key << "|" << x << "|" << score << endl; + } + return 0; +} diff --git a/training/dpmert/test_aer/README b/training/dpmert/test_aer/README new file mode 100644 index 00000000..819b2e32 --- /dev/null +++ b/training/dpmert/test_aer/README @@ -0,0 +1,8 @@ +To run the test: + +../dist-vest.pl --local --metric aer cdec.ini --source-file corpus.src --ref-files=ref.0 --weights weights + +This will optimize the parameters of the tiny lexical translation model +so as to minimize the AER of the Viterbi alignment on the development +set in corpus.src according to the reference alignments in ref.0. + diff --git a/training/dpmert/test_aer/cdec.ini b/training/dpmert/test_aer/cdec.ini new file mode 100644 index 00000000..08187848 --- /dev/null +++ b/training/dpmert/test_aer/cdec.ini @@ -0,0 +1,3 @@ +formalism=lextrans +grammar=grammar +aligner=true diff --git a/training/dpmert/test_aer/corpus.src b/training/dpmert/test_aer/corpus.src new file mode 100644 index 00000000..31b23971 --- /dev/null +++ b/training/dpmert/test_aer/corpus.src @@ -0,0 +1,3 @@ +el gato negro ||| the black cat +el gato ||| the cat +el libro ||| the book diff --git a/training/dpmert/test_aer/grammar b/training/dpmert/test_aer/grammar new file mode 100644 index 00000000..9d857824 --- /dev/null +++ b/training/dpmert/test_aer/grammar @@ -0,0 +1,12 @@ +el ||| cat ||| F1=1 +el ||| the ||| F2=1 +el ||| black ||| F3=1 +el ||| book ||| F11=1 +gato ||| cat ||| F4=1 NN=1 +gato ||| black ||| F5=1 +gato ||| the ||| F6=1 +negro ||| the ||| F7=1 +negro ||| cat ||| F8=1 +negro ||| black ||| F9=1 +libro ||| the ||| F10=1 +libro ||| book ||| F12=1 NN=1 diff --git a/training/dpmert/test_aer/ref.0 b/training/dpmert/test_aer/ref.0 new file mode 100644 index 00000000..734a9c5b --- /dev/null +++ b/training/dpmert/test_aer/ref.0 @@ -0,0 +1,3 @@ +0-0 1-2 2-1 +0-0 1-1 +0-0 1-1 diff --git a/training/dpmert/test_aer/weights b/training/dpmert/test_aer/weights new file mode 100644 index 00000000..afc9282e --- /dev/null +++ b/training/dpmert/test_aer/weights @@ -0,0 +1,13 @@ +F1 0.1 +F2 -.5980815 +F3 0.24235 +F4 0.625 +F5 0.4514 +F6 0.112316 +F7 -0.123415 +F8 -0.25390285 +F9 -0.23852 +F10 0.646 +F11 0.413141 +F12 0.343216 +NN -0.1215 diff --git a/training/dpmert/test_data/0.json.gz b/training/dpmert/test_data/0.json.gz new file mode 100644 index 00000000..30f8dd77 Binary files /dev/null and b/training/dpmert/test_data/0.json.gz differ diff --git a/training/dpmert/test_data/1.json.gz b/training/dpmert/test_data/1.json.gz new file mode 100644 index 00000000..c82cc179 Binary files /dev/null and b/training/dpmert/test_data/1.json.gz differ diff --git a/training/dpmert/test_data/c2e.txt.0 b/training/dpmert/test_data/c2e.txt.0 new file mode 100644 index 00000000..12c4abe9 --- /dev/null +++ b/training/dpmert/test_data/c2e.txt.0 @@ -0,0 +1,2 @@ +australia reopens embassy in manila +( afp , manila , january 2 ) australia reopened its embassy in the philippines today , which was shut down about seven weeks ago due to what was described as a specific threat of a terrorist attack . diff --git a/training/dpmert/test_data/c2e.txt.1 b/training/dpmert/test_data/c2e.txt.1 new file mode 100644 index 00000000..4ac12df1 --- /dev/null +++ b/training/dpmert/test_data/c2e.txt.1 @@ -0,0 +1,2 @@ +australia reopened manila embassy +( agence france-presse , manila , 2nd ) - australia reopened its embassy in the philippines today . the embassy was closed seven weeks ago after what was described as a specific threat of a terrorist attack . diff --git a/training/dpmert/test_data/c2e.txt.2 b/training/dpmert/test_data/c2e.txt.2 new file mode 100644 index 00000000..2f67b72f --- /dev/null +++ b/training/dpmert/test_data/c2e.txt.2 @@ -0,0 +1,2 @@ +australia to reopen embassy in manila +( afp report from manila , january 2 ) australia reopened its embassy in the philippines today . seven weeks ago , the embassy was shut down due to so-called confirmed terrorist attack threats . diff --git a/training/dpmert/test_data/c2e.txt.3 b/training/dpmert/test_data/c2e.txt.3 new file mode 100644 index 00000000..5483cef6 --- /dev/null +++ b/training/dpmert/test_data/c2e.txt.3 @@ -0,0 +1,2 @@ +australia to re - open its embassy to manila +( afp , manila , thursday ) australia reopens its embassy to manila , which was closed for the so-called " clear " threat of terrorist attack 7 weeks ago . diff --git a/training/dpmert/test_data/re.txt.0 b/training/dpmert/test_data/re.txt.0 new file mode 100644 index 00000000..86eff087 --- /dev/null +++ b/training/dpmert/test_data/re.txt.0 @@ -0,0 +1,5 @@ +erdogan states turkey to reject any pressures to urge it to recognize cyprus +ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara will reject any pressure by the european union to urge it to recognize cyprus . this comes two weeks before the summit of european union state and government heads who will decide whether or nor membership negotiations with ankara should be opened . +erdogan told " ntv " television station that " the european union cannot address us by imposing new conditions on us with regard to cyprus . +we will discuss this dossier in the course of membership negotiations . " +he added " let me be clear , i cannot sidestep turkey , this is something we cannot accept . " diff --git a/training/dpmert/test_data/re.txt.1 b/training/dpmert/test_data/re.txt.1 new file mode 100644 index 00000000..2140f198 --- /dev/null +++ b/training/dpmert/test_data/re.txt.1 @@ -0,0 +1,5 @@ +erdogan confirms turkey will resist any pressure to recognize cyprus +ankara 12 - 1 ( afp ) - the turkish head of government , recep tayyip erdogan , announced today ( wednesday ) that ankara would resist any pressure the european union might exercise in order to force it into recognizing cyprus . this comes two weeks before a summit of european union heads of state and government , who will decide whether or not to open membership negotiations with ankara . +erdogan said to the ntv television channel : " the european union cannot engage with us through imposing new conditions on us with regard to cyprus . +we shall discuss this issue in the course of the membership negotiations . " +he added : " let me be clear - i cannot confine turkey . this is something we do not accept . " diff --git a/training/dpmert/test_data/re.txt.2 b/training/dpmert/test_data/re.txt.2 new file mode 100644 index 00000000..94e46286 --- /dev/null +++ b/training/dpmert/test_data/re.txt.2 @@ -0,0 +1,5 @@ +erdogan confirms that turkey will reject any pressures to encourage it to recognize cyprus +ankara , 12 / 1 ( afp ) - the turkish prime minister recep tayyip erdogan declared today , wednesday , that ankara will reject any pressures that the european union may apply on it to encourage to recognize cyprus . this comes two weeks before a summit of the heads of countries and governments of the european union , who will decide on whether or not to start negotiations on joining with ankara . +erdogan told the ntv television station that " it is not possible for the european union to talk to us by imposing new conditions on us regarding cyprus . +we shall discuss this dossier during the negotiations on joining . " +and he added , " let me be clear . turkey's arm should not be twisted ; this is something we cannot accept . " diff --git a/training/dpmert/test_data/re.txt.3 b/training/dpmert/test_data/re.txt.3 new file mode 100644 index 00000000..f87c3308 --- /dev/null +++ b/training/dpmert/test_data/re.txt.3 @@ -0,0 +1,5 @@ +erdogan stresses that turkey will reject all pressures to force it to recognize cyprus +ankara 12 - 1 ( afp ) - turkish prime minister recep tayyip erdogan announced today , wednesday , that ankara would refuse all pressures applied on it by the european union to force it to recognize cyprus . that came two weeks before the summit of the presidents and prime ministers of the european union , who would decide on whether to open negotiations on joining with ankara or not . +erdogan said to " ntv " tv station that the " european union can not communicate with us by imposing on us new conditions related to cyprus . +we will discuss this file during the negotiations on joining . " +he added , " let me be clear . turkey's arm should not be twisted . this is unacceptable to us . " diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am new file mode 100644 index 00000000..5b48e756 --- /dev/null +++ b/training/dtrain/Makefile.am @@ -0,0 +1,7 @@ +bin_PROGRAMS = dtrain + +dtrain_SOURCES = dtrain.cc score.cc +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval + diff --git a/training/dtrain/README.md b/training/dtrain/README.md new file mode 100644 index 00000000..7edabbf1 --- /dev/null +++ b/training/dtrain/README.md @@ -0,0 +1,48 @@ +This is a simple (and parallelizable) tuning method for cdec +which is able to train the weights of very many (sparse) features. +It was used here: + "Joint Feature Selection in Distributed Stochastic + Learning for Large-Scale Discriminative Training in + SMT" +(Simianer, Riezler, Dyer; ACL 2012) + + +Building +-------- +Builds when building cdec, see ../BUILDING . +To build only parts needed for dtrain do +``` + autoreconf -ifv + ./configure [--disable-gtest] + cd dtrain/; make +``` + +Running +------- +To run this on a dev set locally: +``` + #define DTRAIN_LOCAL +``` +otherwise remove that line or undef, then recompile. You need a single +grammar file or input annotated with per-sentence grammars (psg) as you +would use with cdec. Additionally you need to give dtrain a file with +references (--refs) when running locally. + +The input for use with hadoop streaming looks like this: +``` + \t\t\t +``` +To convert a psg to this format you need to replace all "\n" +by "\t". Make sure there are no tabs in your data. + +For an example of local usage (with the 'distributed' format) +the see test/example/ . This expects dtrain to be built without +DTRAIN_LOCAL. + +Legal +----- +Copyright (c) 2012 by Patrick Simianer + +See the file ../LICENSE.txt for the licensing terms that this software is +released under. + diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc new file mode 100644 index 00000000..18286668 --- /dev/null +++ b/training/dtrain/dtrain.cc @@ -0,0 +1,657 @@ +#include "dtrain.h" + + +bool +dtrain_init(int argc, char** argv, po::variables_map* cfg) +{ + po::options_description ini("Configuration File Options"); + ini.add_options() + ("input", po::value()->default_value("-"), "input file") + ("output", po::value()->default_value("-"), "output weights file, '-' for STDOUT") + ("input_weights", po::value(), "input weights file (e.g. from previous iteration)") + ("decoder_config", po::value(), "configuration file for cdec") + ("print_weights", po::value(), "weights to print on each iteration") + ("stop_after", po::value()->default_value(0), "stop after X input sentences") + ("tmp", po::value()->default_value("/tmp"), "temp dir to use") + ("keep", po::value()->zero_tokens(), "keep weights files for each iteration") + ("hstreaming", po::value(), "run in hadoop streaming mode, arg is a task id") + ("epochs", po::value()->default_value(10), "# of iterations T (per shard)") + ("k", po::value()->default_value(100), "how many translations to sample") + ("sample_from", po::value()->default_value("kbest"), "where to sample translations from: 'kbest', 'forest'") + ("filter", po::value()->default_value("uniq"), "filter kbest list: 'not', 'uniq'") + ("pair_sampling", po::value()->default_value("XYX"), "how to sample pairs: 'all', 'XYX' or 'PRO'") + ("hi_lo", po::value()->default_value(0.1), "hi and lo (X) for XYX (default 0.1), <= 0.5") + ("pair_threshold", po::value()->default_value(0.), "bleu [0,1] threshold to filter pairs") + ("N", po::value()->default_value(4), "N for Ngrams (BLEU)") + ("scorer", po::value()->default_value("stupid_bleu"), "scoring: bleu, stupid_, smooth_, approx_, lc_") + ("learning_rate", po::value()->default_value(1.0), "learning rate") + ("gamma", po::value()->default_value(0.), "gamma for SVM (0 for perceptron)") + ("select_weights", po::value()->default_value("last"), "output best, last, avg weights ('VOID' to throw away)") + ("rescale", po::value()->zero_tokens(), "rescale weight vector after each input") + ("l1_reg", po::value()->default_value("none"), "apply l1 regularization as in 'Tsuroka et al' (2010)") + ("l1_reg_strength", po::value(), "l1 regularization strength") + ("fselect", po::value()->default_value(-1), "select top x percent (or by threshold) of features after each epoch NOT IMPLEMENTED") // TODO + ("approx_bleu_d", po::value()->default_value(0.9), "discount for approx. BLEU") + ("scale_bleu_diff", po::value()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") + ("loss_margin", po::value()->default_value(0.), "update if no error in pref pair but model scores this near") + ("max_pairs", po::value()->default_value(std::numeric_limits::max()), "max. # of pairs per Sent.") +#ifdef DTRAIN_LOCAL + ("refs,r", po::value(), "references in local mode") +#endif + ("noup", po::value()->zero_tokens(), "do not update weights"); + po::options_description cl("Command Line Options"); + cl.add_options() + ("config,c", po::value(), "dtrain config file") + ("quiet,q", po::value()->zero_tokens(), "be quiet") + ("verbose,v", po::value()->zero_tokens(), "be verbose"); + cl.add(ini); + po::store(parse_command_line(argc, argv, cl), *cfg); + if (cfg->count("config")) { + ifstream ini_f((*cfg)["config"].as().c_str()); + po::store(po::parse_config_file(ini_f, ini), *cfg); + } + po::notify(*cfg); + if (!cfg->count("decoder_config")) { + cerr << cl << endl; + return false; + } + if (cfg->count("hstreaming") && (*cfg)["output"].as() != "-") { + cerr << "When using 'hstreaming' the 'output' param should be '-'." << endl; + return false; + } +#ifdef DTRAIN_LOCAL + if ((*cfg)["input"].as() == "-") { + cerr << "Can't use stdin as input with this binary. Recompile without DTRAIN_LOCAL" << endl; + return false; + } +#endif + if ((*cfg)["sample_from"].as() != "kbest" + && (*cfg)["sample_from"].as() != "forest") { + cerr << "Wrong 'sample_from' param: '" << (*cfg)["sample_from"].as() << "', use 'kbest' or 'forest'." << endl; + return false; + } + if ((*cfg)["sample_from"].as() == "kbest" && (*cfg)["filter"].as() != "uniq" && + (*cfg)["filter"].as() != "not") { + cerr << "Wrong 'filter' param: '" << (*cfg)["filter"].as() << "', use 'uniq' or 'not'." << endl; + return false; + } + if ((*cfg)["pair_sampling"].as() != "all" && (*cfg)["pair_sampling"].as() != "XYX" && + (*cfg)["pair_sampling"].as() != "PRO") { + cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as() << "'." << endl; + return false; + } + if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as() != "XYX") { + cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; + } + if((*cfg)["hi_lo"].as() > 0.5 || (*cfg)["hi_lo"].as() < 0.01) { + cerr << "hi_lo must lie in [0.01, 0.5]" << endl; + return false; + } + if ((*cfg)["pair_threshold"].as() < 0) { + cerr << "The threshold must be >= 0!" << endl; + return false; + } + if ((*cfg)["select_weights"].as() != "last" && (*cfg)["select_weights"].as() != "best" && + (*cfg)["select_weights"].as() != "avg" && (*cfg)["select_weights"].as() != "VOID") { + cerr << "Wrong 'select_weights' param: '" << (*cfg)["select_weights"].as() << "', use 'last' or 'best'." << endl; + return false; + } + return true; +} + +int +main(int argc, char** argv) +{ + // handle most parameters + po::variables_map cfg; + if (!dtrain_init(argc, argv, &cfg)) exit(1); // something is wrong + bool quiet = false; + if (cfg.count("quiet")) quiet = true; + bool verbose = false; + if (cfg.count("verbose")) verbose = true; + bool noup = false; + if (cfg.count("noup")) noup = true; + bool hstreaming = false; + string task_id; + if (cfg.count("hstreaming")) { + hstreaming = true; + quiet = true; + task_id = cfg["hstreaming"].as(); + cerr.precision(17); + } + bool rescale = false; + if (cfg.count("rescale")) rescale = true; + HSReporter rep(task_id); + bool keep = false; + if (cfg.count("keep")) keep = true; + + const unsigned k = cfg["k"].as(); + const unsigned N = cfg["N"].as(); + const unsigned T = cfg["epochs"].as(); + const unsigned stop_after = cfg["stop_after"].as(); + const string filter_type = cfg["filter"].as(); + const string sample_from = cfg["sample_from"].as(); + const string pair_sampling = cfg["pair_sampling"].as(); + const score_t pair_threshold = cfg["pair_threshold"].as(); + const string select_weights = cfg["select_weights"].as(); + const float hi_lo = cfg["hi_lo"].as(); + const score_t approx_bleu_d = cfg["approx_bleu_d"].as(); + const unsigned max_pairs = cfg["max_pairs"].as(); + weight_t loss_margin = cfg["loss_margin"].as(); + if (loss_margin > 9998.) loss_margin = std::numeric_limits::max(); + bool scale_bleu_diff = false; + if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; + bool average = false; + if (select_weights == "avg") + average = true; + vector print_weights; + if (cfg.count("print_weights")) + boost::split(print_weights, cfg["print_weights"].as(), boost::is_any_of(" ")); + + // setup decoder + register_feature_functions(); + SetSilent(true); + ReadFile ini_rf(cfg["decoder_config"].as()); + if (!quiet) + cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; + Decoder decoder(ini_rf.stream()); + + // scoring metric/scorer + string scorer_str = cfg["scorer"].as(); + LocalScorer* scorer; + if (scorer_str == "bleu") { + scorer = dynamic_cast(new BleuScorer); + } else if (scorer_str == "stupid_bleu") { + scorer = dynamic_cast(new StupidBleuScorer); + } else if (scorer_str == "smooth_bleu") { + scorer = dynamic_cast(new SmoothBleuScorer); + } else if (scorer_str == "sum_bleu") { + scorer = dynamic_cast(new SumBleuScorer); + } else if (scorer_str == "sumexp_bleu") { + scorer = dynamic_cast(new SumExpBleuScorer); + } else if (scorer_str == "sumwhatever_bleu") { + scorer = dynamic_cast(new SumWhateverBleuScorer); + } else if (scorer_str == "approx_bleu") { + scorer = dynamic_cast(new ApproxBleuScorer(N, approx_bleu_d)); + } else if (scorer_str == "lc_bleu") { + scorer = dynamic_cast(new LinearBleuScorer(N)); + } else { + cerr << "Don't know scoring metric: '" << scorer_str << "', exiting." << endl; + exit(1); + } + vector bleu_weights; + scorer->Init(N, bleu_weights); + + // setup decoder observer + MT19937 rng; // random number generator, only for forest sampling + HypSampler* observer; + if (sample_from == "kbest") + observer = dynamic_cast(new KBestGetter(k, filter_type)); + else + observer = dynamic_cast(new KSampler(k, &rng)); + observer->SetScorer(scorer); + + // init weights + vector& dense_weights = decoder.CurrentWeightVector(); + SparseVector lambdas, cumulative_penalties, w_average; + if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + + // meta params for perceptron, SVM + weight_t eta = cfg["learning_rate"].as(); + weight_t gamma = cfg["gamma"].as(); + + // l1 regularization + bool l1naive = false; + bool l1clip = false; + bool l1cumul = false; + weight_t l1_reg = 0; + if (cfg["l1_reg"].as() != "none") { + string s = cfg["l1_reg"].as(); + if (s == "naive") l1naive = true; + else if (s == "clip") l1clip = true; + else if (s == "cumul") l1cumul = true; + l1_reg = cfg["l1_reg_strength"].as(); + } + + // output + string output_fn = cfg["output"].as(); + // input + string input_fn = cfg["input"].as(); + ReadFile input(input_fn); + // buffer input for t > 0 + vector src_str_buf; // source strings (decoder takes only strings) + vector > ref_ids_buf; // references as WordID vecs + // where temp files go + string tmp_path = cfg["tmp"].as(); +#ifdef DTRAIN_LOCAL + string refs_fn = cfg["refs"].as(); + ReadFile refs(refs_fn); +#else + string grammar_buf_fn = gettmpf(tmp_path, "dtrain-grammars"); + ogzstream grammar_buf_out; + grammar_buf_out.open(grammar_buf_fn.c_str()); +#endif + + unsigned in_sz = std::numeric_limits::max(); // input index, input size + vector > all_scores; + score_t max_score = 0.; + unsigned best_it = 0; + float overall_time = 0.; + + // output cfg + if (!quiet) { + cerr << _p5; + cerr << endl << "dtrain" << endl << "Parameters:" << endl; + cerr << setw(25) << "k " << k << endl; + cerr << setw(25) << "N " << N << endl; + cerr << setw(25) << "T " << T << endl; + cerr << setw(25) << "scorer '" << scorer_str << "'" << endl; + if (scorer_str == "approx_bleu") + cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; + cerr << setw(25) << "sample from " << "'" << sample_from << "'" << endl; + if (sample_from == "kbest") + cerr << setw(25) << "filter " << "'" << filter_type << "'" << endl; + if (!scale_bleu_diff) cerr << setw(25) << "learning rate " << eta << endl; + else cerr << setw(25) << "learning rate " << "bleu diff" << endl; + cerr << setw(25) << "gamma " << gamma << endl; + cerr << setw(25) << "loss margin " << loss_margin << endl; + cerr << setw(25) << "pairs " << "'" << pair_sampling << "'" << endl; + if (pair_sampling == "XYX") + cerr << setw(25) << "hi lo " << hi_lo << endl; + cerr << setw(25) << "pair threshold " << pair_threshold << endl; + cerr << setw(25) << "select weights " << "'" << select_weights << "'" << endl; + if (cfg.count("l1_reg")) + cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as() << "'" << endl; + if (rescale) + cerr << setw(25) << "rescale " << rescale << endl; + cerr << setw(25) << "max pairs " << max_pairs << endl; + cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as() << "'" << endl; + cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; +#ifdef DTRAIN_LOCAL + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; +#endif + cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; + if (cfg.count("input_weights")) + cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as() << "'" << endl; + if (stop_after > 0) + cerr << setw(25) << "stop_after " << stop_after << endl; + if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; + } + + + for (unsigned t = 0; t < T; t++) // T epochs + { + + if (hstreaming) cerr << "reporter:status:Iteration #" << t+1 << " of " << T << endl; + + time_t start, end; + time(&start); +#ifndef DTRAIN_LOCAL + igzstream grammar_buf_in; + if (t > 0) grammar_buf_in.open(grammar_buf_fn.c_str()); +#endif + score_t score_sum = 0.; + score_t model_sum(0); + unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; + if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; + + while(true) + { + + string in; + bool next = false, stop = false; // next iteration or premature stop + if (t == 0) { + if(!getline(*input, in)) next = true; + } else { + if (ii == in_sz) next = true; // stop if we reach the end of our input + } + // stop after X sentences (but still go on for those) + if (stop_after > 0 && stop_after == ii && !next) stop = true; + + // produce some pretty output + if (!quiet && !verbose) { + if (ii == 0) cerr << " "; + if ((ii+1) % (DTRAIN_DOTS) == 0) { + cerr << "."; + cerr.flush(); + } + if ((ii+1) % (20*DTRAIN_DOTS) == 0) { + cerr << " " << ii+1 << endl; + if (!next && !stop) cerr << " "; + } + if (stop) { + if (ii % (20*DTRAIN_DOTS) != 0) cerr << " " << ii << endl; + cerr << "Stopping after " << stop_after << " input sentences." << endl; + } else { + if (next) { + if (ii % (20*DTRAIN_DOTS) != 0) cerr << " " << ii << endl; + } + } + } + + // next iteration + if (next || stop) break; + + // weights + lambdas.init_vector(&dense_weights); + + // getting input + vector ref_ids; // reference as vector +#ifndef DTRAIN_LOCAL + vector in_split; // input: sid\tsrc\tref\tpsg + if (t == 0) { + // handling input + split_in(in, in_split); + if (hstreaming && ii == 0) cerr << "reporter:counter:" << task_id << ",First ID," << in_split[0] << endl; + // getting reference + vector ref_tok; + boost::split(ref_tok, in_split[2], boost::is_any_of(" ")); + register_and_convert(ref_tok, ref_ids); + ref_ids_buf.push_back(ref_ids); + // process and set grammar + bool broken_grammar = true; // ignore broken grammars + for (string::iterator it = in.begin(); it != in.end(); it++) { + if (!isspace(*it)) { + broken_grammar = false; + break; + } + } + if (broken_grammar) { + cerr << "Broken grammar for " << ii+1 << "! Ignoring this input." << endl; + continue; + } + boost::replace_all(in, "\t", "\n"); + in += "\n"; + grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl; + decoder.AddSupplementalGrammarFromString(in); + src_str_buf.push_back(in_split[1]); + // decode + observer->SetRef(ref_ids); + decoder.Decode(in_split[1], observer); + } else { + // get buffered grammar + string grammar_str; + while (true) { + string rule; + getline(grammar_buf_in, rule); + if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break; + grammar_str += rule + "\n"; + } + decoder.AddSupplementalGrammarFromString(grammar_str); + // decode + observer->SetRef(ref_ids_buf[ii]); + decoder.Decode(src_str_buf[ii], observer); + } +#else + if (t == 0) { + string r_; + getline(*refs, r_); + vector ref_tok; + boost::split(ref_tok, r_, boost::is_any_of(" ")); + register_and_convert(ref_tok, ref_ids); + ref_ids_buf.push_back(ref_ids); + src_str_buf.push_back(in); + } else { + ref_ids = ref_ids_buf[ii]; + } + observer->SetRef(ref_ids); + if (t == 0) + decoder.Decode(in, observer); + else + decoder.Decode(src_str_buf[ii], observer); +#endif + + // get (scored) samples + vector* samples = observer->GetSamples(); + + if (verbose) { + cerr << "--- ref for " << ii << ": "; + if (t > 0) printWordIDVec(ref_ids_buf[ii]); + else printWordIDVec(ref_ids); + cerr << endl; + for (unsigned u = 0; u < samples->size(); u++) { + cerr << _p2 << _np << "[" << u << ". '"; + printWordIDVec((*samples)[u].w); + cerr << "'" << endl; + cerr << "SCORE=" << (*samples)[u].score << ",model="<< (*samples)[u].model << endl; + cerr << "F{" << (*samples)[u].f << "} ]" << endl << endl; + } + } + + score_sum += (*samples)[0].score; // stats for 1best + model_sum += (*samples)[0].model; + + f_count += observer->get_f_count(); + list_sz += observer->get_sz(); + + // weight updates + if (!noup) { + // get pairs + vector > pairs; + if (pair_sampling == "all") + all_pairs(samples, pairs, pair_threshold, max_pairs); + if (pair_sampling == "XYX") + partXYX(samples, pairs, pair_threshold, max_pairs, hi_lo); + if (pair_sampling == "PRO") + PROsampling(samples, pairs, pair_threshold, max_pairs); + npairs += pairs.size(); + + for (vector >::iterator it = pairs.begin(); + it != pairs.end(); it++) { +#ifdef DTRAIN_FASTER_PERCEPTRON + bool rank_error = true; // pair sampling already did this for us + rank_errors++; + score_t margin = std::numeric_limits::max(); +#else + bool rank_error = it->first.model <= it->second.model; + if (rank_error) rank_errors++; + score_t margin = fabs(fabs(it->first.model) - fabs(it->second.model)); + if (!rank_error && margin < loss_margin) margin_violations++; +#endif + if (scale_bleu_diff) eta = it->first.score - it->second.score; + if (rank_error || margin < loss_margin) { + SparseVector diff_vec = it->first.f - it->second.f; + lambdas.plus_eq_v_times_s(diff_vec, eta); + if (gamma) + lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + } + } + + // l1 regularization + if (l1naive) { + for (unsigned d = 0; d < lambdas.size(); d++) { + weight_t v = lambdas.get(d); + lambdas.set_value(d, v - sign(v) * l1_reg); + } + } else if (l1clip) { + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + if (v > 0) { + lambdas.set_value(d, max(0., v - l1_reg)); + } else { + lambdas.set_value(d, min(0., v + l1_reg)); + } + } + } + } else if (l1cumul) { + weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input + for (unsigned d = 0; d < lambdas.size(); d++) { + if (lambdas.nonzero(d)) { + weight_t v = lambdas.get(d); + weight_t penalty = 0; + if (v > 0) { + penalty = max(0., v-(acc_penalty + cumulative_penalties.get(d))); + } else { + penalty = min(0., v+(acc_penalty - cumulative_penalties.get(d))); + } + lambdas.set_value(d, penalty); + cumulative_penalties.set_value(d, cumulative_penalties.get(d)+penalty); + } + } + } + + } + + if (rescale) lambdas /= lambdas.l2norm(); + + ++ii; + + if (hstreaming) { + rep.update_counter("Seen #"+boost::lexical_cast(t+1), 1u); + rep.update_counter("Seen", 1u); + } + + } // input loop + + if (average) w_average += lambdas; + + if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); + + if (t == 0) { + in_sz = ii; // remember size of input (# lines) + if (hstreaming) { + rep.update_counter("|Input|", ii); + rep.update_gcounter("|Input|", ii); + rep.update_gcounter("Shards", 1u); + } + } + +#ifndef DTRAIN_LOCAL + if (t == 0) { + grammar_buf_out.close(); + } else { + grammar_buf_in.close(); + } +#endif + + // print some stats + score_t score_avg = score_sum/(score_t)in_sz; + score_t model_avg = model_sum/(score_t)in_sz; + score_t score_diff, model_diff; + if (t > 0) { + score_diff = score_avg - all_scores[t-1].first; + model_diff = model_avg - all_scores[t-1].second; + } else { + score_diff = score_avg; + model_diff = model_avg; + } + + unsigned nonz = 0; + if (!quiet || hstreaming) nonz = (unsigned)lambdas.num_nonzero(); + + if (!quiet) { + cerr << _p5 << _p << "WEIGHTS" << endl; + for (vector::iterator it = print_weights.begin(); it != print_weights.end(); it++) { + cerr << setw(18) << *it << " = " << lambdas.get(FD::Convert(*it)) << endl; + } + cerr << " ---" << endl; + cerr << _np << " 1best avg score: " << score_avg; + cerr << _p << " (" << score_diff << ")" << endl; + cerr << _np << " 1best avg model score: " << model_avg; + cerr << _p << " (" << model_diff << ")" << endl; + cerr << " avg # pairs: "; + cerr << _np << npairs/(float)in_sz << endl; + cerr << " avg # rank err: "; + cerr << rank_errors/(float)in_sz << endl; +#ifndef DTRAIN_FASTER_PERCEPTRON + cerr << " avg # margin viol: "; + cerr << margin_violations/(float)in_sz << endl; +#endif + cerr << " non0 feature count: " << nonz << endl; + cerr << " avg list sz: " << list_sz/(float)in_sz << endl; + cerr << " avg f count: " << f_count/(float)list_sz << endl; + } + + if (hstreaming) { + rep.update_counter("Score 1best avg #"+boost::lexical_cast(t+1), (unsigned)(score_avg*DTRAIN_SCALE)); + rep.update_counter("Model 1best avg #"+boost::lexical_cast(t+1), (unsigned)(model_avg*DTRAIN_SCALE)); + rep.update_counter("Pairs avg #"+boost::lexical_cast(t+1), (unsigned)((npairs/(weight_t)in_sz)*DTRAIN_SCALE)); + rep.update_counter("Rank errors avg #"+boost::lexical_cast(t+1), (unsigned)((rank_errors/(weight_t)in_sz)*DTRAIN_SCALE)); + rep.update_counter("Margin violations avg #"+boost::lexical_cast(t+1), (unsigned)((margin_violations/(weight_t)in_sz)*DTRAIN_SCALE)); + rep.update_counter("Non zero feature count #"+boost::lexical_cast(t+1), nonz); + rep.update_gcounter("Non zero feature count #"+boost::lexical_cast(t+1), nonz); + } + + pair remember; + remember.first = score_avg; + remember.second = model_avg; + all_scores.push_back(remember); + if (score_avg > max_score) { + max_score = score_avg; + best_it = t; + } + time (&end); + float time_diff = difftime(end, start); + overall_time += time_diff; + if (!quiet) { + cerr << _p2 << _np << "(time " << time_diff/60. << " min, "; + cerr << time_diff/in_sz << " s/S)" << endl; + } + if (t+1 != T && !quiet) cerr << endl; + + if (noup) break; + + // write weights to file + if (select_weights == "best" || keep) { + lambdas.init_vector(&dense_weights); + string w_fn = "weights." + boost::lexical_cast(t) + ".gz"; + Weights::WriteToFile(w_fn, dense_weights, true); + } + + } // outer loop + + if (average) w_average /= (weight_t)T; + +#ifndef DTRAIN_LOCAL + unlink(grammar_buf_fn.c_str()); +#endif + + if (!noup) { + if (!quiet) cerr << endl << "Writing weights file to '" << output_fn << "' ..." << endl; + if (select_weights == "last" || average) { // last, average + WriteFile of(output_fn); // works with '-' + ostream& o = *of.stream(); + o.precision(17); + o << _np; + if (average) { + for (SparseVector::iterator it = w_average.begin(); it != w_average.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } + } else { + for (SparseVector::iterator it = lambdas.begin(); it != lambdas.end(); ++it) { + if (it->second == 0) continue; + o << FD::Convert(it->first) << '\t' << it->second << endl; + } + } + } else if (select_weights == "VOID") { // do nothing with the weights + } else { // best + if (output_fn != "-") { + CopyFile("weights."+boost::lexical_cast(best_it)+".gz", output_fn); + } else { + ReadFile bestw("weights."+boost::lexical_cast(best_it)+".gz"); + string o; + cout.precision(17); + cout << _np; + while(getline(*bestw, o)) cout << o << endl; + } + if (!keep) { + for (unsigned i = 0; i < T; i++) { + string s = "weights." + boost::lexical_cast(i) + ".gz"; + unlink(s.c_str()); + } + } + } + if (output_fn == "-" && hstreaming) cout << "__SHARD_COUNT__\t1" << endl; + if (!quiet) cerr << "done" << endl; + } + + if (!quiet) { + cerr << _p5 << _np << endl << "---" << endl << "Best iteration: "; + cerr << best_it+1 << " [SCORE '" << scorer_str << "'=" << max_score << "]." << endl; + cerr << "This took " << overall_time/60. << " min." << endl; + } +} + diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h new file mode 100644 index 00000000..4b6f415c --- /dev/null +++ b/training/dtrain/dtrain.h @@ -0,0 +1,97 @@ +#ifndef _DTRAIN_H_ +#define _DTRAIN_H_ + +#undef DTRAIN_FASTER_PERCEPTRON // only look at misranked pairs + // DO NOT USE WITH SVM! +//#define DTRAIN_LOCAL +#define DTRAIN_DOTS 10 // after how many inputs to display a '.' +#define DTRAIN_GRAMMAR_DELIM "########EOS########" +#define DTRAIN_SCALE 100000 + + +#include +#include +#include + +#include +#include + +#include "ksampler.h" +#include "pairsampling.h" + +#include "filelib.h" + + +using namespace std; +using namespace dtrain; +namespace po = boost::program_options; + +inline void register_and_convert(const vector& strs, vector& ids) +{ + vector::const_iterator it; + for (it = strs.begin(); it < strs.end(); it++) + ids.push_back(TD::Convert(*it)); +} + +inline string gettmpf(const string path, const string infix) +{ + char fn[path.size() + infix.size() + 8]; + strcpy(fn, path.c_str()); + strcat(fn, "/"); + strcat(fn, infix.c_str()); + strcat(fn, "-XXXXXX"); + if (!mkstemp(fn)) { + cerr << "Cannot make temp file in" << path << " , exiting." << endl; + exit(1); + } + return string(fn); +} + +inline void split_in(string& s, vector& parts) +{ + unsigned f = 0; + for(unsigned i = 0; i < 3; i++) { + unsigned e = f; + f = s.find("\t", f+1); + if (e != 0) parts.push_back(s.substr(e+1, f-e-1)); + else parts.push_back(s.substr(0, f)); + } + s.erase(0, f+1); +} + +struct HSReporter +{ + string task_id_; + + HSReporter(string task_id) : task_id_(task_id) {} + + inline void update_counter(string name, unsigned amount) { + cerr << "reporter:counter:" << task_id_ << "," << name << "," << amount << endl; + } + inline void update_gcounter(string name, unsigned amount) { + cerr << "reporter:counter:Global," << name << "," << amount << endl; + } +}; + +inline ostream& _np(ostream& out) { return out << resetiosflags(ios::showpos); } +inline ostream& _p(ostream& out) { return out << setiosflags(ios::showpos); } +inline ostream& _p2(ostream& out) { return out << setprecision(2); } +inline ostream& _p5(ostream& out) { return out << setprecision(5); } + +inline void printWordIDVec(vector& v) +{ + for (unsigned i = 0; i < v.size(); i++) { + cerr << TD::Convert(v[i]); + if (i < v.size()-1) cerr << " "; + } +} + +template +inline T sign(T z) +{ + if (z == 0) return 0; + return z < 0 ? -1 : +1; +} + +#endif + diff --git a/training/dtrain/hstreaming/avg.rb b/training/dtrain/hstreaming/avg.rb new file mode 100755 index 00000000..2599c732 --- /dev/null +++ b/training/dtrain/hstreaming/avg.rb @@ -0,0 +1,32 @@ +#!/usr/bin/env ruby +# first arg may be an int of custom shard count + +shard_count_key = "__SHARD_COUNT__" + +STDIN.set_encoding 'utf-8' +STDOUT.set_encoding 'utf-8' + +w = {} +c = {} +w.default = 0 +c.default = 0 +while line = STDIN.gets + key, val = line.split /\s/ + w[key] += val.to_f + c[key] += 1 +end + +if ARGV.size == 0 + shard_count = w["__SHARD_COUNT__"] +else + shard_count = ARGV[0].to_f +end +w.each_key { |k| + if k == shard_count_key + next + else + puts "#{k}\t#{w[k]/shard_count}" + #puts "# #{c[k]}" + end +} + diff --git a/training/dtrain/hstreaming/cdec.ini b/training/dtrain/hstreaming/cdec.ini new file mode 100644 index 00000000..d4f5cecd --- /dev/null +++ b/training/dtrain/hstreaming/cdec.ini @@ -0,0 +1,22 @@ +formalism=scfg +add_pass_through_rules=true +scfg_max_span_limit=15 +intersection_strategy=cube_pruning +cubepruning_pop_limit=30 +feature_function=WordPenalty +feature_function=KLanguageModel nc-wmt11.en.srilm.gz +#feature_function=ArityPenalty +#feature_function=CMR2008ReorderingFeatures +#feature_function=Dwarf +#feature_function=InputIndicator +#feature_function=LexNullJump +#feature_function=NewJump +#feature_function=NgramFeatures +#feature_function=NonLatinCount +#feature_function=OutputIndicator +#feature_function=RuleIdentityFeatures +#feature_function=RuleNgramFeatures +#feature_function=RuleShape +#feature_function=SourceSpanSizeFeatures +#feature_function=SourceWordPenalty +#feature_function=SpanFeatures diff --git a/training/dtrain/hstreaming/dtrain.ini b/training/dtrain/hstreaming/dtrain.ini new file mode 100644 index 00000000..a2c219a1 --- /dev/null +++ b/training/dtrain/hstreaming/dtrain.ini @@ -0,0 +1,15 @@ +input=- +output=- +decoder_config=cdec.ini +tmp=/var/hadoop/mapred/local/ +epochs=1 +k=100 +N=4 +learning_rate=0.0001 +gamma=0 +scorer=stupid_bleu +sample_from=kbest +filter=uniq +pair_sampling=XYX +pair_threshold=0 +select_weights=last diff --git a/training/dtrain/hstreaming/dtrain.sh b/training/dtrain/hstreaming/dtrain.sh new file mode 100755 index 00000000..877ff94c --- /dev/null +++ b/training/dtrain/hstreaming/dtrain.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# script to run dtrain with a task id + +pushd . &>/dev/null +cd .. +ID=$(basename $(pwd)) # attempt_... +popd &>/dev/null +./dtrain -c dtrain.ini --hstreaming $ID + diff --git a/training/dtrain/hstreaming/hadoop-streaming-job.sh b/training/dtrain/hstreaming/hadoop-streaming-job.sh new file mode 100755 index 00000000..92419956 --- /dev/null +++ b/training/dtrain/hstreaming/hadoop-streaming-job.sh @@ -0,0 +1,30 @@ +#!/bin/sh + +EXP=a_simple_test + +# change these vars to fit your hadoop installation +HADOOP_HOME=/usr/lib/hadoop-0.20 +JAR=contrib/streaming/hadoop-streaming-0.20.2-cdh3u1.jar +HSTREAMING="$HADOOP_HOME/bin/hadoop jar $HADOOP_HOME/$JAR" + + IN=input_on_hdfs +OUT=output_weights_on_hdfs + +# you can -reducer to NONE if you want to +# do feature selection/averaging locally (e.g. to +# keep weights of all epochs) +$HSTREAMING \ + -mapper "dtrain.sh" \ + -reducer "ruby lplp.rb l2 select_k 100000" \ + -input $IN \ + -output $OUT \ + -file dtrain.sh \ + -file lplp.rb \ + -file ../dtrain \ + -file dtrain.ini \ + -file cdec.ini \ + -file ../test/example/nc-wmt11.en.srilm.gz \ + -jobconf mapred.reduce.tasks=30 \ + -jobconf mapred.max.map.failures.percent=0 \ + -jobconf mapred.job.name="dtrain $EXP" + diff --git a/training/dtrain/hstreaming/lplp.rb b/training/dtrain/hstreaming/lplp.rb new file mode 100755 index 00000000..f0cd58c5 --- /dev/null +++ b/training/dtrain/hstreaming/lplp.rb @@ -0,0 +1,131 @@ +# lplp.rb + +# norms +def l0(feature_column, n) + if feature_column.size >= n then return 1 else return 0 end +end + +def l1(feature_column, n=-1) + return feature_column.map { |i| i.abs }.reduce { |sum,i| sum+i } +end + +def l2(feature_column, n=-1) + return Math.sqrt feature_column.map { |i| i.abs2 }.reduce { |sum,i| sum+i } +end + +def linfty(feature_column, n=-1) + return feature_column.map { |i| i.abs }.max +end + +# stats +def median(feature_column, n) + return feature_column.concat(0.step(n-feature_column.size-1).map{|i|0}).sort[feature_column.size/2] +end + +def mean(feature_column, n) + return feature_column.reduce { |sum, i| sum+i } / n +end + +# selection +def select_k(weights, norm_fun, n, k=10000) + weights.sort{|a,b| norm_fun.call(b[1], n) <=> norm_fun.call(a[1], n)}.each { |p| + puts "#{p[0]}\t#{mean(p[1], n)}" + k -= 1 + if k == 0 then break end + } +end + +def cut(weights, norm_fun, n, epsilon=0.0001) + weights.each { |k,v| + if norm_fun.call(v, n).abs >= epsilon + puts "#{k}\t#{mean(v, n)}" + end + } +end + +# test +def _test() + puts + w = {} + w["a"] = [1, 2, 3] + w["b"] = [1, 2] + w["c"] = [66] + w["d"] = [10, 20, 30] + n = 3 + puts w.to_s + puts + puts "select_k" + puts "l0 expect ad" + select_k(w, method(:l0), n, 2) + puts "l1 expect cd" + select_k(w, method(:l1), n, 2) + puts "l2 expect c" + select_k(w, method(:l2), n, 1) + puts + puts "cut" + puts "l1 expect cd" + cut(w, method(:l1), n, 7) + puts + puts "median" + a = [1,2,3,4,5] + puts a.to_s + puts median(a, 5) + puts + puts "#{median(a, 7)} <- that's because we add missing 0s:" + puts a.concat(0.step(7-a.size-1).map{|i|0}).to_s + puts + puts "mean expect bc" + w.clear + w["a"] = [2] + w["b"] = [2.1] + w["c"] = [2.2] + cut(w, method(:mean), 1, 2.05) + exit +end +#_test() + +# actually do something +def usage() + puts "lplp.rb [n] < " + puts " l0...: norms for selection" + puts "select_k: only output top k (according to the norm of their column vector) features" + puts " cut: output features with weight >= threshold" + puts " n: if we do not have a shard count use this number for averaging" + exit +end + +if ARGV.size < 3 then usage end +norm_fun = method(ARGV[0].to_sym) +type = ARGV[1] +x = ARGV[2].to_f + +shard_count_key = "__SHARD_COUNT__" + +STDIN.set_encoding 'utf-8' +STDOUT.set_encoding 'utf-8' + +w = {} +shard_count = 0 +while line = STDIN.gets + key, val = line.split /\s+/ + if key == shard_count_key + shard_count += 1 + next + end + if w.has_key? key + w[key].push val.to_f + else + w[key] = [val.to_f] + end +end + +if ARGV.size == 4 then shard_count = ARGV[3].to_f end + +if type == 'cut' + cut(w, norm_fun, shard_count, x) +elsif type == 'select_k' + select_k(w, norm_fun, shard_count, x) +else + puts "oh oh" +end + diff --git a/training/dtrain/hstreaming/red-test b/training/dtrain/hstreaming/red-test new file mode 100644 index 00000000..2623d697 --- /dev/null +++ b/training/dtrain/hstreaming/red-test @@ -0,0 +1,9 @@ +a 1 +b 2 +c 3.5 +a 1 +b 2 +c 3.5 +d 1 +e 2 +__SHARD_COUNT__ 2 diff --git a/training/dtrain/kbestget.h b/training/dtrain/kbestget.h new file mode 100644 index 00000000..dd8882e1 --- /dev/null +++ b/training/dtrain/kbestget.h @@ -0,0 +1,152 @@ +#ifndef _DTRAIN_KBESTGET_H_ +#define _DTRAIN_KBESTGET_H_ + +#include "kbest.h" // cdec +#include "sentence_metadata.h" + +#include "verbose.h" +#include "viterbi.h" +#include "ff_register.h" +#include "decoder.h" +#include "weights.h" +#include "logval.h" + +using namespace std; + +namespace dtrain +{ + + +typedef double score_t; + +struct ScoredHyp +{ + vector w; + SparseVector f; + score_t model; + score_t score; + unsigned rank; +}; + +struct LocalScorer +{ + unsigned N_; + vector w_; + + virtual score_t + Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len)=0; + + void Reset() {} // only for approx bleu + + inline void + Init(unsigned N, vector weights) + { + assert(N > 0); + N_ = N; + if (weights.empty()) for (unsigned i = 0; i < N_; i++) w_.push_back(1./N_); + else w_ = weights; + } + + inline score_t + brevity_penalty(const unsigned hyp_len, const unsigned ref_len) + { + if (hyp_len > ref_len) return 1; + return exp(1 - (score_t)ref_len/hyp_len); + } +}; + +struct HypSampler : public DecoderObserver +{ + LocalScorer* scorer_; + vector* ref_; + unsigned f_count_, sz_; + virtual vector* GetSamples()=0; + inline void SetScorer(LocalScorer* scorer) { scorer_ = scorer; } + inline void SetRef(vector& ref) { ref_ = &ref; } + inline unsigned get_f_count() { return f_count_; } + inline unsigned get_sz() { return sz_; } +}; +//////////////////////////////////////////////////////////////////////////////// + + + + +struct KBestGetter : public HypSampler +{ + const unsigned k_; + const string filter_type_; + vector s_; + unsigned src_len_; + + KBestGetter(const unsigned k, const string filter_type) : + k_(k), filter_type_(filter_type) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + { + src_len_ = smeta.GetSourceLength(); + KBestScored(*hg); + } + + vector* GetSamples() { return &s_; } + + void + KBestScored(const Hypergraph& forest) + { + if (filter_type_ == "uniq") { + KBestUnique(forest); + } else if (filter_type_ == "not") { + KBestNoFilter(forest); + } + } + + void + KBestUnique(const Hypergraph& forest) + { + s_.clear(); sz_ = f_count_ = 0; + KBest::KBestDerivations, ESentenceTraversal, + KBest::FilterUnique, prob_t, EdgeProb> kbest(forest, k_); + for (unsigned i = 0; i < k_; ++i) { + const KBest::KBestDerivations, ESentenceTraversal, KBest::FilterUnique, + prob_t, EdgeProb>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + h.score = scorer_->Score(h.w, *ref_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + } + + void + KBestNoFilter(const Hypergraph& forest) + { + s_.clear(); sz_ = f_count_ = 0; + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, k_); + for (unsigned i = 0; i < k_; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + ScoredHyp h; + h.w = d->yield; + h.f = d->feature_values; + h.model = log(d->score); + h.rank = i; + h.score = scorer_->Score(h.w, *ref_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + } +}; + + +} // namespace + +#endif + diff --git a/training/dtrain/ksampler.h b/training/dtrain/ksampler.h new file mode 100644 index 00000000..bc2f56cd --- /dev/null +++ b/training/dtrain/ksampler.h @@ -0,0 +1,61 @@ +#ifndef _DTRAIN_KSAMPLER_H_ +#define _DTRAIN_KSAMPLER_H_ + +#include "hg_sampler.h" // cdec +#include "kbestget.h" +#include "score.h" + +namespace dtrain +{ + +bool +cmp_hyp_by_model_d(ScoredHyp a, ScoredHyp b) +{ + return a.model > b.model; +} + +struct KSampler : public HypSampler +{ + const unsigned k_; + vector s_; + MT19937* prng_; + score_t (*scorer)(NgramCounts&, const unsigned, const unsigned, unsigned, vector); + unsigned src_len_; + + explicit KSampler(const unsigned k, MT19937* prng) : + k_(k), prng_(prng) {} + + virtual void + NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) + { + src_len_ = smeta.GetSourceLength(); + ScoredSamples(*hg); + } + + vector* GetSamples() { return &s_; } + + void ScoredSamples(const Hypergraph& forest) { + s_.clear(); sz_ = f_count_ = 0; + std::vector samples; + HypergraphSampler::sample_hypotheses(forest, k_, prng_, &samples); + for (unsigned i = 0; i < k_; ++i) { + ScoredHyp h; + h.w = samples[i].words; + h.f = samples[i].fmap; + h.model = log(samples[i].model_score); + h.rank = i; + h.score = scorer_->Score(h.w, *ref_, i, src_len_); + s_.push_back(h); + sz_++; + f_count_ += h.f.size(); + } + sort(s_.begin(), s_.end(), cmp_hyp_by_model_d); + for (unsigned i = 0; i < s_.size(); i++) s_[i].rank = i; + } +}; + + +} // namespace + +#endif + diff --git a/training/dtrain/pairsampling.h b/training/dtrain/pairsampling.h new file mode 100644 index 00000000..84be1efb --- /dev/null +++ b/training/dtrain/pairsampling.h @@ -0,0 +1,149 @@ +#ifndef _DTRAIN_PAIRSAMPLING_H_ +#define _DTRAIN_PAIRSAMPLING_H_ + +namespace dtrain +{ + + +bool +accept_pair(score_t a, score_t b, score_t threshold) +{ + if (fabs(a - b) < threshold) return false; + return true; +} + +bool +cmp_hyp_by_score_d(ScoredHyp a, ScoredHyp b) +{ + return a.score > b.score; +} + +inline void +all_pairs(vector* s, vector >& training, score_t threshold, unsigned max, float _unused=1) +{ + sort(s->begin(), s->end(), cmp_hyp_by_score_d); + unsigned sz = s->size(); + bool b = false; + unsigned count = 0; + for (unsigned i = 0; i < sz-1; i++) { + for (unsigned j = i+1; j < sz; j++) { + if (threshold > 0) { + if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) + training.push_back(make_pair((*s)[i], (*s)[j])); + } else { + if ((*s)[i].score != (*s)[j].score) + training.push_back(make_pair((*s)[i], (*s)[j])); + } + if (++count == max) { + b = true; + break; + } + } + if (b) break; + } +} + +/* + * multipartite ranking + * sort (descending) by bleu + * compare top X to middle Y and low X + * cmp middle Y to low X + */ + +inline void +partXYX(vector* s, vector >& training, score_t threshold, unsigned max, float hi_lo) +{ + unsigned sz = s->size(); + if (sz < 2) return; + sort(s->begin(), s->end(), cmp_hyp_by_score_d); + unsigned sep = round(sz*hi_lo); + unsigned sep_hi = sep; + if (sz > 4) while (sep_hi < sz && (*s)[sep_hi-1].score == (*s)[sep_hi].score) ++sep_hi; + else sep_hi = 1; + bool b = false; + unsigned count = 0; + for (unsigned i = 0; i < sep_hi; i++) { + for (unsigned j = sep_hi; j < sz; j++) { +#ifdef DTRAIN_FASTER_PERCEPTRON + if ((*s)[i].model <= (*s)[j].model) { +#endif + if (threshold > 0) { + if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) + training.push_back(make_pair((*s)[i], (*s)[j])); + } else { + if ((*s)[i].score != (*s)[j].score) + training.push_back(make_pair((*s)[i], (*s)[j])); + } + if (++count == max) { + b = true; + break; + } +#ifdef DTRAIN_FASTER_PERCEPTRON + } +#endif + } + if (b) break; + } + unsigned sep_lo = sz-sep; + while (sep_lo > 0 && (*s)[sep_lo-1].score == (*s)[sep_lo].score) --sep_lo; + for (unsigned i = sep_hi; i < sz-sep_lo; i++) { + for (unsigned j = sz-sep_lo; j < sz; j++) { +#ifdef DTRAIN_FASTER_PERCEPTRON + if ((*s)[i].model <= (*s)[j].model) { +#endif + if (threshold > 0) { + if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) + training.push_back(make_pair((*s)[i], (*s)[j])); + } else { + if ((*s)[i].score != (*s)[j].score) + training.push_back(make_pair((*s)[i], (*s)[j])); + } + if (++count == max) return; +#ifdef DTRAIN_FASTER_PERCEPTRON + } +#endif + } + } +} + +/* + * pair sampling as in + * 'Tuning as Ranking' (Hopkins & May, 2011) + * count = 5000 + * threshold = 5% BLEU (0.05 for param 3) + * cut = top 50 + */ +bool +_PRO_cmp_pair_by_diff_d(pair a, pair b) +{ + return (fabs(a.first.score - a.second.score)) > (fabs(b.first.score - b.second.score)); +} +inline void +PROsampling(vector* s, vector >& training, score_t threshold, unsigned max, float _unused=1) +{ + unsigned max_count = 5000, count = 0, sz = s->size(); + bool b = false; + for (unsigned i = 0; i < sz-1; i++) { + for (unsigned j = i+1; j < sz; j++) { + if (accept_pair((*s)[i].score, (*s)[j].score, threshold)) { + training.push_back(make_pair((*s)[i], (*s)[j])); + if (++count == max_count) { + b = true; + break; + } + } + } + if (b) break; + } + if (training.size() > 50) { + sort(training.begin(), training.end(), _PRO_cmp_pair_by_diff_d); + training.erase(training.begin()+50, training.end()); + } + return; +} + + +} // namespace + +#endif + diff --git a/training/dtrain/parallelize.rb b/training/dtrain/parallelize.rb new file mode 100755 index 00000000..1d277ff6 --- /dev/null +++ b/training/dtrain/parallelize.rb @@ -0,0 +1,79 @@ +#!/usr/bin/env ruby + + +if ARGV.size != 5 + STDERR.write "Usage: " + STDERR.write "ruby parallelize.rb <#shards> \n" + exit +end + +dtrain_bin = '/home/pks/bin/dtrain_local' +ruby = '/usr/bin/ruby' +lplp_rb = '/home/pks/mt/cdec-dtrain/dtrain/hstreaming/lplp.rb' +lplp_args = 'l2 select_k 100000' +gzip = '/bin/gzip' + +num_shards = ARGV[0].to_i +input = ARGV[1] +refs = ARGV[2] +epochs = ARGV[3].to_i +ini = ARGV[4] + + +`mkdir work` + +def make_shards(input, refs, num_shards) + lc = `wc -l #{input}`.split.first.to_i + shard_sz = lc / num_shards + leftover = lc % num_shards + in_f = File.new input, 'r' + refs_f = File.new refs, 'r' + shard_in_files = [] + shard_refs_files = [] + 0.upto(num_shards-1) { |shard| + shard_in = File.new "work/shard.#{shard}.in", 'w+' + shard_refs = File.new "work/shard.#{shard}.refs", 'w+' + 0.upto(shard_sz-1) { |i| + shard_in.write in_f.gets + shard_refs.write refs_f.gets + } + shard_in_files << shard_in + shard_refs_files << shard_refs + } + while leftover > 0 + shard_in_files[-1].write in_f.gets + shard_refs_files[-1].write refs_f.gets + leftover -= 1 + end + (shard_in_files + shard_refs_files).each do |f| f.close end + in_f.close + refs_f.close +end + +make_shards input, refs, num_shards + +0.upto(epochs-1) { |epoch| + pids = [] + input_weights = '' + if epoch > 0 then input_weights = "--input_weights work/weights.#{epoch-1}" end + weights_files = [] + 0.upto(num_shards-1) { |shard| + pids << Kernel.fork { + `#{dtrain_bin} -c #{ini}\ + --input work/shard.#{shard}.in\ + --refs work/shard.#{shard}.refs #{input_weights}\ + --output work/weights.#{shard}.#{epoch}\ + &> work/out.#{shard}.#{epoch}` + } + weights_files << "work/weights.#{shard}.#{epoch}" + } + pids.each { |pid| Process.wait(pid) } + cat = File.new('work/weights_cat', 'w+') + weights_files.each { |f| cat.write File.new(f, 'r').read } + cat.close + `#{ruby} #{lplp_rb} #{lplp_args} #{num_shards} < work/weights_cat &> work/weights.#{epoch}` +} + +`rm work/weights_cat` +`#{gzip} work/*` + diff --git a/training/dtrain/parallelize/test/cdec.ini b/training/dtrain/parallelize/test/cdec.ini new file mode 100644 index 00000000..72e99dc5 --- /dev/null +++ b/training/dtrain/parallelize/test/cdec.ini @@ -0,0 +1,22 @@ +formalism=scfg +add_pass_through_rules=true +intersection_strategy=cube_pruning +cubepruning_pop_limit=200 +scfg_max_span_limit=15 +feature_function=WordPenalty +feature_function=KLanguageModel /stor/dat/wmt12/en/news_only/m/wmt12.news.en.3.kenv5 +#feature_function=ArityPenalty +#feature_function=CMR2008ReorderingFeatures +#feature_function=Dwarf +#feature_function=InputIndicator +#feature_function=LexNullJump +#feature_function=NewJump +#feature_function=NgramFeatures +#feature_function=NonLatinCount +#feature_function=OutputIndicator +#feature_function=RuleIdentityFeatures +#feature_function=RuleNgramFeatures +#feature_function=RuleShape +#feature_function=SourceSpanSizeFeatures +#feature_function=SourceWordPenalty +#feature_function=SpanFeatures diff --git a/training/dtrain/parallelize/test/dtrain.ini b/training/dtrain/parallelize/test/dtrain.ini new file mode 100644 index 00000000..03f9d240 --- /dev/null +++ b/training/dtrain/parallelize/test/dtrain.ini @@ -0,0 +1,15 @@ +k=100 +N=4 +learning_rate=0.0001 +gamma=0 +loss_margin=0 +epochs=1 +scorer=stupid_bleu +sample_from=kbest +filter=uniq +pair_sampling=XYX +hi_lo=0.1 +select_weights=last +print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough +tmp=/tmp +decoder_config=cdec.ini diff --git a/training/dtrain/parallelize/test/in b/training/dtrain/parallelize/test/in new file mode 100644 index 00000000..a312809f --- /dev/null +++ b/training/dtrain/parallelize/test/in @@ -0,0 +1,10 @@ +barack obama erhält als vierter us @-@ präsident den frieden nobelpreis +der amerikanische präsident barack obama kommt für 26 stunden nach oslo , norwegen , um hier als vierter us @-@ präsident in der geschichte den frieden nobelpreis entgegen zunehmen . +darüber hinaus erhält er das diplom sowie die medaille und einen scheck über 1,4 mio. dollar für seine außer gewöhnlichen bestrebungen um die intensivierung der welt diplomatie und zusammen arbeit unter den völkern . +der chef des weißen hauses kommt morgen zusammen mit seiner frau michelle in der nordwegischen metropole an und wird die ganze zeit beschäftigt sein . +zunächst stattet er dem nobel @-@ institut einen besuch ab , wo er überhaupt zum ersten mal mit den fünf ausschuss mitglieder zusammen trifft , die ihn im oktober aus 172 leuten und 33 organisationen gewählt haben . +das präsidenten paar hat danach ein treffen mit dem norwegischen könig harald v. und königin sonja eingeplant . +nachmittags erreicht dann der besuch seinen höhepunkt mit der zeremonie , bei der obama den prestige preis übernimmt . +diesen erhält er als der vierte us @-@ präsident , aber erst als der dritte , der den preis direkt im amt entgegen nimmt . +das weiße haus avisierte schon , dass obama bei der übernahme des preises über den afghanistan krieg sprechen wird . +der präsident will diesem thema nicht ausweichen , weil er weiß , dass er den preis als ein präsident übernimmt , der zur zeit krieg in zwei ländern führt . diff --git a/training/dtrain/parallelize/test/refs b/training/dtrain/parallelize/test/refs new file mode 100644 index 00000000..4d3128cb --- /dev/null +++ b/training/dtrain/parallelize/test/refs @@ -0,0 +1,10 @@ +barack obama becomes the fourth american president to receive the nobel peace prize +the american president barack obama will fly into oslo , norway for 26 hours to receive the nobel peace prize , the fourth american president in history to do so . +he will receive a diploma , medal and cheque for 1.4 million dollars for his exceptional efforts to improve global diplomacy and encourage international cooperation , amongst other things . +the head of the white house will be flying into the norwegian city in the morning with his wife michelle and will have a busy schedule . +first , he will visit the nobel institute , where he will have his first meeting with the five committee members who selected him from 172 people and 33 organisations . +the presidential couple then has a meeting scheduled with king harald v and queen sonja of norway . +then , in the afternoon , the visit will culminate in a grand ceremony , at which obama will receive the prestigious award . +he will be the fourth american president to be awarded the prize , and only the third to have received it while actually in office . +the white house has stated that , when he accepts the prize , obama will speak about the war in afghanistan . +the president does not want to skirt around this topic , as he realises that he is accepting the prize as a president whose country is currently at war in two countries . diff --git a/training/dtrain/score.cc b/training/dtrain/score.cc new file mode 100644 index 00000000..34fc86a9 --- /dev/null +++ b/training/dtrain/score.cc @@ -0,0 +1,254 @@ +#include "score.h" + +namespace dtrain +{ + + +/* + * bleu + * + * as in "BLEU: a Method for Automatic Evaluation + * of Machine Translation" + * (Papineni et al. '02) + * + * NOTE: 0 if for one n \in {1..N} count is 0 + */ +score_t +BleuScorer::Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len) +{ + if (hyp_len == 0 || ref_len == 0) return 0.; + unsigned M = N_; + vector v = w_; + if (ref_len < N_) { + M = ref_len; + for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); + } + score_t sum = 0; + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) return 0.; + sum += v[i] * log((score_t)counts.clipped_[i]/counts.sum_[i]); + } + return brevity_penalty(hyp_len, ref_len) * exp(sum); +} + +score_t +BleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + return Bleu(counts, hyp_len, ref_len); +} + +/* + * 'stupid' bleu + * + * as in "ORANGE: a Method for Evaluating + * Automatic Evaluation Metrics + * for Machine Translation" + * (Lin & Och '04) + * + * NOTE: 0 iff no 1gram match + */ +score_t +StupidBleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + unsigned M = N_; + vector v = w_; + if (ref_len < N_) { + M = ref_len; + for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); + } + score_t sum = 0, add = 0; + for (unsigned i = 0; i < M; i++) { + if (i == 0 && (counts.sum_[i] == 0 || counts.clipped_[i] == 0)) return 0.; + if (i == 1) add = 1; + sum += v[i] * log(((score_t)counts.clipped_[i] + add)/((counts.sum_[i] + add))); + } + return brevity_penalty(hyp_len, ref_len) * exp(sum); +} + +/* + * smooth bleu + * + * as in "An End-to-End Discriminative Approach + * to Machine Translation" + * (Liang et al. '06) + * + * NOTE: max is 0.9375 (with N=4) + */ +score_t +SmoothBleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + unsigned M = N_; + if (ref_len < N_) M = ref_len; + score_t sum = 0.; + vector i_bleu; + for (unsigned i = 0; i < M; i++) i_bleu.push_back(0.); + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) { + break; + } else { + score_t i_ng = log((score_t)counts.clipped_[i]/counts.sum_[i]); + for (unsigned j = i; j < M; j++) { + i_bleu[j] += (1/((score_t)j+1)) * i_ng; + } + } + sum += exp(i_bleu[i])/pow(2.0, (double)(N_-i)); + } + return brevity_penalty(hyp_len, ref_len) * sum; +} + +/* + * 'sum' bleu + * + * sum up Ngram precisions + */ +score_t +SumBleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + unsigned M = N_; + if (ref_len < N_) M = ref_len; + score_t sum = 0.; + unsigned j = 1; + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; + sum += ((score_t)counts.clipped_[i]/counts.sum_[i])/pow(2.0, (double) (N_-j+1)); + j++; + } + return brevity_penalty(hyp_len, ref_len) * sum; +} + +/* + * 'sum' (exp) bleu + * + * sum up exp(Ngram precisions) + */ +score_t +SumExpBleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + unsigned M = N_; + if (ref_len < N_) M = ref_len; + score_t sum = 0.; + unsigned j = 1; + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; + sum += exp(((score_t)counts.clipped_[i]/counts.sum_[i]))/pow(2.0, (double) (N_-j+1)); + j++; + } + return brevity_penalty(hyp_len, ref_len) * sum; +} + +/* + * 'sum' (whatever) bleu + * + * sum up exp(weight * log(Ngram precisions)) + */ +score_t +SumWhateverBleuScorer::Score(vector& hyp, vector& ref, + const unsigned /*rank*/, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (hyp_len == 0 || ref_len == 0) return 0.; + NgramCounts counts = make_ngram_counts(hyp, ref, N_); + unsigned M = N_; + vector v = w_; + if (ref_len < N_) { + M = ref_len; + for (unsigned i = 0; i < M; i++) v[i] = 1/((score_t)M); + } + score_t sum = 0.; + unsigned j = 1; + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || counts.clipped_[i] == 0) break; + sum += exp(v[i] * log(((score_t)counts.clipped_[i]/counts.sum_[i])))/pow(2.0, (double) (N_-j+1)); + j++; + } + return brevity_penalty(hyp_len, ref_len) * sum; +} + +/* + * approx. bleu + * + * as in "Online Large-Margin Training of Syntactic + * and Structural Translation Features" + * (Chiang et al. '08) + * + * NOTE: Needs some more code in dtrain.cc . + * No scaling by src len. + */ +score_t +ApproxBleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank, const unsigned src_len) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (ref_len == 0) return 0.; + score_t score = 0.; + NgramCounts counts(N_); + if (hyp_len > 0) { + counts = make_ngram_counts(hyp, ref, N_); + NgramCounts tmp = glob_onebest_counts_ + counts; + score = Bleu(tmp, hyp_len, ref_len); + } + if (rank == 0) { // 'context of 1best translations' + glob_onebest_counts_ += counts; + glob_onebest_counts_ *= discount_; + glob_hyp_len_ = discount_ * (glob_hyp_len_ + hyp_len); + glob_ref_len_ = discount_ * (glob_ref_len_ + ref_len); + glob_src_len_ = discount_ * (glob_src_len_ + src_len); + } + return score; +} + +/* + * Linear (Corpus) Bleu + * + * as in "Lattice Minimum Bayes-Risk Decoding + * for Statistical Machine Translation" + * (Tromble et al. '08) + * + */ +score_t +LinearBleuScorer::Score(vector& hyp, vector& ref, + const unsigned rank, const unsigned /*src_len*/) +{ + unsigned hyp_len = hyp.size(), ref_len = ref.size(); + if (ref_len == 0) return 0.; + unsigned M = N_; + if (ref_len < N_) M = ref_len; + NgramCounts counts(M); + if (hyp_len > 0) + counts = make_ngram_counts(hyp, ref, M); + score_t ret = 0.; + for (unsigned i = 0; i < M; i++) { + if (counts.sum_[i] == 0 || onebest_counts_.sum_[i] == 0) break; + ret += counts.sum_[i]/onebest_counts_.sum_[i]; + } + ret = -(hyp_len/(score_t)onebest_len_) + (1./M) * ret; + if (rank == 0) { + onebest_len_ += hyp_len; + onebest_counts_ += counts; + } + return ret; +} + + +} // namespace + diff --git a/training/dtrain/score.h b/training/dtrain/score.h new file mode 100644 index 00000000..f317c903 --- /dev/null +++ b/training/dtrain/score.h @@ -0,0 +1,212 @@ +#ifndef _DTRAIN_SCORE_H_ +#define _DTRAIN_SCORE_H_ + +#include "kbestget.h" + +using namespace std; + +namespace dtrain +{ + + +struct NgramCounts +{ + unsigned N_; + map clipped_; + map sum_; + + NgramCounts(const unsigned N) : N_(N) { Zero(); } + + inline void + operator+=(const NgramCounts& rhs) + { + if (rhs.N_ > N_) Resize(rhs.N_); + for (unsigned i = 0; i < N_; i++) { + this->clipped_[i] += rhs.clipped_.find(i)->second; + this->sum_[i] += rhs.sum_.find(i)->second; + } + } + + inline const NgramCounts + operator+(const NgramCounts &other) const + { + NgramCounts result = *this; + result += other; + return result; + } + + inline void + operator*=(const score_t rhs) + { + for (unsigned i = 0; i < N_; i++) { + this->clipped_[i] *= rhs; + this->sum_[i] *= rhs; + } + } + + inline void + Add(const unsigned count, const unsigned ref_count, const unsigned i) + { + assert(i < N_); + if (count > ref_count) { + clipped_[i] += ref_count; + } else { + clipped_[i] += count; + } + sum_[i] += count; + } + + inline void + Zero() + { + for (unsigned i = 0; i < N_; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } + + inline void + One() + { + for (unsigned i = 0; i < N_; i++) { + clipped_[i] = 1.; + sum_[i] = 1.; + } + } + + inline void + Print() + { + for (unsigned i = 0; i < N_; i++) { + cout << i+1 << "grams (clipped):\t" << clipped_[i] << endl; + cout << i+1 << "grams:\t\t\t" << sum_[i] << endl; + } + } + + inline void Resize(unsigned N) + { + if (N == N_) return; + else if (N > N_) { + for (unsigned i = N_; i < N; i++) { + clipped_[i] = 0.; + sum_[i] = 0.; + } + } else { // N < N_ + for (unsigned i = N_-1; i > N-1; i--) { + clipped_.erase(i); + sum_.erase(i); + } + } + N_ = N; + } +}; + +typedef map, unsigned> Ngrams; + +inline Ngrams +make_ngrams(const vector& s, const unsigned N) +{ + Ngrams ngrams; + vector ng; + for (size_t i = 0; i < s.size(); i++) { + ng.clear(); + for (unsigned j = i; j < min(i+N, s.size()); j++) { + ng.push_back(s[j]); + ngrams[ng]++; + } + } + return ngrams; +} + +inline NgramCounts +make_ngram_counts(const vector& hyp, const vector& ref, const unsigned N) +{ + Ngrams hyp_ngrams = make_ngrams(hyp, N); + Ngrams ref_ngrams = make_ngrams(ref, N); + NgramCounts counts(N); + Ngrams::iterator it; + Ngrams::iterator ti; + for (it = hyp_ngrams.begin(); it != hyp_ngrams.end(); it++) { + ti = ref_ngrams.find(it->first); + if (ti != ref_ngrams.end()) { + counts.Add(it->second, ti->second, it->first.size() - 1); + } else { + counts.Add(it->second, 0, it->first.size() - 1); + } + } + return counts; +} + +struct BleuScorer : public LocalScorer +{ + score_t Bleu(NgramCounts& counts, const unsigned hyp_len, const unsigned ref_len); + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct StupidBleuScorer : public LocalScorer +{ + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SmoothBleuScorer : public LocalScorer +{ + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumBleuScorer : public LocalScorer +{ + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumExpBleuScorer : public LocalScorer +{ + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct SumWhateverBleuScorer : public LocalScorer +{ + score_t Score(vector& hyp, vector& ref, const unsigned /*rank*/, const unsigned /*src_len*/); +}; + +struct ApproxBleuScorer : public BleuScorer +{ + NgramCounts glob_onebest_counts_; + unsigned glob_hyp_len_, glob_ref_len_, glob_src_len_; + score_t discount_; + + ApproxBleuScorer(unsigned N, score_t d) : glob_onebest_counts_(NgramCounts(N)), discount_(d) + { + glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0; + } + + inline void Reset() { + glob_onebest_counts_.Zero(); + glob_hyp_len_ = glob_ref_len_ = glob_src_len_ = 0.; + } + + score_t Score(vector& hyp, vector& ref, const unsigned rank, const unsigned src_len); +}; + +struct LinearBleuScorer : public BleuScorer +{ + unsigned onebest_len_; + NgramCounts onebest_counts_; + + LinearBleuScorer(unsigned N) : onebest_len_(1), onebest_counts_(N) + { + onebest_counts_.One(); + } + + score_t Score(vector& hyp, vector& ref, const unsigned rank, const unsigned /*src_len*/); + + inline void Reset() { + onebest_len_ = 1; + onebest_counts_.One(); + } +}; + + +} // namespace + +#endif + diff --git a/training/dtrain/test/example/README b/training/dtrain/test/example/README new file mode 100644 index 00000000..6937b11b --- /dev/null +++ b/training/dtrain/test/example/README @@ -0,0 +1,8 @@ +Small example of input format for distributed training. +Call dtrain from cdec/dtrain/ with ./dtrain -c test/example/dtrain.ini . + +For this to work, undef 'DTRAIN_LOCAL' in dtrain.h +and recompile. + +Data is here: http://simianer.de/#dtrain + diff --git a/training/dtrain/test/example/cdec.ini b/training/dtrain/test/example/cdec.ini new file mode 100644 index 00000000..d5955f0e --- /dev/null +++ b/training/dtrain/test/example/cdec.ini @@ -0,0 +1,25 @@ +formalism=scfg +add_pass_through_rules=true +scfg_max_span_limit=15 +intersection_strategy=cube_pruning +cubepruning_pop_limit=30 +feature_function=WordPenalty +feature_function=KLanguageModel test/example/nc-wmt11.en.srilm.gz +# all currently working feature functions for translation: +# (with those features active that were used in the ACL paper) +#feature_function=ArityPenalty +#feature_function=CMR2008ReorderingFeatures +#feature_function=Dwarf +#feature_function=InputIndicator +#feature_function=LexNullJump +#feature_function=NewJump +#feature_function=NgramFeatures +#feature_function=NonLatinCount +#feature_function=OutputIndicator +feature_function=RuleIdentityFeatures +feature_function=RuleSourceBigramFeatures +feature_function=RuleTargetBigramFeatures +feature_function=RuleShape +#feature_function=SourceSpanSizeFeatures +#feature_function=SourceWordPenalty +#feature_function=SpanFeatures diff --git a/training/dtrain/test/example/dtrain.ini b/training/dtrain/test/example/dtrain.ini new file mode 100644 index 00000000..72d50ca1 --- /dev/null +++ b/training/dtrain/test/example/dtrain.ini @@ -0,0 +1,22 @@ +input=test/example/nc-wmt11.1k.gz # use '-' for STDIN +output=- # a weights file (add .gz for gzip compression) or STDOUT '-' +select_weights=VOID # don't output weights +decoder_config=test/example/cdec.ini # config for cdec +# weights for these features will be printed on each iteration +print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough +tmp=/tmp +stop_after=10 # stop epoch after 10 inputs + +# interesting stuff +epochs=2 # run over input 2 times +k=100 # use 100best lists +N=4 # optimize (approx) BLEU4 +scorer=stupid_bleu # use 'stupid' BLEU+1 +learning_rate=1.0 # learning rate, don't care if gamma=0 (perceptron) +gamma=0 # use SVM reg +sample_from=kbest # use kbest lists (as opposed to forest) +filter=uniq # only unique entries in kbest (surface form) +pair_sampling=XYX +hi_lo=0.1 # 10 vs 80 vs 10 and 80 vs 10 here +pair_threshold=0 # minimum distance in BLEU (this will still only use pairs with diff > 0) +loss_margin=0 diff --git a/training/dtrain/test/example/expected-output b/training/dtrain/test/example/expected-output new file mode 100644 index 00000000..05326763 --- /dev/null +++ b/training/dtrain/test/example/expected-output @@ -0,0 +1,89 @@ + cdec cfg 'test/example/cdec.ini' +Loading the LM will be faster if you build a binary file. +Reading test/example/nc-wmt11.en.srilm.gz +----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 +**************************************************************************************************** + Example feature: Shape_S00000_T00000 +Seeding random number sequence to 2912000813 + +dtrain +Parameters: + k 100 + N 4 + T 2 + scorer 'stupid_bleu' + sample from 'kbest' + filter 'uniq' + learning rate 1 + gamma 0 + loss margin 0 + pairs 'XYX' + hi lo 0.1 + pair threshold 0 + select weights 'VOID' + l1 reg 0 'none' + max pairs 4294967295 + cdec cfg 'test/example/cdec.ini' + input 'test/example/nc-wmt11.1k.gz' + output '-' + stop_after 10 +(a dot represents 10 inputs) +Iteration #1 of 2. + . 10 +Stopping after 10 input sentences. +WEIGHTS + Glue = -637 + WordPenalty = +1064 + LanguageModel = +1175.3 + LanguageModel_OOV = -1437 + PhraseModel_0 = +1935.6 + PhraseModel_1 = +2499.3 + PhraseModel_2 = +964.96 + PhraseModel_3 = +1410.8 + PhraseModel_4 = -5977.9 + PhraseModel_5 = +522 + PhraseModel_6 = +1089 + PassThrough = -1308 + --- + 1best avg score: 0.16963 (+0.16963) + 1best avg model score: 64485 (+64485) + avg # pairs: 1494.4 + avg # rank err: 702.6 + avg # margin viol: 0 + non0 feature count: 528 + avg list sz: 85.7 + avg f count: 102.75 +(time 0.083 min, 0.5 s/S) + +Iteration #2 of 2. + . 10 +WEIGHTS + Glue = -1196 + WordPenalty = +809.52 + LanguageModel = +3112.1 + LanguageModel_OOV = -1464 + PhraseModel_0 = +3895.5 + PhraseModel_1 = +4683.4 + PhraseModel_2 = +1092.8 + PhraseModel_3 = +1079.6 + PhraseModel_4 = -6827.7 + PhraseModel_5 = -888 + PhraseModel_6 = +142 + PassThrough = -1335 + --- + 1best avg score: 0.277 (+0.10736) + 1best avg model score: -3110.5 (-67595) + avg # pairs: 1144.2 + avg # rank err: 529.1 + avg # margin viol: 0 + non0 feature count: 859 + avg list sz: 74.9 + avg f count: 112.84 +(time 0.067 min, 0.4 s/S) + +Writing weights file to '-' ... +done + +--- +Best iteration: 2 [SCORE 'stupid_bleu'=0.277]. +This took 0.15 min. diff --git a/training/dtrain/test/parallelize/cdec.ini b/training/dtrain/test/parallelize/cdec.ini new file mode 100644 index 00000000..72e99dc5 --- /dev/null +++ b/training/dtrain/test/parallelize/cdec.ini @@ -0,0 +1,22 @@ +formalism=scfg +add_pass_through_rules=true +intersection_strategy=cube_pruning +cubepruning_pop_limit=200 +scfg_max_span_limit=15 +feature_function=WordPenalty +feature_function=KLanguageModel /stor/dat/wmt12/en/news_only/m/wmt12.news.en.3.kenv5 +#feature_function=ArityPenalty +#feature_function=CMR2008ReorderingFeatures +#feature_function=Dwarf +#feature_function=InputIndicator +#feature_function=LexNullJump +#feature_function=NewJump +#feature_function=NgramFeatures +#feature_function=NonLatinCount +#feature_function=OutputIndicator +#feature_function=RuleIdentityFeatures +#feature_function=RuleNgramFeatures +#feature_function=RuleShape +#feature_function=SourceSpanSizeFeatures +#feature_function=SourceWordPenalty +#feature_function=SpanFeatures diff --git a/training/dtrain/test/parallelize/dtrain.ini b/training/dtrain/test/parallelize/dtrain.ini new file mode 100644 index 00000000..03f9d240 --- /dev/null +++ b/training/dtrain/test/parallelize/dtrain.ini @@ -0,0 +1,15 @@ +k=100 +N=4 +learning_rate=0.0001 +gamma=0 +loss_margin=0 +epochs=1 +scorer=stupid_bleu +sample_from=kbest +filter=uniq +pair_sampling=XYX +hi_lo=0.1 +select_weights=last +print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 PhraseModel_1 PhraseModel_2 PhraseModel_3 PhraseModel_4 PhraseModel_5 PhraseModel_6 PassThrough +tmp=/tmp +decoder_config=cdec.ini diff --git a/training/dtrain/test/parallelize/in b/training/dtrain/test/parallelize/in new file mode 100644 index 00000000..a312809f --- /dev/null +++ b/training/dtrain/test/parallelize/in @@ -0,0 +1,10 @@ +barack obama erhält als vierter us @-@ präsident den frieden nobelpreis +der amerikanische präsident barack obama kommt für 26 stunden nach oslo , norwegen , um hier als vierter us @-@ präsident in der geschichte den frieden nobelpreis entgegen zunehmen . +darüber hinaus erhält er das diplom sowie die medaille und einen scheck über 1,4 mio. dollar für seine außer gewöhnlichen bestrebungen um die intensivierung der welt diplomatie und zusammen arbeit unter den völkern . +der chef des weißen hauses kommt morgen zusammen mit seiner frau michelle in der nordwegischen metropole an und wird die ganze zeit beschäftigt sein . +zunächst stattet er dem nobel @-@ institut einen besuch ab , wo er überhaupt zum ersten mal mit den fünf ausschuss mitglieder zusammen trifft , die ihn im oktober aus 172 leuten und 33 organisationen gewählt haben . +das präsidenten paar hat danach ein treffen mit dem norwegischen könig harald v. und königin sonja eingeplant . +nachmittags erreicht dann der besuch seinen höhepunkt mit der zeremonie , bei der obama den prestige preis übernimmt . +diesen erhält er als der vierte us @-@ präsident , aber erst als der dritte , der den preis direkt im amt entgegen nimmt . +das weiße haus avisierte schon , dass obama bei der übernahme des preises über den afghanistan krieg sprechen wird . +der präsident will diesem thema nicht ausweichen , weil er weiß , dass er den preis als ein präsident übernimmt , der zur zeit krieg in zwei ländern führt . diff --git a/training/dtrain/test/parallelize/refs b/training/dtrain/test/parallelize/refs new file mode 100644 index 00000000..4d3128cb --- /dev/null +++ b/training/dtrain/test/parallelize/refs @@ -0,0 +1,10 @@ +barack obama becomes the fourth american president to receive the nobel peace prize +the american president barack obama will fly into oslo , norway for 26 hours to receive the nobel peace prize , the fourth american president in history to do so . +he will receive a diploma , medal and cheque for 1.4 million dollars for his exceptional efforts to improve global diplomacy and encourage international cooperation , amongst other things . +the head of the white house will be flying into the norwegian city in the morning with his wife michelle and will have a busy schedule . +first , he will visit the nobel institute , where he will have his first meeting with the five committee members who selected him from 172 people and 33 organisations . +the presidential couple then has a meeting scheduled with king harald v and queen sonja of norway . +then , in the afternoon , the visit will culminate in a grand ceremony , at which obama will receive the prestigious award . +he will be the fourth american president to be awarded the prize , and only the third to have received it while actually in office . +the white house has stated that , when he accepts the prize , obama will speak about the war in afghanistan . +the president does not want to skirt around this topic , as he realises that he is accepting the prize as a president whose country is currently at war in two countries . diff --git a/training/dtrain/test/toy/cdec.ini b/training/dtrain/test/toy/cdec.ini new file mode 100644 index 00000000..98b02d44 --- /dev/null +++ b/training/dtrain/test/toy/cdec.ini @@ -0,0 +1,2 @@ +formalism=scfg +add_pass_through_rules=true diff --git a/training/dtrain/test/toy/dtrain.ini b/training/dtrain/test/toy/dtrain.ini new file mode 100644 index 00000000..a091732f --- /dev/null +++ b/training/dtrain/test/toy/dtrain.ini @@ -0,0 +1,12 @@ +decoder_config=test/toy/cdec.ini +input=test/toy/input +output=- +print_weights=logp shell_rule house_rule small_rule little_rule PassThrough +k=4 +N=4 +epochs=2 +scorer=bleu +sample_from=kbest +filter=uniq +pair_sampling=all +learning_rate=1 diff --git a/training/dtrain/test/toy/input b/training/dtrain/test/toy/input new file mode 100644 index 00000000..4d10a9ea --- /dev/null +++ b/training/dtrain/test/toy/input @@ -0,0 +1,2 @@ +0 ich sah ein kleines haus i saw a little house [S] ||| [NP,1] [VP,2] ||| [1] [2] ||| logp=0 [NP] ||| ich ||| i ||| logp=0 [NP] ||| ein [NN,1] ||| a [1] ||| logp=0 [NN] ||| [JJ,1] haus ||| [1] house ||| logp=0 house_rule=1 [NN] ||| [JJ,1] haus ||| [1] shell ||| logp=0 shell_rule=1 [JJ] ||| kleines ||| small ||| logp=0 small_rule=1 [JJ] ||| kleines ||| little ||| logp=0 little_rule=1 [JJ] ||| grosses ||| big ||| logp=0 [JJ] ||| grosses ||| large ||| logp=0 [VP] ||| [V,1] [NP,2] ||| [1] [2] ||| logp=0 [V] ||| sah ||| saw ||| logp=0 [V] ||| fand ||| found ||| logp=0 +1 ich fand ein kleines haus i found a little house [S] ||| [NP,1] [VP,2] ||| [1] [2] ||| logp=0 [NP] ||| ich ||| i ||| logp=0 [NP] ||| ein [NN,1] ||| a [1] ||| logp=0 [NN] ||| [JJ,1] haus ||| [1] house ||| logp=0 house_rule=1 [NN] ||| [JJ,1] haus ||| [1] shell ||| logp=0 shell_rule=1 [JJ] ||| kleines ||| small ||| logp=0 small_rule=1 [JJ] ||| kleines ||| little ||| logp=0 little_rule=1 [JJ] ||| grosses ||| big ||| logp=0 [JJ] ||| grosses ||| large ||| logp=0 [VP] ||| [V,1] [NP,2] ||| [1] [2] ||| logp=0 [V] ||| sah ||| saw ||| logp=0 [V] ||| fand ||| found ||| logp=0 diff --git a/training/entropy.cc b/training/entropy.cc deleted file mode 100644 index 4fdbe2be..00000000 --- a/training/entropy.cc +++ /dev/null @@ -1,41 +0,0 @@ -#include "entropy.h" - -#include "prob.h" -#include "candidate_set.h" - -using namespace std; - -namespace training { - -// see Mann and McCallum "Efficient Computation of Entropy Gradient ..." for -// a mostly clear derivation of: -// g = E[ F(x,y) * log p(y|x) ] + H(y | x) * E[ F(x,y) ] -double CandidateSetEntropy::operator()(const vector& params, - SparseVector* g) const { - prob_t z; - vector dps(cands_.size()); - for (unsigned i = 0; i < cands_.size(); ++i) { - dps[i] = cands_[i].fmap.dot(params); - const prob_t u(dps[i], init_lnx()); - z += u; - } - const double log_z = log(z); - - SparseVector exp_feats; - double entropy = 0; - for (unsigned i = 0; i < cands_.size(); ++i) { - const double log_prob = cands_[i].fmap.dot(params) - log_z; - const double prob = exp(log_prob); - const double e_logprob = prob * log_prob; - entropy -= e_logprob; - if (g) { - (*g) += cands_[i].fmap * e_logprob; - exp_feats += cands_[i].fmap * prob; - } - } - if (g) (*g) += exp_feats * entropy; - return entropy; -} - -} - diff --git a/training/entropy.h b/training/entropy.h deleted file mode 100644 index 796589ca..00000000 --- a/training/entropy.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef _CSENTROPY_H_ -#define _CSENTROPY_H_ - -#include -#include "sparse_vector.h" - -namespace training { - class CandidateSet; - - class CandidateSetEntropy { - public: - explicit CandidateSetEntropy(const CandidateSet& cs) : cands_(cs) {} - // compute the entropy (expected log likelihood) of a CandidateSet - // (optional) the gradient of the entropy with respect to params - double operator()(const std::vector& params, - SparseVector* g = NULL) const; - private: - const CandidateSet& cands_; - }; -}; - -#endif diff --git a/training/fast_align.cc b/training/fast_align.cc deleted file mode 100644 index 7492d26f..00000000 --- a/training/fast_align.cc +++ /dev/null @@ -1,281 +0,0 @@ -#include -#include - -#include -#include - -#include "m.h" -#include "corpus_tools.h" -#include "stringlib.h" -#include "filelib.h" -#include "ttables.h" -#include "tdict.h" - -namespace po = boost::program_options; -using namespace std; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input,i",po::value(),"Parallel corpus input file") - ("reverse,r","Reverse estimation (swap source and target during training)") - ("iterations,I",po::value()->default_value(5),"Number of iterations of EM training") - //("bidir,b", "Run bidirectional alignment") - ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") - ("prob_align_null", po::value()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") - ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") - ("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior") - ("alpha,a", po::value()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") - ("no_null_word,N","Do not generate from a null token") - ("output_parameters,p", "Write model parameters instead of alignments") - ("beam_threshold,t",po::value()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)") - ("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)") - ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") - ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || conf->count("input") == 0) { - cerr << "Usage " << argv[0] << " [OPTIONS] -i corpus.fr-en\n"; - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -int main(int argc, char** argv) { - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) return 1; - const string fname = conf["input"].as(); - const bool reverse = conf.count("reverse") > 0; - const int ITERATIONS = conf["iterations"].as(); - const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as()); - const bool use_null = (conf.count("no_null_word") == 0); - const WordID kNULL = TD::Convert(""); - const bool add_viterbi = (conf.count("no_add_viterbi") == 0); - const bool variational_bayes = (conf.count("variational_bayes") > 0); - const bool write_alignments = (conf.count("output_parameters") == 0); - const double diagonal_tension = conf["diagonal_tension"].as(); - const double prob_align_null = conf["prob_align_null"].as(); - const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); - string testset; - if (conf.count("testset")) testset = conf["testset"].as(); - const double prob_align_not_null = 1.0 - prob_align_null; - const double alpha = conf["alpha"].as(); - const bool favor_diagonal = conf.count("favor_diagonal"); - if (variational_bayes && alpha <= 0.0) { - cerr << "--alpha must be > 0\n"; - return 1; - } - - TTable s2t, t2s; - TTable::Word2Word2Double s2t_viterbi; - double tot_len_ratio = 0; - double mean_srclen_multiplier = 0; - vector unnormed_a_i; - for (int iter = 0; iter < ITERATIONS; ++iter) { - const bool final_iteration = (iter == (ITERATIONS - 1)); - cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; - ReadFile rf(fname); - istream& in = *rf.stream(); - double likelihood = 0; - double denom = 0.0; - int lc = 0; - bool flag = false; - string line; - string ssrc, strg; - vector src, trg; - while(true) { - getline(in, line); - if (!in) break; - ++lc; - if (lc % 1000 == 0) { cerr << '.'; flag = true; } - if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } - src.clear(); trg.clear(); - CorpusTools::ReadLine(line, &src, &trg); - if (reverse) swap(src, trg); - if (src.size() == 0 || trg.size() == 0) { - cerr << "Error: " << lc << "\n" << line << endl; - return 1; - } - if (src.size() > unnormed_a_i.size()) - unnormed_a_i.resize(src.size()); - if (iter == 0) - tot_len_ratio += static_cast(trg.size()) / static_cast(src.size()); - denom += trg.size(); - vector probs(src.size() + 1); - bool first_al = true; // used for write_alignments - for (int j = 0; j < trg.size(); ++j) { - const WordID& f_j = trg[j]; - double sum = 0; - const double j_over_ts = double(j) / trg.size(); - double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) - if (use_null) { - if (favor_diagonal) prob_a_i = prob_align_null; - probs[0] = s2t.prob(kNULL, f_j) * prob_a_i; - sum += probs[0]; - } - double az = 0; - if (favor_diagonal) { - for (int ta = 0; ta < src.size(); ++ta) { - unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); - az += unnormed_a_i[ta]; - } - az /= prob_align_not_null; - } - for (int i = 1; i <= src.size(); ++i) { - if (favor_diagonal) - prob_a_i = unnormed_a_i[i-1] / az; - probs[i] = s2t.prob(src[i-1], f_j) * prob_a_i; - sum += probs[i]; - } - if (final_iteration) { - if (add_viterbi || write_alignments) { - WordID max_i = 0; - double max_p = -1; - int max_index = -1; - if (use_null) { - max_i = kNULL; - max_index = 0; - max_p = probs[0]; - } - for (int i = 1; i <= src.size(); ++i) { - if (probs[i] > max_p) { - max_index = i; - max_p = probs[i]; - max_i = src[i-1]; - } - } - if (!hide_training_alignments && write_alignments) { - if (max_index > 0) { - if (first_al) first_al = false; else cout << ' '; - if (reverse) - cout << j << '-' << (max_index - 1); - else - cout << (max_index - 1) << '-' << j; - } - } - s2t_viterbi[max_i][f_j] = 1.0; - } - } else { - if (use_null) - s2t.Increment(kNULL, f_j, probs[0] / sum); - for (int i = 1; i <= src.size(); ++i) - s2t.Increment(src[i-1], f_j, probs[i] / sum); - } - likelihood += log(sum); - } - if (write_alignments && final_iteration && !hide_training_alignments) cout << endl; - } - - // log(e) = 1.0 - double base2_likelihood = likelihood / log(2); - - if (flag) { cerr << endl; } - if (iter == 0) { - mean_srclen_multiplier = tot_len_ratio / lc; - cerr << "expected target length = source length * " << mean_srclen_multiplier << endl; - } - cerr << " log_e likelihood: " << likelihood << endl; - cerr << " log_2 likelihood: " << base2_likelihood << endl; - cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; - cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; - if (!final_iteration) { - if (variational_bayes) - s2t.NormalizeVB(alpha); - else - s2t.Normalize(); - } - } - if (testset.size()) { - ReadFile rf(testset); - istream& in = *rf.stream(); - int lc = 0; - double tlp = 0; - string line; - while (getline(in, line)) { - ++lc; - vector src, trg; - CorpusTools::ReadLine(line, &src, &trg); - cout << TD::GetString(src) << " ||| " << TD::GetString(trg) << " |||"; - if (reverse) swap(src, trg); - double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); - if (src.size() > unnormed_a_i.size()) - unnormed_a_i.resize(src.size()); - - // compute likelihood - for (int j = 0; j < trg.size(); ++j) { - const WordID& f_j = trg[j]; - double sum = 0; - int a_j = 0; - double max_pat = 0; - const double j_over_ts = double(j) / trg.size(); - double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) - if (use_null) { - if (favor_diagonal) prob_a_i = prob_align_null; - max_pat = s2t.prob(kNULL, f_j) * prob_a_i; - sum += max_pat; - } - double az = 0; - if (favor_diagonal) { - for (int ta = 0; ta < src.size(); ++ta) { - unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); - az += unnormed_a_i[ta]; - } - az /= prob_align_not_null; - } - for (int i = 1; i <= src.size(); ++i) { - if (favor_diagonal) - prob_a_i = unnormed_a_i[i-1] / az; - double pat = s2t.prob(src[i-1], f_j) * prob_a_i; - if (pat > max_pat) { max_pat = pat; a_j = i; } - sum += pat; - } - log_prob += log(sum); - if (write_alignments) { - if (a_j > 0) { - cout << ' '; - if (reverse) - cout << j << '-' << (a_j - 1); - else - cout << (a_j - 1) << '-' << j; - } - } - } - tlp += log_prob; - cout << " ||| " << log_prob << endl << flush; - } // loop over test set sentences - cerr << "TOTAL LOG PROB " << tlp << endl; - } - - if (write_alignments) return 0; - - for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { - const TTable::Word2Double& cpd = ei->second; - const TTable::Word2Double& vit = s2t_viterbi[ei->first]; - const string& esym = TD::Convert(ei->first); - double max_p = -1; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) - if (fi->second > max_p) max_p = fi->second; - const double threshold = max_p * BEAM_THRESHOLD; - for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { - if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { - cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; - } - } - } - return 0; -} - diff --git a/training/feature_expectations.cc b/training/feature_expectations.cc deleted file mode 100644 index f1a85495..00000000 --- a/training/feature_expectations.cc +++ /dev/null @@ -1,232 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "online_optimizer.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" -#include "sampler.h" - -#ifdef HAVE_MPI -#include -#include -namespace mpi = boost::mpi; -#endif - -using namespace std; -namespace po = boost::program_options; - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - sort(fnums.begin(), fnums.end(), FComp(w)); - for (vector::iterator i = fnums.begin(); i != fnums.end(); ++i) { - if (w[*i]) cout << FD::Convert(*i) << ' ' << w[*i] << endl; - } -} - -void ReadConfig(const string& ini, vector* out) { - ReadFile rf(ini); - istream& in = *rf.stream(); - while(in) { - string line; - getline(in, line); - if (!in) continue; - out->push_back(line); - } -} - -void StoreConfig(const vector& cfg, istringstream* o) { - ostringstream os; - for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } - o->str(os.str()); -} - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input,i",po::value(),"Corpus of source language sentences") - ("weights,w",po::value(),"Input feature weights file") - ("decoder_config,c",po::value(), "cdec.ini file"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("input") || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int id = 0; - while(in) { - getline(in, line); - if (!in) break; - if (id % size == rank) { - c->push_back(line); - order->push_back(id); - } - ++id; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_exp.clear(); - total_complete = 0; - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - cur_model_exp.clear(); - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - state = 2; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_model_exp /= z; - acc_exp += cur_model_exp; - } - - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - cerr << "IGNORING ALIGNMENT FOREST!\n"; - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - if (state == 2) { - ++total_complete; - } - } - - void GetExpectations(SparseVector* g) const { - g->clear(); - for (SparseVector::const_iterator it = acc_exp.begin(); it != acc_exp.end(); ++it) - g->set_value(it->first, it->second); - } - - int total_complete; - SparseVector cur_model_exp; - SparseVector acc_exp; - int state; -}; - -#ifdef HAVE_MPI -namespace boost { namespace mpi { - template<> - struct is_commutative >, SparseVector > - : mpl::true_ { }; -} } // end namespace boost::mpi -#endif - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return 1; - - // load initial weights - Weights weights; - if (conf.count("weights")) - weights.InitFromFile(conf["weights"].as()); - - vector corpus; - vector ids; - ReadTrainingCorpus(conf["input"].as(), rank, size, &corpus, &ids); - assert(corpus.size() > 0); - - vector cdec_ini; - ReadConfig(conf["decoder_config"].as(), &cdec_ini); - istringstream ini; - StoreConfig(cdec_ini, &ini); - Decoder decoder(&ini); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - return 1; - } - - SparseVector x; - weights.InitSparseVector(&x); - TrainingObserver observer; - - weights.InitFromVector(x); - vector lambdas; - weights.InitVector(&lambdas); - decoder.SetWeights(lambdas); - observer.Reset(); - for (unsigned i = 0; i < corpus.size(); ++i) { - int id = ids[i]; - decoder.SetId(id); - decoder.Decode(corpus[i], &observer); - } - SparseVector local_exps, exps; - observer.GetExpectations(&local_exps); -#ifdef HAVE_MPI - reduce(world, local_exps, exps, std::plus >(), 0); -#else - exps.swap(local_exps); -#endif - - weights.InitFromVector(exps); - weights.InitVector(&lambdas); - ShowFeatures(lambdas); - - return 0; -} diff --git a/training/grammar_convert.cc b/training/grammar_convert.cc deleted file mode 100644 index 607a7cb9..00000000 --- a/training/grammar_convert.cc +++ /dev/null @@ -1,348 +0,0 @@ -/* - this program modifies cfg hypergraphs (forests) and extracts kbests? - what are: json, split ? - */ -#include -#include -#include - -#include -#include - -#include "inside_outside.h" -#include "tdict.h" -#include "filelib.h" -#include "hg.h" -#include "hg_io.h" -#include "kbest.h" -#include "viterbi.h" -#include "weights.h" - -namespace po = boost::program_options; -using namespace std; - -WordID kSTART; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input,i", po::value()->default_value("-"), "Input file") - ("format,f", po::value()->default_value("cfg"), "Input format. Values: cfg, json, split") - ("output,o", po::value()->default_value("json"), "Output command. Values: json, 1best") - ("reorder,r", "Add Yamada & Knight (2002) reorderings") - ("weights,w", po::value(), "Feature weights for k-best derivations [optional]") - ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") - ("k_derivations,k", po::value(), "Show k derivations and their features") - ("max_reorder,m", po::value()->default_value(999), "Move a constituent at most this far") - ("help,h", "Print this help message and exit"); - po::options_description clo("Command line options"); - po::options_description dcmdline_options; - dcmdline_options.add(opts); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - po::notify(*conf); - - if (conf->count("help") || conf->count("input") == 0) { - cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; - cerr << dcmdline_options << endl; - exit(1); - } -} - -int GetOrCreateNode(const WordID& lhs, map* lhs2node, Hypergraph* hg) { - int& node_id = (*lhs2node)[lhs]; - if (!node_id) - node_id = hg->AddNode(lhs)->id_ + 1; - return node_id - 1; -} - -void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { - if (goal < 0) { - cerr << "Error! [S] not found in grammar!\n"; - exit(1); - } - if (hg->nodes_[goal].in_edges_.size() != 1) { - cerr << "Error! [S] has more than one rewrite!\n"; - exit(1); - } - int old_size = hg->nodes_.size(); - hg->TopologicallySortNodesAndEdges(goal); - if (hg->nodes_.size() != old_size) { - cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; - } - vector inside; // inside score at each node - double p = Inside(*hg, &inside); - if (!p) { - cerr << "Warning! Grammar defines the empty language!\n"; - hg->clear(); - return; - } - vector prune(hg->edges_.size(), false); - int bad_edges = 0; - for (unsigned i = 0; i < hg->edges_.size(); ++i) { - Hypergraph::Edge& edge = hg->edges_[i]; - bool bad = false; - for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { - if (!inside[edge.tail_nodes_[j]]) { - bad = true; - ++bad_edges; - } - } - prune[i] = bad; - } - cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; - for (unsigned i = 0; i < hg->edges_.size(); ++i) { - if (prune[i]) - cerr << " " << hg->edges_[i].rule_->AsString() << endl; - } - hg->PruneEdges(prune); -} - -void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { - Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); - hg->ConnectEdgeToHeadNode(new_edge, head_node); - new_edge->feature_values_ = r->scores_; -} - -// from a category label like "NP_2", return "NP" -string PureCategory(WordID cat) { - assert(cat < 0); - string c = TD::Convert(cat*-1); - size_t p = c.find("_"); - if (p == string::npos) return c; - return c.substr(0, p); -}; - -string ConstituentOrderFeature(const TRule& rule, const vector& pi) { - const static string kTERM_VAR = "x"; - const vector& f = rule.f(); - map used; - vector terms(f.size()); - for (int i = 0; i < f.size(); ++i) { - const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); - int& count = used[term]; - if (!count) { - terms[i] = term; - } else { - ostringstream os; - os << term << count; - terms[i] = os.str(); - } - ++count; - } - ostringstream os; - os << PureCategory(rule.GetLHS()) << ':'; - for (int i = 0; i < f.size(); ++i) { - if (i > 0) os << '_'; - os << terms[pi[i]]; - } - return os.str(); -} - -bool CheckPermutationMask(const vector& mask, const vector& pi) { - assert(mask.size() == pi.size()); - - int req_min = -1; - int cur_max = 0; - int cur_mask = -1; - for (int i = 0; i < mask.size(); ++i) { - if (mask[i] != cur_mask) { - cur_mask = mask[i]; - req_min = cur_max - 1; - } - if (pi[i] > req_min) { - if (pi[i] > cur_max) cur_max = pi[i]; - } else { - return false; - } - } - - return true; -} - -void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { - // Hypergraph tmp = *hg; - Hypergraph::Node* node = &hg->nodes_[nodeid]; - if (node->in_edges_.size() != 1) { - cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; - cerr << " not recursing!\n"; - return; - } -// for (int eii = 0; eii < node->in_edges_.size(); ++eii) { - const int oe_index = node->in_edges_.front(); - const TRule& rule = *hg->edges_[oe_index].rule_; - const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; - const int tail_size = orig_tail.size(); - for (int i = 0; i < tail_size; ++i) { - PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); - } - const vector& of = rule.f_; - if (of.size() == 1) return; - // cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; - // cerr << "ORIG: " << rule.AsString() << endl; - vector pi(of.size(), 0); - for (int i = 0; i < pi.size(); ++i) pi[i] = i; - - vector permutation_mask(of.size(), 0); - const bool dont_reorder_across_PU = true; // TODO add configuration - if (dont_reorder_across_PU) { - int cur = 0; - for (int i = 0; i < pi.size(); ++i) { - if (of[i] >= 0) continue; - const string cat = PureCategory(of[i]); - if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { - ++cur; - permutation_mask[i] = cur; - ++cur; - } else { - permutation_mask[i] = cur; - } - } - } - int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); - hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); - while (next_permutation(pi.begin(), pi.end())) { - if (!CheckPermutationMask(permutation_mask, pi)) - continue; - vector nf(pi.size(), 0); - Hypergraph::TailNodeVector tail(pi.size(), 0); - bool skip = false; - for (int i = 0; i < pi.size(); ++i) { - int dist = pi[i] - i; if (dist < 0) dist *= -1; - if (dist > max_reorder) { skip = true; break; } - nf[i] = of[pi[i]]; - tail[i] = orig_tail[pi[i]]; - } - if (skip) continue; - TRulePtr nr(new TRule(rule)); - nr->f_ = nf; - int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); - nr->scores_.set_value(fid, 1.0); - // cerr << "PERM: " << nr->AsString() << endl; - CreateEdge(nr, tail, node, hg); - } - // } -} - -void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { - assert(hg->nodes_.back().cat_ == kSTART); - assert(hg->nodes_.back().in_edges_.size() == 1); - PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); -} - -void CollapseWeights(Hypergraph* hg) { - int fid = FD::Convert("Reordering"); - for (int i = 0; i < hg->edges_.size(); ++i) { - Hypergraph::Edge& edge = hg->edges_[i]; - edge.feature_values_.clear(); - if (edge.edge_prob_ != prob_t::Zero()) { - edge.feature_values_.set_value(fid, log(edge.edge_prob_)); - } - } -} - -void ProcessHypergraph(const vector& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { - if (conf.count("reorder")) - PermuteYamadaAndKnight(hg, conf["max_reorder"].as()); - if (w.size() > 0) { hg->Reweight(w); } - if (conf.count("collapse_weights")) CollapseWeights(hg); - if (conf["output"].as() == "json") { - HypergraphIO::WriteToJSON(*hg, false, &cout); - if (!ref.empty()) { cerr << "REF: " << ref << endl; } - } else { - vector onebest; - ViterbiESentence(*hg, &onebest); - if (ref.empty()) { - cout << TD::GetString(onebest) << endl; - } else { - cout << TD::GetString(onebest) << " ||| " << ref << endl; - } - } - if (conf.count("k_derivations")) { - const int k = conf["k_derivations"].as(); - KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k); - for (int i = 0; i < k; ++i) { - const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = - kbest.LazyKthBest(hg->nodes_.size() - 1, i); - if (!d) break; - cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; - } - } -} - -int main(int argc, char **argv) { - kSTART = TD::Convert("S") * -1; - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - string infile = conf["input"].as(); - const bool is_split_input = (conf["format"].as() == "split"); - const bool is_json_input = is_split_input || (conf["format"].as() == "json"); - const bool collapse_weights = conf.count("collapse_weights"); - vector w; - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &w); - - if (collapse_weights && !w.size()) { - cerr << "--collapse_weights requires a weights file to be specified!\n"; - exit(1); - } - ReadFile rf(infile); - istream* in = rf.stream(); - assert(*in); - int lc = 0; - Hypergraph hg; - map lhs2node; - while(*in) { - string line; - ++lc; - getline(*in, line); - if (is_json_input) { - if (line.empty() || line[0] == '#') continue; - string ref; - if (is_split_input) { - size_t pos = line.rfind("}}"); - assert(pos != string::npos); - size_t rstart = line.find("||| ", pos); - assert(rstart != string::npos); - ref = line.substr(rstart + 4); - line = line.substr(0, pos + 2); - } - istringstream is(line); - if (HypergraphIO::ReadFromJSON(&is, &hg)) { - ProcessHypergraph(w, conf, ref, &hg); - hg.clear(); - } else { - cerr << "Error reading grammar from JSON: line " << lc << endl; - exit(1); - } - } else { - if (line.empty()) { - int goal = lhs2node[kSTART] - 1; - FilterAndCheckCorrectness(goal, &hg); - ProcessHypergraph(w, conf, "", &hg); - hg.clear(); - lhs2node.clear(); - continue; - } - if (line[0] == '#') continue; - if (line[0] != '[') { - cerr << "Line " << lc << ": bad format\n"; - exit(1); - } - TRulePtr tr(TRule::CreateRuleMonolingual(line)); - Hypergraph::TailNodeVector tail; - for (int i = 0; i < tr->f_.size(); ++i) { - WordID var_cat = tr->f_[i]; - if (var_cat < 0) - tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); - } - const WordID lhs = tr->GetLHS(); - int head = GetOrCreateNode(lhs, &lhs2node, &hg); - Hypergraph::Edge* edge = hg.AddEdge(tr, tail); - edge->feature_values_ = tr->scores_; - Hypergraph::Node* node = &hg.nodes_[head]; - hg.ConnectEdgeToHeadNode(edge, node); - } - } -} - diff --git a/training/lbfgs.h b/training/lbfgs.h deleted file mode 100644 index e8baecab..00000000 --- a/training/lbfgs.h +++ /dev/null @@ -1,1459 +0,0 @@ -#ifndef SCITBX_LBFGS_H -#define SCITBX_LBFGS_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace scitbx { - -//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. -/*! Implementation of the - Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) - algorithm for large-scale multidimensional minimization - problems. - - This code was manually derived from Java code which was - in turn derived from the Fortran program - lbfgs.f. The Java translation was - effected mostly mechanically, with some manual - clean-up; in particular, array indices start at 0 - instead of 1. Most of the comments from the Fortran - code have been pasted in. - - Information on the original LBFGS Fortran source code is - available at - http://www.netlib.org/opt/lbfgs_um.shar . The following - information is taken verbatim from the Netlib documentation - for the Fortran source. - -
-    file    opt/lbfgs_um.shar
-    for     unconstrained optimization problems
-    alg     limited memory BFGS method
-    by      J. Nocedal
-    contact nocedal@eecs.nwu.edu
-    ref     D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
-    ,       large scale optimization methods'' Mathematical Programming 45
-    ,       (1989), pp. 503-528.
-    ,       (Postscript file of this paper is available via anonymous ftp
-    ,       to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
-    
- - @author Jorge Nocedal: original Fortran version, including comments - (July 1990).
- Robert Dodier: Java translation, August 1997.
- Ralf W. Grosse-Kunstleve: C++ port, March 2002.
- Chris Dyer: serialize/deserialize functionality - */ -namespace lbfgs { - - //! Generic exception class for %lbfgs %error messages. - /*! All exceptions thrown by the minimizer are derived from this class. - */ - class error : public std::exception { - public: - //! Constructor. - error(std::string const& msg) throw() - : msg_("lbfgs error: " + msg) - {} - //! Access to error message. - virtual const char* what() const throw() { return msg_.c_str(); } - protected: - virtual ~error() throw() {} - std::string msg_; - public: - static std::string itoa(unsigned long i) { - std::ostringstream os; - os << i; - return os.str(); - } - }; - - //! Specific exception class. - class error_internal_error : public error { - public: - //! Constructor. - error_internal_error(const char* file, unsigned long line) throw() - : error( - "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") - {} - }; - - //! Specific exception class. - class error_improper_input_parameter : public error { - public: - //! Constructor. - error_improper_input_parameter(std::string const& msg) throw() - : error("Improper input parameter: " + msg) - {} - }; - - //! Specific exception class. - class error_improper_input_data : public error { - public: - //! Constructor. - error_improper_input_data(std::string const& msg) throw() - : error("Improper input data: " + msg) - {} - }; - - //! Specific exception class. - class error_search_direction_not_descent : public error { - public: - //! Constructor. - error_search_direction_not_descent() throw() - : error("The search direction is not a descent direction.") - {} - }; - - //! Specific exception class. - class error_line_search_failed : public error { - public: - //! Constructor. - error_line_search_failed(std::string const& msg) throw() - : error("Line search failed: " + msg) - {} - }; - - //! Specific exception class. - class error_line_search_failed_rounding_errors - : public error_line_search_failed { - public: - //! Constructor. - error_line_search_failed_rounding_errors(std::string const& msg) throw() - : error_line_search_failed(msg) - {} - }; - - namespace detail { - - template - inline - NumType - pow2(NumType const& x) { return x * x; } - - template - inline - NumType - abs(NumType const& x) { - if (x < NumType(0)) return -x; - return x; - } - - // This class implements an algorithm for multi-dimensional line search. - template - class mcsrch - { - protected: - int infoc; - FloatType dginit; - bool brackt; - bool stage1; - FloatType finit; - FloatType dgtest; - FloatType width; - FloatType width1; - FloatType stx; - FloatType fx; - FloatType dgx; - FloatType sty; - FloatType fy; - FloatType dgy; - FloatType stmin; - FloatType stmax; - - static FloatType const& max3( - FloatType const& x, - FloatType const& y, - FloatType const& z) - { - return x < y ? (y < z ? z : y ) : (x < z ? z : x ); - } - - public: - /* Minimize a function along a search direction. This code is - a Java translation of the function MCSRCH from - lbfgs.f, which in turn is a slight modification - of the subroutine CSRCH of More' and Thuente. - The changes are to allow reverse communication, and do not - affect the performance of the routine. This function, in turn, - calls mcstep.

- - The Java translation was effected mostly mechanically, with - some manual clean-up; in particular, array indices start at 0 - instead of 1. Most of the comments from the Fortran code have - been pasted in here as well.

- - The purpose of mcsrch is to find a step which - satisfies a sufficient decrease condition and a curvature - condition.

- - At each stage this function updates an interval of uncertainty - with endpoints stx and sty. The - interval of uncertainty is initially chosen so that it - contains a minimizer of the modified function -

-                f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
-           
- If a step is obtained for which the modified function has a - nonpositive function value and nonnegative derivative, then - the interval of uncertainty is chosen so that it contains a - minimizer of f(x+stp*s).

- - The algorithm is designed to find a step which satisfies - the sufficient decrease condition -

-                 f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s),
-           
- and the curvature condition -
-                 abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s).
-           
- If ftol is less than gtol and if, - for example, the function is bounded below, then there is - always a step which satisfies both conditions. If no step can - be found which satisfies both conditions, then the algorithm - usually stops when rounding errors prevent further progress. - In this case stp only satisfies the sufficient - decrease condition.

- - @author Original Fortran version by Jorge J. More' and - David J. Thuente as part of the Minpack project, June 1983, - Argonne National Laboratory. Java translation by Robert - Dodier, August 1997. - - @param n The number of variables. - - @param x On entry this contains the base point for the line - search. On exit it contains x + stp*s. - - @param f On entry this contains the value of the objective - function at x. On exit it contains the value - of the objective function at x + stp*s. - - @param g On entry this contains the gradient of the objective - function at x. On exit it contains the gradient - at x + stp*s. - - @param s The search direction. - - @param stp On entry this contains an initial estimate of a - satifactory step length. On exit stp contains - the final estimate. - - @param ftol Tolerance for the sufficient decrease condition. - - @param xtol Termination occurs when the relative width of the - interval of uncertainty is at most xtol. - - @param maxfev Termination occurs when the number of evaluations - of the objective function is at least maxfev by - the end of an iteration. - - @param info This is an output variable, which can have these - values: -

    -
  • info = -1 A return is made to compute - the function and gradient. -
  • info = 1 The sufficient decrease condition - and the directional derivative condition hold. -
- - @param nfev On exit, this is set to the number of function - evaluations. - - @param wa Temporary storage array, of length n. - */ - void run( - FloatType const& gtol, - FloatType const& stpmin, - FloatType const& stpmax, - SizeType n, - FloatType* x, - FloatType f, - const FloatType* g, - FloatType* s, - SizeType is0, - FloatType& stp, - FloatType ftol, - FloatType xtol, - SizeType maxfev, - int& info, - SizeType& nfev, - FloatType* wa); - - /* The purpose of this function is to compute a safeguarded step - for a linesearch and to update an interval of uncertainty for - a minimizer of the function.

- - The parameter stx contains the step with the - least function value. The parameter stp contains - the current step. It is assumed that the derivative at - stx is negative in the direction of the step. If - brackt is true when - mcstep returns then a minimizer has been - bracketed in an interval of uncertainty with endpoints - stx and sty.

- - Variables that must be modified by mcstep are - implemented as 1-element arrays. - - @param stx Step at the best step obtained so far. - This variable is modified by mcstep. - @param fx Function value at the best step obtained so far. - This variable is modified by mcstep. - @param dx Derivative at the best step obtained so far. - The derivative must be negative in the direction of the - step, that is, dx and stp-stx must - have opposite signs. This variable is modified by - mcstep. - - @param sty Step at the other endpoint of the interval of - uncertainty. This variable is modified by mcstep. - @param fy Function value at the other endpoint of the interval - of uncertainty. This variable is modified by - mcstep. - - @param dy Derivative at the other endpoint of the interval of - uncertainty. This variable is modified by mcstep. - - @param stp Step at the current step. If brackt is set - then on input stp must be between stx - and sty. On output stp is set to the - new step. - @param fp Function value at the current step. - @param dp Derivative at the current step. - - @param brackt Tells whether a minimizer has been bracketed. - If the minimizer has not been bracketed, then on input this - variable must be set false. If the minimizer has - been bracketed, then on output this variable is - true. - - @param stpmin Lower bound for the step. - @param stpmax Upper bound for the step. - - If the return value is 1, 2, 3, or 4, then the step has - been computed successfully. A return value of 0 indicates - improper input parameters. - - @author Jorge J. More, David J. Thuente: original Fortran version, - as part of Minpack project. Argonne Nat'l Laboratory, June 1983. - Robert Dodier: Java translation, August 1997. - */ - static int mcstep( - FloatType& stx, - FloatType& fx, - FloatType& dx, - FloatType& sty, - FloatType& fy, - FloatType& dy, - FloatType& stp, - FloatType fp, - FloatType dp, - bool& brackt, - FloatType stpmin, - FloatType stpmax); - - void serialize(std::ostream* out) const { - out->write((const char*)&infoc,sizeof(infoc)); - out->write((const char*)&dginit,sizeof(dginit)); - out->write((const char*)&brackt,sizeof(brackt)); - out->write((const char*)&stage1,sizeof(stage1)); - out->write((const char*)&finit,sizeof(finit)); - out->write((const char*)&dgtest,sizeof(dgtest)); - out->write((const char*)&width,sizeof(width)); - out->write((const char*)&width1,sizeof(width1)); - out->write((const char*)&stx,sizeof(stx)); - out->write((const char*)&fx,sizeof(fx)); - out->write((const char*)&dgx,sizeof(dgx)); - out->write((const char*)&sty,sizeof(sty)); - out->write((const char*)&fy,sizeof(fy)); - out->write((const char*)&dgy,sizeof(dgy)); - out->write((const char*)&stmin,sizeof(stmin)); - out->write((const char*)&stmax,sizeof(stmax)); - } - - void deserialize(std::istream* in) const { - in->read((char*)&infoc, sizeof(infoc)); - in->read((char*)&dginit, sizeof(dginit)); - in->read((char*)&brackt, sizeof(brackt)); - in->read((char*)&stage1, sizeof(stage1)); - in->read((char*)&finit, sizeof(finit)); - in->read((char*)&dgtest, sizeof(dgtest)); - in->read((char*)&width, sizeof(width)); - in->read((char*)&width1, sizeof(width1)); - in->read((char*)&stx, sizeof(stx)); - in->read((char*)&fx, sizeof(fx)); - in->read((char*)&dgx, sizeof(dgx)); - in->read((char*)&sty, sizeof(sty)); - in->read((char*)&fy, sizeof(fy)); - in->read((char*)&dgy, sizeof(dgy)); - in->read((char*)&stmin, sizeof(stmin)); - in->read((char*)&stmax, sizeof(stmax)); - } - }; - - template - void mcsrch::run( - FloatType const& gtol, - FloatType const& stpmin, - FloatType const& stpmax, - SizeType n, - FloatType* x, - FloatType f, - const FloatType* g, - FloatType* s, - SizeType is0, - FloatType& stp, - FloatType ftol, - FloatType xtol, - SizeType maxfev, - int& info, - SizeType& nfev, - FloatType* wa) - { - if (info != -1) { - infoc = 1; - if ( n == 0 - || maxfev == 0 - || gtol < FloatType(0) - || xtol < FloatType(0) - || stpmin < FloatType(0) - || stpmax < stpmin) { - throw error_internal_error(__FILE__, __LINE__); - } - if (stp <= FloatType(0) || ftol < FloatType(0)) { - throw error_internal_error(__FILE__, __LINE__); - } - // Compute the initial gradient in the search direction - // and check that s is a descent direction. - dginit = FloatType(0); - for (SizeType j = 0; j < n; j++) { - dginit += g[j] * s[is0+j]; - } - if (dginit >= FloatType(0)) { - throw error_search_direction_not_descent(); - } - brackt = false; - stage1 = true; - nfev = 0; - finit = f; - dgtest = ftol*dginit; - width = stpmax - stpmin; - width1 = FloatType(2) * width; - std::copy(x, x+n, wa); - // The variables stx, fx, dgx contain the values of the step, - // function, and directional derivative at the best step. - // The variables sty, fy, dgy contain the value of the step, - // function, and derivative at the other endpoint of - // the interval of uncertainty. - // The variables stp, f, dg contain the values of the step, - // function, and derivative at the current step. - stx = FloatType(0); - fx = finit; - dgx = dginit; - sty = FloatType(0); - fy = finit; - dgy = dginit; - } - for (;;) { - if (info != -1) { - // Set the minimum and maximum steps to correspond - // to the present interval of uncertainty. - if (brackt) { - stmin = std::min(stx, sty); - stmax = std::max(stx, sty); - } - else { - stmin = stx; - stmax = stp + FloatType(4) * (stp - stx); - } - // Force the step to be within the bounds stpmax and stpmin. - stp = std::max(stp, stpmin); - stp = std::min(stp, stpmax); - // If an unusual termination is to occur then let - // stp be the lowest point obtained so far. - if ( (brackt && (stp <= stmin || stp >= stmax)) - || nfev >= maxfev - 1 || infoc == 0 - || (brackt && stmax - stmin <= xtol * stmax)) { - stp = stx; - } - // Evaluate the function and gradient at stp - // and compute the directional derivative. - // We return to main program to obtain F and G. - for (SizeType j = 0; j < n; j++) { - x[j] = wa[j] + stp * s[is0+j]; - } - info=-1; - break; - } - info = 0; - nfev++; - FloatType dg(0); - for (SizeType j = 0; j < n; j++) { - dg += g[j] * s[is0+j]; - } - FloatType ftest1 = finit + stp*dgtest; - // Test for convergence. - if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { - throw error_line_search_failed_rounding_errors( - "Rounding errors prevent further progress." - " There may not be a step which satisfies the" - " sufficient decrease and curvature conditions." - " Tolerances may be too small."); - } - if (stp == stpmax && f <= ftest1 && dg <= dgtest) { - throw error_line_search_failed( - "The step is at the upper bound stpmax()."); - } - if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { - throw error_line_search_failed( - "The step is at the lower bound stpmin()."); - } - if (nfev >= maxfev) { - throw error_line_search_failed( - "Number of function evaluations has reached maxfev()."); - } - if (brackt && stmax - stmin <= xtol * stmax) { - throw error_line_search_failed( - "Relative width of the interval of uncertainty" - " is at most xtol()."); - } - // Check for termination. - if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { - info = 1; - break; - } - // In the first stage we seek a step for which the modified - // function has a nonpositive value and nonnegative derivative. - if ( stage1 && f <= ftest1 - && dg >= std::min(ftol, gtol) * dginit) { - stage1 = false; - } - // A modified function is used to predict the step only if - // we have not obtained a step for which the modified - // function has a nonpositive function value and nonnegative - // derivative, and if a lower function value has been - // obtained but the decrease is not sufficient. - if (stage1 && f <= fx && f > ftest1) { - // Define the modified function and derivative values. - FloatType fm = f - stp*dgtest; - FloatType fxm = fx - stx*dgtest; - FloatType fym = fy - sty*dgtest; - FloatType dgm = dg - dgtest; - FloatType dgxm = dgx - dgtest; - FloatType dgym = dgy - dgtest; - // Call cstep to update the interval of uncertainty - // and to compute the new step. - infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, - brackt, stmin, stmax); - // Reset the function and gradient values for f. - fx = fxm + stx*dgtest; - fy = fym + sty*dgtest; - dgx = dgxm + dgtest; - dgy = dgym + dgtest; - } - else { - // Call mcstep to update the interval of uncertainty - // and to compute the new step. - infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, - brackt, stmin, stmax); - } - // Force a sufficient decrease in the size of the - // interval of uncertainty. - if (brackt) { - if (abs(sty - stx) >= FloatType(0.66) * width1) { - stp = stx + FloatType(0.5) * (sty - stx); - } - width1 = width; - width = abs(sty - stx); - } - } - } - - template - int mcsrch::mcstep( - FloatType& stx, - FloatType& fx, - FloatType& dx, - FloatType& sty, - FloatType& fy, - FloatType& dy, - FloatType& stp, - FloatType fp, - FloatType dp, - bool& brackt, - FloatType stpmin, - FloatType stpmax) - { - bool bound; - FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; - int info = 0; - if ( ( brackt && (stp <= std::min(stx, sty) - || stp >= std::max(stx, sty))) - || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { - return 0; - } - // Determine if the derivatives have opposite sign. - sgnd = dp * (dx / abs(dx)); - if (fp > fx) { - // First case. A higher function value. - // The minimum is bracketed. If the cubic step is closer - // to stx than the quadratic step, the cubic step is taken, - // else the average of the cubic and quadratic steps is taken. - info = 1; - bound = true; - theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; - s = max3(abs(theta), abs(dx), abs(dp)); - gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); - if (stp < stx) gamma = - gamma; - p = (gamma - dx) + theta; - q = ((gamma - dx) + gamma) + dp; - r = p/q; - stpc = stx + r * (stp - stx); - stpq = stx - + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) - * (stp - stx); - if (abs(stpc - stx) < abs(stpq - stx)) { - stpf = stpc; - } - else { - stpf = stpc + (stpq - stpc) / FloatType(2); - } - brackt = true; - } - else if (sgnd < FloatType(0)) { - // Second case. A lower function value and derivatives of - // opposite sign. The minimum is bracketed. If the cubic - // step is closer to stx than the quadratic (secant) step, - // the cubic step is taken, else the quadratic step is taken. - info = 2; - bound = false; - theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; - s = max3(abs(theta), abs(dx), abs(dp)); - gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); - if (stp > stx) gamma = - gamma; - p = (gamma - dp) + theta; - q = ((gamma - dp) + gamma) + dx; - r = p/q; - stpc = stp + r * (stx - stp); - stpq = stp + (dp / (dp - dx)) * (stx - stp); - if (abs(stpc - stp) > abs(stpq - stp)) { - stpf = stpc; - } - else { - stpf = stpq; - } - brackt = true; - } - else if (abs(dp) < abs(dx)) { - // Third case. A lower function value, derivatives of the - // same sign, and the magnitude of the derivative decreases. - // The cubic step is only used if the cubic tends to infinity - // in the direction of the step or if the minimum of the cubic - // is beyond stp. Otherwise the cubic step is defined to be - // either stpmin or stpmax. The quadratic (secant) step is also - // computed and if the minimum is bracketed then the the step - // closest to stx is taken, else the step farthest away is taken. - info = 3; - bound = true; - theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; - s = max3(abs(theta), abs(dx), abs(dp)); - gamma = s * std::sqrt( - std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); - if (stp > stx) gamma = -gamma; - p = (gamma - dp) + theta; - q = (gamma + (dx - dp)) + gamma; - r = p/q; - if (r < FloatType(0) && gamma != FloatType(0)) { - stpc = stp + r * (stx - stp); - } - else if (stp > stx) { - stpc = stpmax; - } - else { - stpc = stpmin; - } - stpq = stp + (dp / (dp - dx)) * (stx - stp); - if (brackt) { - if (abs(stp - stpc) < abs(stp - stpq)) { - stpf = stpc; - } - else { - stpf = stpq; - } - } - else { - if (abs(stp - stpc) > abs(stp - stpq)) { - stpf = stpc; - } - else { - stpf = stpq; - } - } - } - else { - // Fourth case. A lower function value, derivatives of the - // same sign, and the magnitude of the derivative does - // not decrease. If the minimum is not bracketed, the step - // is either stpmin or stpmax, else the cubic step is taken. - info = 4; - bound = false; - if (brackt) { - theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; - s = max3(abs(theta), abs(dy), abs(dp)); - gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); - if (stp > sty) gamma = -gamma; - p = (gamma - dp) + theta; - q = ((gamma - dp) + gamma) + dy; - r = p/q; - stpc = stp + r * (sty - stp); - stpf = stpc; - } - else if (stp > stx) { - stpf = stpmax; - } - else { - stpf = stpmin; - } - } - // Update the interval of uncertainty. This update does not - // depend on the new step or the case analysis above. - if (fp > fx) { - sty = stp; - fy = fp; - dy = dp; - } - else { - if (sgnd < FloatType(0)) { - sty = stx; - fy = fx; - dy = dx; - } - stx = stp; - fx = fp; - dx = dp; - } - // Compute the new step and safeguard it. - stpf = std::min(stpmax, stpf); - stpf = std::max(stpmin, stpf); - stp = stpf; - if (brackt && bound) { - if (sty > stx) { - stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); - } - else { - stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); - } - } - return info; - } - - /* Compute the sum of a vector times a scalar plus another vector. - Adapted from the subroutine daxpy in - lbfgs.f. - */ - template - void daxpy( - SizeType n, - FloatType da, - const FloatType* dx, - SizeType ix0, - SizeType incx, - FloatType* dy, - SizeType iy0, - SizeType incy) - { - SizeType i, ix, iy, m; - if (n == 0) return; - if (da == FloatType(0)) return; - if (!(incx == 1 && incy == 1)) { - ix = 0; - iy = 0; - for (i = 0; i < n; i++) { - dy[iy0+iy] += da * dx[ix0+ix]; - ix += incx; - iy += incy; - } - return; - } - m = n % 4; - for (i = 0; i < m; i++) { - dy[iy0+i] += da * dx[ix0+i]; - } - for (; i < n;) { - dy[iy0+i] += da * dx[ix0+i]; i++; - dy[iy0+i] += da * dx[ix0+i]; i++; - dy[iy0+i] += da * dx[ix0+i]; i++; - dy[iy0+i] += da * dx[ix0+i]; i++; - } - } - - template - inline - void daxpy( - SizeType n, - FloatType da, - const FloatType* dx, - SizeType ix0, - FloatType* dy) - { - daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); - } - - /* Compute the dot product of two vectors. - Adapted from the subroutine ddot - in lbfgs.f. - */ - template - FloatType ddot( - SizeType n, - const FloatType* dx, - SizeType ix0, - SizeType incx, - const FloatType* dy, - SizeType iy0, - SizeType incy) - { - SizeType i, ix, iy, m; - FloatType dtemp(0); - if (n == 0) return FloatType(0); - if (!(incx == 1 && incy == 1)) { - ix = 0; - iy = 0; - for (i = 0; i < n; i++) { - dtemp += dx[ix0+ix] * dy[iy0+iy]; - ix += incx; - iy += incy; - } - return dtemp; - } - m = n % 5; - for (i = 0; i < m; i++) { - dtemp += dx[ix0+i] * dy[iy0+i]; - } - for (; i < n;) { - dtemp += dx[ix0+i] * dy[iy0+i]; i++; - dtemp += dx[ix0+i] * dy[iy0+i]; i++; - dtemp += dx[ix0+i] * dy[iy0+i]; i++; - dtemp += dx[ix0+i] * dy[iy0+i]; i++; - dtemp += dx[ix0+i] * dy[iy0+i]; i++; - } - return dtemp; - } - - template - inline - FloatType ddot( - SizeType n, - const FloatType* dx, - const FloatType* dy) - { - return ddot( - n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); - } - - } // namespace detail - - //! Interface to the LBFGS %minimizer. - /*! This class solves the unconstrained minimization problem -

-          min f(x),  x = (x1,x2,...,x_n),
-      
- using the limited-memory BFGS method. The routine is - especially effective on problems involving a large number of - variables. In a typical iteration of this method an - approximation Hk to the inverse of the Hessian - is obtained by applying m BFGS updates to a - diagonal matrix Hk0, using information from the - previous m steps. The user specifies the number - m, which determines the amount of storage - required by the routine. The user may also provide the - diagonal matrices Hk0 (parameter diag in the run() - function) if not satisfied with the default choice. The - algorithm is described in "On the limited memory BFGS method for - large scale optimization", by D. Liu and J. Nocedal, Mathematical - Programming B 45 (1989) 503-528. - - The user is required to calculate the function value - f and its gradient g. In order to - allow the user complete control over these computations, - reverse communication is used. The routine must be called - repeatedly under the control of the member functions - requests_f_and_g(), - requests_diag(). - If neither requests_f_and_g() nor requests_diag() is - true the user should check for convergence - (using class traditional_convergence_test or any - other custom test). If the convergence test is negative, - the minimizer may be called again for the next iteration. - - The steplength (stp()) is determined at each iteration - by means of the line search routine mcsrch, which is - a slight modification of the routine CSRCH written - by More' and Thuente. - - The only variables that are machine-dependent are - xtol, - stpmin and - stpmax. - - Fatal errors cause error exceptions to be thrown. - The generic class error is sub-classed (e.g. - class error_line_search_failed) to facilitate - granular %error handling. - - A note on performance: Using Compaq Fortran V5.4 and - Compaq C++ V6.5, the C++ implementation is about 15% slower - than the Fortran implementation. - */ - template - class minimizer - { - public: - //! Default constructor. Some members are not initialized! - minimizer() - : n_(0), m_(0), maxfev_(0), - gtol_(0), xtol_(0), - stpmin_(0), stpmax_(0), - ispt(0), iypt(0) - {} - - //! Constructor. - /*! @param n The number of variables in the minimization problem. - Restriction: n > 0. - - @param m The number of corrections used in the BFGS update. - Values of m less than 3 are not recommended; - large values of m will result in excessive - computing time. 3 <= m <= 7 is - recommended. - Restriction: m > 0. - - @param maxfev Maximum number of function evaluations - per line search. - Termination occurs when the number of evaluations - of the objective function is at least maxfev by - the end of an iteration. - - @param gtol Controls the accuracy of the line search. - If the function and gradient evaluations are inexpensive with - respect to the cost of the iteration (which is sometimes the - case when solving very large problems) it may be advantageous - to set gtol to a small value. A typical small - value is 0.1. - Restriction: gtol should be greater than 1e-4. - - @param xtol An estimate of the machine precision (e.g. 10e-16 - on a SUN station 3/60). The line search routine will - terminate if the relative width of the interval of - uncertainty is less than xtol. - - @param stpmin Specifies the lower bound for the step - in the line search. - The default value is 1e-20. This value need not be modified - unless the exponent is too large for the machine being used, - or unless the problem is extremely badly scaled (in which - case the exponent should be increased). - - @param stpmax specifies the upper bound for the step - in the line search. - The default value is 1e20. This value need not be modified - unless the exponent is too large for the machine being used, - or unless the problem is extremely badly scaled (in which - case the exponent should be increased). - */ - explicit - minimizer( - SizeType n, - SizeType m = 5, - SizeType maxfev = 20, - FloatType gtol = FloatType(0.9), - FloatType xtol = FloatType(1.e-16), - FloatType stpmin = FloatType(1.e-20), - FloatType stpmax = FloatType(1.e20)) - : n_(n), m_(m), maxfev_(maxfev), - gtol_(gtol), xtol_(xtol), - stpmin_(stpmin), stpmax_(stpmax), - iflag_(0), requests_f_and_g_(false), requests_diag_(false), - iter_(0), nfun_(0), stp_(0), - stp1(0), ftol(0.0001), ys(0), point(0), npt(0), - ispt(n+2*m), iypt((n+2*m)+n*m), - info(0), bound(0), nfev(0) - { - if (n_ == 0) { - throw error_improper_input_parameter("n = 0."); - } - if (m_ == 0) { - throw error_improper_input_parameter("m = 0."); - } - if (maxfev_ == 0) { - throw error_improper_input_parameter("maxfev = 0."); - } - if (gtol_ <= FloatType(1.e-4)) { - throw error_improper_input_parameter("gtol <= 1.e-4."); - } - if (xtol_ < FloatType(0)) { - throw error_improper_input_parameter("xtol < 0."); - } - if (stpmin_ < FloatType(0)) { - throw error_improper_input_parameter("stpmin < 0."); - } - if (stpmax_ < stpmin) { - throw error_improper_input_parameter("stpmax < stpmin"); - } - w_.resize(n_*(2*m_+1)+2*m_); - scratch_array_.resize(n_); - } - - //! Number of free parameters (as passed to the constructor). - SizeType n() const { return n_; } - - //! Number of corrections kept (as passed to the constructor). - SizeType m() const { return m_; } - - /*! \brief Maximum number of evaluations of the objective function - per line search (as passed to the constructor). - */ - SizeType maxfev() const { return maxfev_; } - - /*! \brief Control of the accuracy of the line search. - (as passed to the constructor). - */ - FloatType gtol() const { return gtol_; } - - //! Estimate of the machine precision (as passed to the constructor). - FloatType xtol() const { return xtol_; } - - /*! \brief Lower bound for the step in the line search. - (as passed to the constructor). - */ - FloatType stpmin() const { return stpmin_; } - - /*! \brief Upper bound for the step in the line search. - (as passed to the constructor). - */ - FloatType stpmax() const { return stpmax_; } - - //! Status indicator for reverse communication. - /*! true if the run() function returns to request - evaluation of the objective function (f) and - gradients (g) for the current point - (x). To continue the minimization the - run() function is called again with the updated values for - f and g. -

- See also: requests_diag() - */ - bool requests_f_and_g() const { return requests_f_and_g_; } - - //! Status indicator for reverse communication. - /*! true if the run() function returns to request - evaluation of the diagonal matrix (diag) - for the current point (x). - To continue the minimization the run() function is called - again with the updated values for diag. -

- See also: requests_f_and_g() - */ - bool requests_diag() const { return requests_diag_; } - - //! Number of iterations so far. - /*! Note that one iteration may involve multiple evaluations - of the objective function. -

- See also: nfun() - */ - SizeType iter() const { return iter_; } - - //! Total number of evaluations of the objective function so far. - /*! The total number of function evaluations increases by the - number of evaluations required for the line search. The total - is only increased after a successful line search. -

- See also: iter() - */ - SizeType nfun() const { return nfun_; } - - //! Norm of gradient given gradient array of length n(). - FloatType euclidean_norm(const FloatType* a) const { - return std::sqrt(detail::ddot(n_, a, a)); - } - - //! Current stepsize. - FloatType stp() const { return stp_; } - - //! Execution of one step of the minimization. - /*! @param x On initial entry this must be set by the user to - the values of the initial estimate of the solution vector. - - @param f Before initial entry or on re-entry under the - control of requests_f_and_g(), f must be set - by the user to contain the value of the objective function - at the current point x. - - @param g Before initial entry or on re-entry under the - control of requests_f_and_g(), g must be set - by the user to contain the components of the gradient at - the current point x. - - The return value is true if either - requests_f_and_g() or requests_diag() is true. - Otherwise the user should check for convergence - (e.g. using class traditional_convergence_test) and - call the run() function again to continue the minimization. - If the return value is false the user - should not update f, g or - diag (other overload) before calling - the run() function again. - - Note that x is always modified by the run() - function. Depending on the situation it can therefore be - necessary to evaluate the objective function one more time - after the minimization is terminated. - */ - bool run( - FloatType* x, - FloatType f, - const FloatType* g) - { - return generic_run(x, f, g, false, 0); - } - - //! Execution of one step of the minimization. - /*! @param x See other overload. - - @param f See other overload. - - @param g See other overload. - - @param diag On initial entry or on re-entry under the - control of requests_diag(), diag must be set by - the user to contain the values of the diagonal matrix Hk0. - The routine will return at each iteration of the algorithm - with requests_diag() set to true. -

- Restriction: all elements of diag must be - positive. - */ - bool run( - FloatType* x, - FloatType f, - const FloatType* g, - const FloatType* diag) - { - return generic_run(x, f, g, true, diag); - } - - void serialize(std::ostream* out) const { - out->write((const char*)&n_, sizeof(n_)); // sanity check - out->write((const char*)&m_, sizeof(m_)); // sanity check - SizeType fs = sizeof(FloatType); - out->write((const char*)&fs, sizeof(fs)); // sanity check - - mcsrch_instance.serialize(out); - out->write((const char*)&iflag_, sizeof(iflag_)); - out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); - out->write((const char*)&requests_diag_, sizeof(requests_diag_)); - out->write((const char*)&iter_, sizeof(iter_)); - out->write((const char*)&nfun_, sizeof(nfun_)); - out->write((const char*)&stp_, sizeof(stp_)); - out->write((const char*)&stp1, sizeof(stp1)); - out->write((const char*)&ftol, sizeof(ftol)); - out->write((const char*)&ys, sizeof(ys)); - out->write((const char*)&point, sizeof(point)); - out->write((const char*)&npt, sizeof(npt)); - out->write((const char*)&info, sizeof(info)); - out->write((const char*)&bound, sizeof(bound)); - out->write((const char*)&nfev, sizeof(nfev)); - out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); - out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); - } - - void deserialize(std::istream* in) { - SizeType n, m, fs; - in->read((char*)&n, sizeof(n)); - in->read((char*)&m, sizeof(m)); - in->read((char*)&fs, sizeof(fs)); - assert(n == n_); - assert(m == m_); - assert(fs == sizeof(FloatType)); - - mcsrch_instance.deserialize(in); - in->read((char*)&iflag_, sizeof(iflag_)); - in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); - in->read((char*)&requests_diag_, sizeof(requests_diag_)); - in->read((char*)&iter_, sizeof(iter_)); - in->read((char*)&nfun_, sizeof(nfun_)); - in->read((char*)&stp_, sizeof(stp_)); - in->read((char*)&stp1, sizeof(stp1)); - in->read((char*)&ftol, sizeof(ftol)); - in->read((char*)&ys, sizeof(ys)); - in->read((char*)&point, sizeof(point)); - in->read((char*)&npt, sizeof(npt)); - in->read((char*)&info, sizeof(info)); - in->read((char*)&bound, sizeof(bound)); - in->read((char*)&nfev, sizeof(nfev)); - in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); - in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); - } - - protected: - static void throw_diagonal_element_not_positive(SizeType i) { - throw error_improper_input_data( - "The " + error::itoa(i) + ". diagonal element of the" - " inverse Hessian approximation is not positive."); - } - - bool generic_run( - FloatType* x, - FloatType f, - const FloatType* g, - bool diagco, - const FloatType* diag); - - detail::mcsrch mcsrch_instance; - const SizeType n_; - const SizeType m_; - const SizeType maxfev_; - const FloatType gtol_; - const FloatType xtol_; - const FloatType stpmin_; - const FloatType stpmax_; - int iflag_; - bool requests_f_and_g_; - bool requests_diag_; - SizeType iter_; - SizeType nfun_; - FloatType stp_; - FloatType stp1; - FloatType ftol; - FloatType ys; - SizeType point; - SizeType npt; - const SizeType ispt; - const SizeType iypt; - int info; - SizeType bound; - SizeType nfev; - std::vector w_; - std::vector scratch_array_; - }; - - template - bool minimizer::generic_run( - FloatType* x, - FloatType f, - const FloatType* g, - bool diagco, - const FloatType* diag) - { - bool execute_entire_while_loop = false; - if (!(requests_f_and_g_ || requests_diag_)) { - execute_entire_while_loop = true; - } - requests_f_and_g_ = false; - requests_diag_ = false; - FloatType* w = &(*(w_.begin())); - if (iflag_ == 0) { // Initialize. - nfun_ = 1; - if (diagco) { - for (SizeType i = 0; i < n_; i++) { - if (diag[i] <= FloatType(0)) { - throw_diagonal_element_not_positive(i); - } - } - } - else { - std::fill_n(scratch_array_.begin(), n_, FloatType(1)); - diag = &(*(scratch_array_.begin())); - } - for (SizeType i = 0; i < n_; i++) { - w[ispt + i] = -g[i] * diag[i]; - } - FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); - if (gnorm == FloatType(0)) return false; - stp1 = FloatType(1) / gnorm; - execute_entire_while_loop = true; - } - if (execute_entire_while_loop) { - bound = iter_; - iter_++; - info = 0; - if (iter_ != 1) { - if (iter_ > m_) bound = m_; - ys = detail::ddot( - n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); - if (!diagco) { - FloatType yy = detail::ddot( - n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); - std::fill_n(scratch_array_.begin(), n_, ys / yy); - diag = &(*(scratch_array_.begin())); - } - else { - iflag_ = 2; - requests_diag_ = true; - return true; - } - } - } - if (execute_entire_while_loop || iflag_ == 2) { - if (iter_ != 1) { - if (diag == 0) { - throw error_internal_error(__FILE__, __LINE__); - } - if (diagco) { - for (SizeType i = 0; i < n_; i++) { - if (diag[i] <= FloatType(0)) { - throw_diagonal_element_not_positive(i); - } - } - } - SizeType cp = point; - if (point == 0) cp = m_; - w[n_ + cp -1] = 1 / ys; - SizeType i; - for (i = 0; i < n_; i++) { - w[i] = -g[i]; - } - cp = point; - for (i = 0; i < bound; i++) { - if (cp == 0) cp = m_; - cp--; - FloatType sq = detail::ddot( - n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); - SizeType inmc=n_+m_+cp; - SizeType iycn=iypt+cp*n_; - w[inmc] = w[n_ + cp] * sq; - detail::daxpy(n_, -w[inmc], w, iycn, w); - } - for (i = 0; i < n_; i++) { - w[i] *= diag[i]; - } - for (i = 0; i < bound; i++) { - FloatType yr = detail::ddot( - n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); - FloatType beta = w[n_ + cp] * yr; - SizeType inmc=n_+m_+cp; - beta = w[inmc] - beta; - SizeType iscn=ispt+cp*n_; - detail::daxpy(n_, beta, w, iscn, w); - cp++; - if (cp == m_) cp = 0; - } - std::copy(w, w+n_, w+(ispt + point * n_)); - } - stp_ = FloatType(1); - if (iter_ == 1) stp_ = stp1; - std::copy(g, g+n_, w); - } - mcsrch_instance.run( - gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, - stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); - if (info == -1) { - iflag_ = 1; - requests_f_and_g_ = true; - return true; - } - if (info != 1) { - throw error_internal_error(__FILE__, __LINE__); - } - nfun_ += nfev; - npt = point*n_; - for (SizeType i = 0; i < n_; i++) { - w[ispt + npt + i] = stp_ * w[ispt + npt + i]; - w[iypt + npt + i] = g[i] - w[i]; - } - point++; - if (point == m_) point = 0; - return false; - } - - //! Traditional LBFGS convergence test. - /*! This convergence test is equivalent to the test embedded - in the lbfgs.f Fortran code. The test assumes that - there is a meaningful relation between the Euclidean norm of the - parameter vector x and the norm of the gradient - vector g. Therefore this test should not be used if - this assumption is not correct for a given problem. - */ - template - class traditional_convergence_test - { - public: - //! Default constructor. - traditional_convergence_test() - : n_(0), eps_(0) - {} - - //! Constructor. - /*! @param n The number of variables in the minimization problem. - Restriction: n > 0. - - @param eps Determines the accuracy with which the solution - is to be found. - */ - explicit - traditional_convergence_test( - SizeType n, - FloatType eps = FloatType(1.e-5)) - : n_(n), eps_(eps) - { - if (n_ == 0) { - throw error_improper_input_parameter("n = 0."); - } - if (eps_ < FloatType(0)) { - throw error_improper_input_parameter("eps < 0."); - } - } - - //! Number of free parameters (as passed to the constructor). - SizeType n() const { return n_; } - - /*! \brief Accuracy with which the solution is to be found - (as passed to the constructor). - */ - FloatType eps() const { return eps_; } - - //! Execution of the convergence test for the given parameters. - /*! Returns true if -

-            ||g|| < eps * max(1,||x||),
-          
- where ||.|| denotes the Euclidean norm. - - @param x Current solution vector. - - @param g Components of the gradient at the current - point x. - */ - bool - operator()(const FloatType* x, const FloatType* g) const - { - FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); - FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); - if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; - return false; - } - protected: - const SizeType n_; - const FloatType eps_; - }; - -}} // namespace scitbx::lbfgs - -template -std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer& min) { - return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); -} - - -#endif // SCITBX_LBFGS_H diff --git a/training/lbfgs_test.cc b/training/lbfgs_test.cc deleted file mode 100644 index 9678e788..00000000 --- a/training/lbfgs_test.cc +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include -#include -#include -#include "lbfgs.h" -#include "sparse_vector.h" -#include "fdict.h" - -using namespace std; - -double TestOptimizer() { - cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; - - // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 - // df/dx1 = 8*x1 + x2 - // df/dx2 = 2*x2 + x1 - // df/dx3 = 2*x3 + 6 - double x[3]; - double g[3]; - scitbx::lbfgs::minimizer opt(3); - scitbx::lbfgs::traditional_convergence_test converged(3); - x[0] = 8; - x[1] = 8; - x[2] = 8; - double obj = 0; - do { - g[0] = 8 * x[0] + x[1]; - g[1] = 2 * x[1] + x[0]; - g[2] = 2 * x[2] + 6; - obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; - opt.run(x, obj, g); - if (!opt.requests_f_and_g()) { - if (converged(x,g)) break; - opt.run(x, obj, g); - } - cerr << x[0] << " " << x[1] << " " << x[2] << endl; - cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; - cerr << opt << endl; - } while (true); - return obj; -} - -double TestPersistentOptimizer() { - cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; - // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 - // df/dx1 = 8*x1 + x2 - // df/dx2 = 2*x2 + x1 - // df/dx3 = 2*x3 + 6 - double x[3]; - double g[3]; - scitbx::lbfgs::traditional_convergence_test converged(3); - x[0] = 8; - x[1] = 8; - x[2] = 8; - double obj = 0; - string state; - do { - g[0] = 8 * x[0] + x[1]; - g[1] = 2 * x[1] + x[0]; - g[2] = 2 * x[2] + 6; - obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; - - { - scitbx::lbfgs::minimizer opt(3); - if (state.size() > 0) { - istringstream is(state, ios::binary); - opt.deserialize(&is); - } - opt.run(x, obj, g); - ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); - } - - cerr << x[0] << " " << x[1] << " " << x[2] << endl; - cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; - } while (!converged(x, g)); - return obj; -} - -void TestSparseVector() { - cerr << "Testing SparseVector serialization.\n"; - int f1 = FD::Convert("Feature_1"); - int f2 = FD::Convert("Feature_2"); - FD::Convert("LanguageModel"); - int f4 = FD::Convert("SomeFeature"); - int f5 = FD::Convert("SomeOtherFeature"); - SparseVector g; - g.set_value(f2, log(0.5)); - g.set_value(f4, log(0.125)); - g.set_value(f1, 0); - g.set_value(f5, 23.777); - ostringstream os; - double iobj = 1.5; - B64::Encode(iobj, g, &os); - cerr << iobj << "\t" << g << endl; - string data = os.str(); - cout << data << endl; - SparseVector v; - double obj; - bool decode_b64 = B64::Decode(&obj, &v, &data[0], data.size()); - cerr << obj << "\t" << v << endl; - assert(decode_b64); - assert(obj == iobj); - assert(g.size() == v.size()); -} - -int main() { - double o1 = TestOptimizer(); - double o2 = TestPersistentOptimizer(); - if (fabs(o1 - o2) > 1e-5) { - cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; - return 1; - } - TestSparseVector(); - cerr << "SUCCESS\n"; - return 0; -} - diff --git a/training/lbl_model.cc b/training/lbl_model.cc deleted file mode 100644 index a46ce33c..00000000 --- a/training/lbl_model.cc +++ /dev/null @@ -1,421 +0,0 @@ -#include - -#include "config.h" -#ifndef HAVE_EIGEN - int main() { std::cerr << "Please rebuild with --with-eigen PATH\n"; return 1; } -#else - -#include -#include -#include -#include -#include // memset -#include - -#ifdef HAVE_MPI -#include -#include -#include -namespace mpi = boost::mpi; -#endif -#include -#include -#include -#include - -#include "corpus_tools.h" -#include "optimize.h" -#include "array2d.h" -#include "m.h" -#include "lattice.h" -#include "stringlib.h" -#include "filelib.h" -#include "tdict.h" - -namespace po = boost::program_options; -using namespace std; - -#define kDIMENSIONS 10 -typedef Eigen::Matrix RVector; -typedef Eigen::Matrix RTVector; -typedef Eigen::Matrix TMatrix; -vector r_src, r_trg; - -#if HAVE_MPI -namespace boost { -namespace serialization { - -template -void serialize(Archive & ar, RVector & v, const unsigned int version) { - for (unsigned i = 0; i < kDIMENSIONS; ++i) - ar & v[i]; -} - -} // namespace serialization -} // namespace boost -#endif - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input,i",po::value(),"Input file") - ("iterations,I",po::value()->default_value(1000),"Number of iterations of training") - ("regularization_strength,C",po::value()->default_value(0.1),"L2 regularization strength (0 for no regularization)") - ("eta", po::value()->default_value(0.1f), "Eta for SGD") - ("source_embeddings,f", po::value(), "File containing source embeddings (if unset, random vectors will be used)") - ("target_embeddings,e", po::value(), "File containing target embeddings (if unset, random vectors will be used)") - ("random_seed,s", po::value(), "Random seed") - ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (0 = uniform, >0 sharpens)") - ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (argc < 2 || conf->count("help")) { - cerr << "Usage " << argv[0] << " [OPTIONS] -i corpus.fr-en\n"; - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void Normalize(RVector* v) { - double norm = v->norm(); - assert(norm > 0.0f); - *v /= norm; -} - -void Flatten(const TMatrix& m, vector* v) { - unsigned c = 0; - v->resize(kDIMENSIONS * kDIMENSIONS); - for (unsigned i = 0; i < kDIMENSIONS; ++i) - for (unsigned j = 0; j < kDIMENSIONS; ++j) { - assert(boost::math::isfinite(m(i, j))); - (*v)[c++] = m(i,j); - } -} - -void Unflatten(const vector& v, TMatrix* m) { - unsigned c = 0; - for (unsigned i = 0; i < kDIMENSIONS; ++i) - for (unsigned j = 0; j < kDIMENSIONS; ++j) { - assert(boost::math::isfinite(v[c])); - (*m)(i, j) = v[c++]; - } -} - -double ApplyRegularization(const double C, - const vector& weights, - vector* g) { - assert(weights.size() == g->size()); - double reg = 0; - for (size_t i = 0; i < weights.size(); ++i) { - const double& w_i = weights[i]; - double& g_i = (*g)[i]; - reg += C * w_i * w_i; - g_i += 2 * C * w_i; - } - return reg; -} - -void LoadEmbeddings(const string& filename, vector* pv) { - vector& v = *pv; - cerr << "Reading embeddings from " << filename << " ...\n"; - ReadFile rf(filename); - istream& in = *rf.stream(); - string line; - unsigned lc = 0; - while(getline(in, line)) { - ++lc; - size_t cur = line.find(' '); - if (cur == string::npos || cur == 0) { - cerr << "Parse error reading line " << lc << ":\n" << line << endl; - abort(); - } - WordID w = TD::Convert(line.substr(0, cur)); - if (w >= v.size()) continue; - RVector& curv = v[w]; - line[cur] = 0; - size_t start = cur + 1; - cur = start + 1; - size_t c = 0; - while(cur < line.size()) { - if (line[cur] == ' ') { - line[cur] = 0; - curv[c++] = strtod(&line[start], NULL); - start = cur + 1; - cur = start; - if (c == kDIMENSIONS) break; - } - ++cur; - } - if (c < kDIMENSIONS && cur != start) { - if (cur < line.size()) line[cur] = 0; - curv[c++] = strtod(&line[start], NULL); - } - if (c != kDIMENSIONS) { - static bool first = true; - if (first) { - cerr << " read " << c << " dimensions from embedding file, but built with " << kDIMENSIONS << " (filling in with random values)\n"; - first = false; - } - for (; c < kDIMENSIONS; ++c) curv[c] = rand(); - } - if (c == kDIMENSIONS && cur != line.size()) { - static bool first = true; - if (first) { - cerr << " embedding file contains more dimensions than configured with, truncating.\n"; - first = false; - } - } - } -} - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - std::cerr << "**MPI enabled.\n"; - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - std::cerr << "**MPI disabled.\n"; - const int rank = 0; - const int size = 1; -#endif - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) return 1; - const string fname = conf["input"].as(); - const double reg_strength = conf["regularization_strength"].as(); - const bool has_l2 = reg_strength; - assert(reg_strength >= 0.0f); - const int ITERATIONS = conf["iterations"].as(); - const double eta = conf["eta"].as(); - const double diagonal_tension = conf["diagonal_tension"].as(); - bool SGD = false; - if (diagonal_tension < 0.0) { - cerr << "Invalid value for diagonal_tension: must be >= 0\n"; - return 1; - } - string testset; - if (conf.count("testset")) testset = conf["testset"].as(); - - unsigned lc = 0; - vector unnormed_a_i; - bool flag = false; - vector > srcs, trgs; - vector vocab_e; - { - set svocab_e, svocab_f; - CorpusTools::ReadFromFile(fname, &srcs, NULL, &trgs, &svocab_e, rank, size); - copy(svocab_e.begin(), svocab_e.end(), back_inserter(vocab_e)); - } - cerr << "Number of target word types: " << vocab_e.size() << endl; - const double num_examples = lc; - - boost::shared_ptr lbfgs; - if (rank == 0) - lbfgs.reset(new LBFGSOptimizer(kDIMENSIONS * kDIMENSIONS, 100)); - r_trg.resize(TD::NumWords() + 1); - r_src.resize(TD::NumWords() + 1); - vector > trg_pos(TD::NumWords() + 1); - - if (conf.count("random_seed")) { - srand(conf["random_seed"].as()); - } else { - unsigned seed = time(NULL) + rank * 100; - cerr << "Random seed: " << seed << endl; - srand(seed); - } - - TMatrix t = TMatrix::Zero(); - if (rank == 0) { - t = TMatrix::Random() / 50.0; - for (unsigned i = 1; i < r_trg.size(); ++i) { - r_trg[i] = RVector::Random(); - r_src[i] = RVector::Random(); - } - if (conf.count("source_embeddings")) - LoadEmbeddings(conf["source_embeddings"].as(), &r_src); - if (conf.count("target_embeddings")) - LoadEmbeddings(conf["target_embeddings"].as(), &r_trg); - } - - // do optimization - TMatrix g = TMatrix::Zero(); - vector exp_src; - vector z_src; - vector flat_g, flat_t, rcv_grad; - Flatten(t, &flat_t); - bool converged = false; -#if HAVE_MPI - mpi::broadcast(world, &flat_t[0], flat_t.size(), 0); - mpi::broadcast(world, r_trg, 0); - mpi::broadcast(world, r_src, 0); -#endif - cerr << "rank=" << rank << ": " << r_trg[0][4] << endl; - for (int iter = 0; !converged && iter < ITERATIONS; ++iter) { - if (rank == 0) cerr << "ITERATION " << (iter + 1) << endl; - Unflatten(flat_t, &t); - double likelihood = 0; - double denom = 0.0; - lc = 0; - flag = false; - g *= 0; - for (unsigned i = 0; i < srcs.size(); ++i) { - const vector& src = srcs[i]; - const vector& trg = trgs[i]; - ++lc; - if (rank == 0 && lc % 1000 == 0) { cerr << '.'; flag = true; } - if (rank == 0 && lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } - denom += trg.size(); - - exp_src.clear(); exp_src.resize(src.size(), TMatrix::Zero()); - z_src.clear(); z_src.resize(src.size(), 0.0); - Array2D exp_refs(src.size(), trg.size(), TMatrix::Zero()); - Array2D z_refs(src.size(), trg.size(), 0.0); - for (unsigned j = 0; j < trg.size(); ++j) - trg_pos[trg[j]].insert(j); - - for (unsigned i = 0; i < src.size(); ++i) { - const RVector& r_s = r_src[src[i]]; - const RTVector pred = r_s.transpose() * t; - TMatrix& exp_m = exp_src[i]; - double& z = z_src[i]; - for (unsigned k = 0; k < vocab_e.size(); ++k) { - const WordID v_k = vocab_e[k]; - const RVector& r_t = r_trg[v_k]; - const double dot_prod = pred * r_t; - const double u = exp(dot_prod); - z += u; - const TMatrix v = r_s * r_t.transpose() * u; - exp_m += v; - set& ref_locs = trg_pos[v_k]; - if (!ref_locs.empty()) { - for (set::iterator it = ref_locs.begin(); it != ref_locs.end(); ++it) { - TMatrix& exp_ref_ij = exp_refs(i, *it); - double& z_ref_ij = z_refs(i, *it); - z_ref_ij += u; - exp_ref_ij += v; - } - } - } - } - for (unsigned j = 0; j < trg.size(); ++j) - trg_pos[trg[j]].clear(); - - // model expectations for a single target generation with - // uniform alignment prior - // TODO: when using a non-uniform alignment, m_exp will be - // a function of j (below) - double m_z = 0; - TMatrix m_exp = TMatrix::Zero(); - for (unsigned i = 0; i < src.size(); ++i) { - m_exp += exp_src[i]; - m_z += z_src[i]; - } - m_exp /= m_z; - - Array2D al(src.size(), trg.size(), false); - for (unsigned j = 0; j < trg.size(); ++j) { - double ref_z = 0; - TMatrix ref_exp = TMatrix::Zero(); - int max_i = 0; - double max_s = -9999999; - for (unsigned i = 0; i < src.size(); ++i) { - ref_exp += exp_refs(i, j); - ref_z += z_refs(i, j); - if (log(z_refs(i, j)) > max_s) { - max_s = log(z_refs(i, j)); - max_i = i; - } - // TODO handle alignment prob - } - if (ref_z <= 0) { - cerr << "TRG=" << TD::Convert(trg[j]) << endl; - cerr << " LINE=" << lc << " (RANK=" << rank << "/" << size << ")" << endl; - cerr << " REF_EXP=\n" << ref_exp << endl; - cerr << " M_EXP=\n" << m_exp << endl; - abort(); - } - al(max_i, j) = true; - ref_exp /= ref_z; - g += m_exp - ref_exp; - likelihood += log(ref_z) - log(m_z); - if (SGD) { - t -= g * eta / num_examples; - g *= 0; - } - } - - if (rank == 0 && (iter == (ITERATIONS - 1) || lc < 12)) { cerr << al << endl; } - } - if (flag && rank == 0) { cerr << endl; } - - double obj = 0; - if (!SGD) { - Flatten(g, &flat_g); - obj = -likelihood; -#if HAVE_MPI - rcv_grad.resize(flat_g.size(), 0.0); - mpi::reduce(world, &flat_g[0], flat_g.size(), &rcv_grad[0], plus(), 0); - swap(flat_g, rcv_grad); - rcv_grad.clear(); - - double to = 0; - mpi::reduce(world, obj, to, plus(), 0); - obj = to; - double tlh = 0; - mpi::reduce(world, likelihood, tlh, plus(), 0); - likelihood = tlh; - double td = 0; - mpi::reduce(world, denom, td, plus(), 0); - denom = td; -#endif - } - - if (rank == 0) { - double gn = 0; - for (unsigned i = 0; i < flat_g.size(); ++i) - gn += flat_g[i]*flat_g[i]; - const double base2_likelihood = likelihood / log(2); - cerr << " log_e likelihood: " << likelihood << endl; - cerr << " log_2 likelihood: " << base2_likelihood << endl; - cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; - cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; - cerr << " gradient norm: " << sqrt(gn) << endl; - if (!SGD) { - if (has_l2) { - const double r = ApplyRegularization(reg_strength, - flat_t, - &flat_g); - obj += r; - cerr << " regularization: " << r << endl; - } - lbfgs->Optimize(obj, flat_g, &flat_t); - converged = (lbfgs->HasConverged()); - } - } -#ifdef HAVE_MPI - mpi::broadcast(world, &flat_t[0], flat_t.size(), 0); - mpi::broadcast(world, converged, 0); -#endif - } - if (rank == 0) - cerr << "TRANSLATION MATRIX:" << endl << t << endl; - return 0; -} - -#endif - diff --git a/training/minrisk/Makefile.am b/training/minrisk/Makefile.am new file mode 100644 index 00000000..a15e821e --- /dev/null +++ b/training/minrisk/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = minrisk_optimize + +minrisk_optimize_SOURCES = minrisk_optimize.cc +minrisk_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/training/liblbfgs/liblbfgs.a -lz + +AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training -I$(top_srcdir)/training/utils diff --git a/training/minrisk/minrisk.pl b/training/minrisk/minrisk.pl new file mode 100755 index 00000000..0f8bacd0 --- /dev/null +++ b/training/minrisk/minrisk.pl @@ -0,0 +1,540 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment", "$SCRIPT_DIR/../utils"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use IPC::Open2; +use POSIX ":sys_wait_h"; +my $QSUB_CMD = qsub_args(mert_memory()); +my $default_jobs = env_default_jobs(); + +my $UTILS_DIR="$SCRIPT_DIR/../utils"; +require "$UTILS_DIR/libcall.pl"; + +# Default settings +my $srcFile; +my $refFiles; +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/minrisk_generate_input.pl"; +my $MAPPER = "$bin_dir/minrisk_optimize"; +my $parallelize = "$UTILS_DIR/parallelize.pl"; +my $libcall = "$UTILS_DIR/libcall.pl"; +my $sentserver = "$UTILS_DIR/sentserver"; +my $sentclient = "$UTILS_DIR/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $lines_per_mapper = 30; +my $iteration = 1; +my $best_weights; +my $psi = 1; +my $default_max_iter = 30; +my $max_iterations = $default_max_iter; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "4g"; +my $disable_clean = 0; +my %seen_weights; +my $help = 0; +my $epsilon = 0.0001; +my $dryrun = 0; +my $last_score = -10000000; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $use_make = 1; # use make to parallelize +my $useqsub = 0; +my $initial_weights; +my $pass_suffix = ''; +my $cpbin=1; + +# regularization strength +my $tune_regularizer = 0; +my $reg = 500; +my $reg_previous = 5000; +my $dont_accum = 0; + +# Process command-line options +Getopt::Long::Configure("no_auto_abbrev"); +if (GetOptions( + "jobs=i" => \$jobs, + "dont-clean" => \$disable_clean, + "dont-accumulate" => \$dont_accum, + "pass-suffix=s" => \$pass_suffix, + "qsub" => \$useqsub, + "dry-run" => \$dryrun, + "epsilon=s" => \$epsilon, + "help" => \$help, + "weights=s" => \$initial_weights, + "reg=f" => \$reg, + "use-make=i" => \$use_make, + "max-iterations=i" => \$max_iterations, + "pmem=s" => \$pmem, + "cpbin!" => \$cpbin, + "ref-files=s" => \$refFiles, + "metric=s" => \$metric, + "source-file=s" => \$srcFile, + "workdir=s" => \$dir, +) == 0 || @ARGV!=1 || $help) { + print_help(); + exit; +} + +die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (!defined $srcFile) { push @missing_args, "--source-file"; } +if (!defined $refFiles) { push @missing_args, "--ref-files"; } +if (!defined $initial_weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); + +if ($metric =~ /^(combi|ter)$/i) { + $lines_per_mapper = 5; +} + +($iniFile) = @ARGV; + + +sub write_config; +sub enseg; +sub print_help; + +my $nodelist; +my $host =check_output("hostname"); chomp $host; +my $bleu; +my $interval_count = 0; +my $logfile; +my $projected_score; + +# used in sorting scores +my $DIR_FLAG = '-r'; +if ($metric =~ /^ter$|^aer$/i) { + $DIR_FLAG = ''; +} + +my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); + +unless ($dir){ + $dir = "minrisk"; +} +unless ($dir =~ /^\//){ # convert relative path to absolute path + my $basedir = check_output("pwd"); + chomp $basedir; + $dir = "$basedir/$dir"; +} + + +# Initializations and helper functions +srand; + +my @childpids = (); +my @cleanupcmds = (); + +sub cleanup { + print STDERR "Cleanup...\n"; + for my $pid (@childpids){ unchecked_call("kill $pid"); } + for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } + exit 1; +}; +# Always call cleanup, no matter how we exit +*CORE::GLOBAL::exit = + sub{ cleanup(); }; +$SIG{INT} = "cleanup"; +$SIG{TERM} = "cleanup"; +$SIG{HUP} = "cleanup"; + +my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; +my $newIniFile = "$dir/$decoderBase.ini"; +my $inputFileName = "$dir/input"; +my $user = $ENV{"USER"}; +# process ini file +-e $iniFile || die "Error: could not open $iniFile for reading\n"; +open(INI, $iniFile); + +use File::Basename qw(basename); +#pass bindir, refs to vars holding bin +sub modbin { + local $_; + my $bindir=shift; + check_call("mkdir -p $bindir"); + -d $bindir || die "couldn't make bindir $bindir"; + for (@_) { + my $src=$$_; + $$_="$bindir/".basename($src); + check_call("cp -p $src $$_"); + } +} +sub dirsize { + opendir ISEMPTY,$_[0]; + return scalar(readdir(ISEMPTY))-1; +} +my @allweights; +if ($dryrun){ + write_config(*STDERR); + exit 0; +} else { + if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-pro.pl outputs + die "ERROR: working dir $dir already exists\n\n"; + } else { + -e $dir || mkdir $dir; + mkdir "$dir/hgs"; + modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; + mkdir "$dir/scripts"; + my $cmdfile="$dir/rerun-pro.sh"; + open CMD,'>',$cmdfile; + print CMD "cd ",&getcwd,"\n"; +# print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. + my $cline=&cmdline."\n"; + print CMD $cline; + close CMD; + print STDERR $cline; + chmod(0755,$cmdfile); + check_call("cp $initial_weights $dir/weights.0"); + die "Can't find weights.0" unless (-e "$dir/weights.0"); + } + write_config(*STDERR); +} + + +# Generate initial files and values +check_call("cp $iniFile $newIniFile"); +$iniFile = $newIniFile; + +my $newsrc = "$dir/dev.input"; +enseg($srcFile, $newsrc); +$srcFile = $newsrc; +my $devSize = 0; +open F, "<$srcFile" or die "Can't read $srcFile: $!"; +while() { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; +my $kbest = "$dir/kbest"; +if ($dont_accum) { + $kbest = ''; +} else { + check_call("mkdir -p $kbest"); + $kbest = "--kbest_repository $kbest"; +} + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +while (1){ + print STDERR "\n\nITERATION $iteration\n==========\n"; + + if ($iteration > $max_iterations){ + print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + check_call("mkdir -p $logdir"); + + + #decode + print STDERR "RUNNING DECODER AT "; + print STDERR unchecked_output("date"); + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + push @allweights, "-w $dir/weights.$im1"; + `rm -f $dir/hgs/*.gz`; + my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; + my $pcmd; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; + } else { + $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; + } + my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + my $num_hgs; + my $num_topbest; + my $retries = 0; + while($retries < 5) { + $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); + $num_topbest = check_output("wc -l < $runFile"); + print STDERR "NUMBER OF HGs: $num_hgs\n"; + print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; + if($devSize == $num_hgs && $devSize == $num_topbest) { + last; + } else { + print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; + sleep(3); + } + $retries++; + } + die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); + my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -m $metric"); + chomp $dec_score; + print STDERR "DECODER SCORE: $dec_score\n"; + + # save space + check_call("gzip -f $runFile"); + check_call("gzip -f $decoderLog"); + + # run optimizer + print STDERR "RUNNING OPTIMIZER AT "; + print STDERR unchecked_output("date"); + print STDERR " - GENERATE TRAINING EXEMPLARS\n"; + my $mergeLog="$logdir/prune-merge.log.$iteration"; + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + my $outweights="$dir/weights.$iteration"; + $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + $cmd="$MAPPER $refs_comma_sep -m $metric -i $dir/agenda.$im1 $kbest -w $inweights > $outweights"; + check_call($cmd); + $lastWeightsFile = $outweights; + $iteration++; + `rm hgs/*.gz`; + print STDERR "\n==========\n"; +} + +print STDERR "\nFINAL WEIGHTS: $lastWeightsFile\n(Use -w with the decoder)\n\n"; + +print STDOUT "$lastWeightsFile\n"; + +exit 0; + +sub get_lines { + my $fn = shift @_; + open FL, "<$fn" or die "Couldn't read $fn: $!"; + my $lc = 0; + while() { $lc++; } + return $lc; +} + +sub get_comma_sep_refs { + my ($r,$p) = @_; + my $o = check_output("echo $p"); + chomp $o; + my @files = split /\s+/, $o; + return "-$r " . join(" -$r ", @files); +} + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +# subs +sub write_config { + my $fh = shift; + my $cleanup = "yes"; + if ($disable_clean) {$cleanup = "no";} + + print $fh "\n"; + print $fh "DECODER: $decoder\n"; + print $fh "INI FILE: $iniFile\n"; + print $fh "WORKING DIR: $dir\n"; + print $fh "SOURCE (DEV): $srcFile\n"; + print $fh "REFS (DEV): $refFiles\n"; + print $fh "EVAL METRIC: $metric\n"; + print $fh "MAX ITERATIONS: $max_iterations\n"; + print $fh "JOBS: $jobs\n"; + print $fh "HEAD NODE: $host\n"; + print $fh "PMEM (DECODING): $pmem\n"; + print $fh "CLEANUP: $cleanup\n"; +} + +sub update_weights_file { + my ($neww, $rfn, $rpts) = @_; + my @feats = @$rfn; + my @pts = @$rpts; + my $num_feats = scalar @feats; + my $num_pts = scalar @pts; + die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; + open G, ">$neww" or die; + for (my $i = 0; $i < $num_feats; $i++) { + my $f = $feats[$i]; + my $lambda = $pts[$i]; + print G "$f $lambda\n"; + } + close G; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; + die "Empty dev set!" if ($i == 0); +} + +sub print_help { + + my $executable = check_output("basename $0"); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable [options] + Runs a complete PRO optimization using the ini file specified. + +Required: + + --ref-files + Dev set ref files. This option takes only a single string argument. + To use multiple files (including file globbing), this argument should + be quoted. + + --source-file + Dev set source file. + + --weights + Initial weights file (use empty file to start from 0) + +General options: + + --help + Print this message and exit. + + --dont-accumulate + Don't accumulate k-best lists from multiple iterations. + + --max-iterations + Maximum number of iterations to run. If not specified, defaults + to $default_max_iter. + + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --pass-suffix + If the decoder is doing multi-pass decoding, the pass suffix "2", + "3", etc., is used to control what iteration of weights is set. + + --workdir + Directory for intermediate and output files. If not specified, the + name is derived from the ini filename. Assuming that the ini + filename begins with the decoder name and ends with ini, the default + name of the working directory is inferred from the middle part of + the filename. E.g. an ini file named decoder.foo.ini would have + a default working directory name foo. + +Regularization options: + + --reg + l2 regularization strength [default=500]. The greater this value, + the closer to zero the weights will be. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} diff --git a/training/minrisk/minrisk_generate_input.pl b/training/minrisk/minrisk_generate_input.pl new file mode 100755 index 00000000..b30fc4fd --- /dev/null +++ b/training/minrisk/minrisk_generate_input.pl @@ -0,0 +1,18 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; +my $d = shift @ARGV; +die "Can't find directory $d" unless -d $d; + +opendir(DIR, $d) or die "Can't read $d: $!"; +my @hgs = grep { /\.gz$/ } readdir(DIR); +closedir DIR; + +for my $hg (@hgs) { + my $file = $hg; + my $id = $hg; + $id =~ s/(\.json)?\.gz//; + print "$d/$file $id\n"; +} + diff --git a/training/minrisk/minrisk_optimize.cc b/training/minrisk/minrisk_optimize.cc new file mode 100644 index 00000000..da8b5260 --- /dev/null +++ b/training/minrisk/minrisk_optimize.cc @@ -0,0 +1,197 @@ +#include +#include +#include +#include + +#include +#include + +#include "liblbfgs/lbfgs++.h" +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "ns.h" +#include "ns_docscorer.h" +#include "candidate_set.h" +#include "risk.h" +#include "entropy.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("weights,w",po::value(), "[REQD] Weights files from current iterations") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") + ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") + ("temperature,T",po::value()->default_value(0.0), "Temperature parameter for objective (>0 increases the entropy)") + ("l1_strength,C",po::value()->default_value(0.0), "L1 regularization strength") + ("memory_buffers,M",po::value()->default_value(20), "Memory buffers used in LBFGS") + ("kbest_repository,R",po::value(), "Accumulate k-best lists from previous iterations (parameter is path to repository)") + ("kbest_size,k",po::value()->default_value(500u), "Top k-hypotheses to extract") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (!conf->count("weights")) { + cerr << "Please specify weights using -w \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +EvaluationMetric* metric = NULL; + +struct RiskObjective { + explicit RiskObjective(const vector& tr, const double temp) : training(tr), T(temp) {} + double operator()(const vector& x, double* g) const { + fill(g, g + x.size(), 0.0); + double obj = 0; + double h = 0; + for (unsigned i = 0; i < training.size(); ++i) { + training::CandidateSetRisk risk(training[i], *metric); + training::CandidateSetEntropy entropy(training[i]); + SparseVector tg, hg; + double r = risk(x, &tg); + double hh = entropy(x, &hg); + h += hh; + obj += r; + for (SparseVector::iterator it = tg.begin(); it != tg.end(); ++it) + g[it->first] += it->second; + if (T) { + for (SparseVector::iterator it = hg.begin(); it != hg.end(); ++it) + g[it->first] += T * it->second; + } + } + cerr << (1-(obj / training.size())) << " H=" << h << endl; + return obj - T * h; + } + const vector& training; + const double T; // temperature for entropy regularization +}; + +double LearnParameters(const vector& training, + const double temp, // > 0 increases the entropy, < 0 decreases the entropy + const double C1, + const unsigned memory_buffers, + vector* px) { + RiskObjective obj(training, temp); + LBFGS lbfgs(px, obj, memory_buffers, C1); + lbfgs.MinimizeFunction(); + return 0; +} + +#if 0 +struct FooLoss { + double operator()(const vector& x, double* g) const { + fill(g, g + x.size(), 0.0); + training::CandidateSet cs; + training::CandidateSetEntropy cse(cs); + cs.cs.resize(3); + cs.cs[0].fmap.set_value(FD::Convert("F1"), -1.0); + cs.cs[1].fmap.set_value(FD::Convert("F2"), 1.0); + cs.cs[2].fmap.set_value(FD::Convert("F1"), 2.0); + cs.cs[2].fmap.set_value(FD::Convert("F2"), 0.5); + SparseVector xx; + double h = cse(x, &xx); + cerr << cse(x, &xx) << endl; cerr << "G: " << xx << endl; + for (SparseVector::iterator i = xx.begin(); i != xx.end(); ++i) + g[i->first] += i->second; + return -h; + } +}; +#endif + +int main(int argc, char** argv) { +#if 0 + training::CandidateSet cs; + training::CandidateSetEntropy cse(cs); + cs.cs.resize(3); + cs.cs[0].fmap.set_value(FD::Convert("F1"), -1.0); + cs.cs[1].fmap.set_value(FD::Convert("F2"), 1.0); + cs.cs[2].fmap.set_value(FD::Convert("F1"), 2.0); + cs.cs[2].fmap.set_value(FD::Convert("F2"), 0.5); + FooLoss foo; + vector ww(FD::NumFeats()); ww[FD::Convert("F1")] = 1.0; + LBFGS lbfgs(&ww, foo, 100, 0.0); + lbfgs.MinimizeFunction(); + return 1; +#endif + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string evaluation_metric = conf["evaluation_metric"].as(); + + metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; + + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as()); + string kbest_repo; + if (conf.count("kbest_repository")) { + kbest_repo = conf["kbest_repository"].as(); + MkDirP(kbest_repo); + } + istream &in=*in_read.stream(); + const unsigned kbest_size = conf["kbest_size"].as(); + vector weights; + const string weightsf = conf["weights"].as(); + Weights::InitFromFile(weightsf, &weights); + double t = 0; + for (unsigned i = 0; i < weights.size(); ++i) + t += weights[i] * weights[i]; + if (t > 0) { + for (unsigned i = 0; i < weights.size(); ++i) + weights[i] /= sqrt(t); + } + string line, file; + vector kis; + cerr << "Loading hypergraphs...\n"; + while(getline(in, line)) { + istringstream is(line); + int sent_id; + kis.resize(kis.size() + 1); + training::CandidateSet& curkbest = kis.back(); + string kbest_file; + if (kbest_repo.size()) { + ostringstream os; + os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; + kbest_file = os.str(); + if (FileExists(kbest_file)) + curkbest.ReadFromFile(kbest_file); + } + is >> file >> sent_id; + ReadFile rf(file); + if (kis.size() % 5 == 0) { cerr << '.'; } + if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(weights); + curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); + if (kbest_file.size()) + curkbest.WriteToFile(kbest_file); + } + cerr << "\nHypergraphs loaded.\n"; + weights.resize(FD::NumFeats()); + + double c1 = conf["l1_strength"].as(); + double temp = conf["temperature"].as(); + unsigned m = conf["memory_buffers"].as(); + LearnParameters(kis, temp, c1, m, &weights); + Weights::WriteToFile("-", weights); + return 0; +} + diff --git a/training/mira/Makefile.am b/training/mira/Makefile.am new file mode 100644 index 00000000..ae609ede --- /dev/null +++ b/training/mira/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = kbest_mira + +kbest_mira_SOURCES = kbest_mira.cc +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/mira/kbest_mira.cc b/training/mira/kbest_mira.cc new file mode 100644 index 00000000..8b7993dd --- /dev/null +++ b/training/mira/kbest_mira.cc @@ -0,0 +1,309 @@ +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "hg_sampler.h" +#include "sentence_metadata.h" +#include "scorer.h" +#include "verbose.h" +#include "viterbi.h" +#include "hg.h" +#include "prob.h" +#include "kbest.h" +#include "ff_register.h" +#include "decoder.h" +#include "filelib.h" +#include "fdict.h" +#include "weights.h" +#include "sparse_vector.h" +#include "sampler.h" + +using namespace std; +namespace po = boost::program_options; + +bool invert_score; +std::tr1::shared_ptr rng; + +void RandomPermutation(int len, vector* p_ids) { + vector& ids = *p_ids; + ids.resize(len); + for (int i = 0; i < len; ++i) ids[i] = i; + for (int i = len; i > 0; --i) { + int j = rng->next() * i; + if (j == i) i--; + swap(ids[i-1], ids[j]); + } +} + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input_weights,w",po::value(),"Input feature weights file") + ("source,i",po::value(),"Source file for development set") + ("passes,p", po::value()->default_value(15), "Number of passes through the training data") + ("reference,r",po::value >(), "[REQD] Reference translation(s) (tokenized text file)") + ("mt_metric,m",po::value()->default_value("ibm_bleu"), "Scoring metric (ibm_bleu, nist_bleu, koehn_bleu, ter, combi)") + ("max_step_size,C", po::value()->default_value(0.01), "regularization strength (C)") + ("mt_metric_scale,s", po::value()->default_value(1.0), "Amount to scale MT loss function by") + ("k_best_size,k", po::value()->default_value(250), "Size of hypothesis list to search for oracles") + ("sample_forest,f", "Instead of a k-best list, sample k hypotheses from the decoder's forest") + ("sample_forest_unit_weight_vector,x", "Before sampling (must use -f option), rescale the weight vector used so it has unit length; this may improve the quality of the samples") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("decoder_config,c",po::value(),"Decoder configuration file"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || !conf->count("input_weights") || !conf->count("source") || !conf->count("decoder_config") || !conf->count("reference")) { + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +static const double kMINUS_EPSILON = -1e-6; + +struct HypothesisInfo { + SparseVector features; + double mt_metric; +}; + +struct GoodBadOracle { + std::tr1::shared_ptr good; + std::tr1::shared_ptr bad; +}; + +struct TrainingObserver : public DecoderObserver { + TrainingObserver(const int k, const DocScorer& d, bool sf, vector* o) : ds(d), oracles(*o), kbest_size(k), sample_forest(sf) {} + const DocScorer& ds; + vector& oracles; + std::tr1::shared_ptr cur_best; + const int kbest_size; + const bool sample_forest; + + const HypothesisInfo& GetCurrentBestHypothesis() const { + return *cur_best; + } + + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { + UpdateOracles(smeta.GetSentenceID(), *hg); + } + + std::tr1::shared_ptr MakeHypothesisInfo(const SparseVector& feats, const double score) { + std::tr1::shared_ptr h(new HypothesisInfo); + h->features = feats; + h->mt_metric = score; + return h; + } + + void UpdateOracles(int sent_id, const Hypergraph& forest) { + std::tr1::shared_ptr& cur_good = oracles[sent_id].good; + std::tr1::shared_ptr& cur_bad = oracles[sent_id].bad; + cur_bad.reset(); // TODO get rid of?? + + if (sample_forest) { + vector cur_prediction; + ViterbiESentence(forest, &cur_prediction); + float sentscore = ds[sent_id]->ScoreCandidate(cur_prediction)->ComputeScore(); + cur_best = MakeHypothesisInfo(ViterbiFeatures(forest), sentscore); + + vector samples; + HypergraphSampler::sample_hypotheses(forest, kbest_size, &*rng, &samples); + for (unsigned i = 0; i < samples.size(); ++i) { + sentscore = ds[sent_id]->ScoreCandidate(samples[i].words)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(samples[i].fmap, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(samples[i].fmap, sentscore); + } + } else { + KBest::KBestDerivations, ESentenceTraversal> kbest(forest, kbest_size); + for (int i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(forest.nodes_.size() - 1, i); + if (!d) break; + float sentscore = ds[sent_id]->ScoreCandidate(d->yield)->ComputeScore(); + if (invert_score) sentscore *= -1.0; + // cerr << TD::GetString(d->yield) << " ||| " << d->score << " ||| " << sentscore << endl; + if (i == 0) + cur_best = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_good || sentscore > cur_good->mt_metric) + cur_good = MakeHypothesisInfo(d->feature_values, sentscore); + if (!cur_bad || sentscore < cur_bad->mt_metric) + cur_bad = MakeHypothesisInfo(d->feature_values, sentscore); + } + //cerr << "GOOD: " << cur_good->mt_metric << endl; + //cerr << " CUR: " << cur_best->mt_metric << endl; + //cerr << " BAD: " << cur_bad->mt_metric << endl; + } + } +}; + +void ReadTrainingCorpus(const string& fname, vector* c) { + ReadFile rf(fname); + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) break; + c->push_back(line); + } +} + +bool ApproxEqual(double a, double b) { + if (a == b) return true; + return (fabs(a-b)/fabs(b)) < 0.000001; +} + +int main(int argc, char** argv) { + register_feature_functions(); + SetSilent(true); // turn off verbose decoder output + + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as())); + else + rng.reset(new MT19937); + const bool sample_forest = conf.count("sample_forest") > 0; + const bool sample_forest_unit_weight_vector = conf.count("sample_forest_unit_weight_vector") > 0; + if (sample_forest_unit_weight_vector && !sample_forest) { + cerr << "Cannot --sample_forest_unit_weight_vector without --sample_forest" << endl; + return 1; + } + vector corpus; + ReadTrainingCorpus(conf["source"].as(), &corpus); + const string metric_name = conf["mt_metric"].as(); + ScoreType type = ScoreTypeFromString(metric_name); + if (type == TER) { + invert_score = true; + } else { + invert_score = false; + } + DocScorer ds(type, conf["reference"].as >(), ""); + cerr << "Loaded " << ds.size() << " references for scoring with " << metric_name << endl; + if (ds.size() != corpus.size()) { + cerr << "Mismatched number of references (" << ds.size() << ") and sources (" << corpus.size() << ")\n"; + return 1; + } + + ReadFile ini_rf(conf["decoder_config"].as()); + Decoder decoder(ini_rf.stream()); + + // load initial weights + vector& dense_weights = decoder.CurrentWeightVector(); + SparseVector lambdas; + Weights::InitFromFile(conf["input_weights"].as(), &dense_weights); + Weights::InitSparseVector(dense_weights, &lambdas); + + const double max_step_size = conf["max_step_size"].as(); + const double mt_metric_scale = conf["mt_metric_scale"].as(); + + assert(corpus.size() > 0); + vector oracles(corpus.size()); + + TrainingObserver observer(conf["k_best_size"].as(), ds, sample_forest, &oracles); + int cur_sent = 0; + int lcount = 0; + int normalizer = 0; + double tot_loss = 0; + int dots = 0; + int cur_pass = 0; + SparseVector tot; + tot += lambdas; // initial weights + normalizer++; // count for initial weights + int max_iteration = conf["passes"].as() * corpus.size(); + string msg = "# MIRA tuned weights"; + string msga = "# MIRA tuned weights AVERAGED"; + vector order; + RandomPermutation(corpus.size(), &order); + while (lcount <= max_iteration) { + lambdas.init_vector(&dense_weights); + if ((cur_sent * 40 / corpus.size()) > dots) { ++dots; cerr << '.'; } + if (corpus.size() == cur_sent) { + cerr << " [AVG METRIC LAST PASS=" << (tot_loss / corpus.size()) << "]\n"; + Weights::ShowLargestFeatures(dense_weights); + cur_sent = 0; + tot_loss = 0; + dots = 0; + ostringstream os; + os << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << ".gz"; + SparseVector x = tot; + x /= normalizer; + ostringstream sa; + sa << "weights.mira-pass" << (cur_pass < 10 ? "0" : "") << cur_pass << "-avg.gz"; + x.init_vector(&dense_weights); + Weights::WriteToFile(os.str(), dense_weights, true, &msg); + ++cur_pass; + RandomPermutation(corpus.size(), &order); + } + if (cur_sent == 0) { + cerr << "PASS " << (lcount / corpus.size() + 1) << endl; + } + decoder.SetId(order[cur_sent]); + double sc = 1.0; + if (sample_forest_unit_weight_vector) { + sc = lambdas.l2norm(); + if (sc > 0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] /= sc; + } + } + decoder.Decode(corpus[order[cur_sent]], &observer); // update oracles + if (sc && sc != 1.0) { + for (unsigned i = 0; i < dense_weights.size(); ++i) + dense_weights[i] *= sc; + } + const HypothesisInfo& cur_hyp = observer.GetCurrentBestHypothesis(); + const HypothesisInfo& cur_good = *oracles[order[cur_sent]].good; + const HypothesisInfo& cur_bad = *oracles[order[cur_sent]].bad; + tot_loss += cur_hyp.mt_metric; + if (!ApproxEqual(cur_hyp.mt_metric, cur_good.mt_metric)) { + const double loss = cur_bad.features.dot(dense_weights) - cur_good.features.dot(dense_weights) + + mt_metric_scale * (cur_good.mt_metric - cur_bad.mt_metric); + //cerr << "LOSS: " << loss << endl; + if (loss > 0.0) { + SparseVector diff = cur_good.features; + diff -= cur_bad.features; + double step_size = loss / diff.l2norm_sq(); + //cerr << loss << " " << step_size << " " << diff << endl; + if (step_size > max_step_size) step_size = max_step_size; + lambdas += (cur_good.features * step_size); + lambdas -= (cur_bad.features * step_size); + //cerr << "L: " << lambdas << endl; + } + } + tot += lambdas; + ++normalizer; + ++lcount; + ++cur_sent; + } + cerr << endl; + Weights::WriteToFile("weights.mira-final.gz", dense_weights, true, &msg); + tot /= normalizer; + tot.init_vector(dense_weights); + msg = "# MIRA tuned weights (averaged vector)"; + Weights::WriteToFile("weights.mira-final-avg.gz", dense_weights, true, &msg); + cerr << "Optimization complete.\nAVERAGED WEIGHTS: weights.mira-final-avg.gz\n"; + return 0; +} + diff --git a/training/mpi_batch_optimize.cc b/training/mpi_batch_optimize.cc deleted file mode 100644 index 2eff07e4..00000000 --- a/training/mpi_batch_optimize.cc +++ /dev/null @@ -1,372 +0,0 @@ -#include -#include -#include -#include -#include - -#include "config.h" -#ifdef HAVE_MPI -#include -#include -namespace mpi = boost::mpi; -#endif - -#include -#include -#include - -#include "sentence_metadata.h" -#include "cllh_observer.h" -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "stringlib.h" -#include "optimize.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_weights,w",po::value(),"Input feature weights file") - ("training_data,t",po::value(),"Training data") - ("test_data,T",po::value(),"(optional) test data") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") - ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") - ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") - ("gaussian_prior,p","Use a Gaussian prior on the weights") - ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior") - ("means,u", po::value(), "(optional) file containing the means for Gaussian prior"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("input_weights") || !(conf->count("training_data")) || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) c->push_back(line); - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_grad.clear(); - acc_obj = 0; - total_complete = 0; - trg_words = 0; - } - - void SetLocalGradientAndObjective(vector* g, double* o) const { - *o = acc_obj; - for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second.as_float(); - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - cur_model_exp.clear(); - cur_obj = 0; - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - state = 2; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_obj = log(z); - cur_model_exp /= z; - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - SparseVector ref_exp; - const prob_t ref_z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); - ref_exp /= ref_z; - - double log_ref_z; -#if 0 - if (crf_uniform_empirical) { - log_ref_z = ref_exp.dot(feature_weights); - } else { - log_ref_z = log(ref_z); - } -#else - log_ref_z = log(ref_z); -#endif - - // rounding errors means that <0 is too strict - if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; - exit(1); - } - assert(!std::isnan(log_ref_z)); - ref_exp -= cur_model_exp; - acc_grad -= ref_exp; - acc_obj += (cur_obj - log_ref_z); - trg_words += smeta.GetReference().size(); - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - if (state == 3) { - ++total_complete; - } else { - } - } - - int total_complete; - SparseVector cur_model_exp; - SparseVector acc_grad; - double acc_obj; - double cur_obj; - unsigned trg_words; - int state; -}; - -void ReadConfig(const string& ini, vector* out) { - ReadFile rf(ini); - istream& in = *rf.stream(); - while(in) { - string line; - getline(in, line); - if (!in) continue; - out->push_back(line); - } -} - -void StoreConfig(const vector& cfg, istringstream* o) { - ostringstream os; - for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } - o->str(os.str()); -} - -template -struct VectorPlus : public binary_function, vector, vector > { - vector operator()(const vector& a, const vector& b) const { - assert(a.size() == b.size()); - vector v(a.size()); - transform(a.begin(), a.end(), b.begin(), v.begin(), plus()); - return v; - } -}; - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) return 1; - - // load cdec.ini and set up decoder - vector cdec_ini; - ReadConfig(conf["decoder_config"].as(), &cdec_ini); - istringstream ini; - StoreConfig(cdec_ini, &ini); - if (rank == 0) cerr << "Loading grammar...\n"; - Decoder* decoder = new Decoder(&ini); - if (decoder->GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - return 1; - } - if (rank == 0) cerr << "Done loading grammar!\n"; - - // load initial weights - if (rank == 0) { cerr << "Loading weights...\n"; } - vector& lambdas = decoder->CurrentWeightVector(); - Weights::InitFromFile(conf["input_weights"].as(), &lambdas); - if (rank == 0) { cerr << "Done loading weights.\n"; } - - // freeze feature set (should be optional?) - const bool freeze_feature_set = true; - if (freeze_feature_set) FD::Freeze(); - - const int num_feats = FD::NumFeats(); - if (rank == 0) cerr << "Number of features: " << num_feats << endl; - lambdas.resize(num_feats); - - const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); - if (conf.count("means")) { - if (!gaussian_prior) { - cerr << "Don't use --means without --gaussian_prior!\n"; - exit(1); - } - Weights::InitFromFile(conf["means"].as(), &means); - } - boost::shared_ptr o; - if (rank == 0) { - const string omethod = conf["optimization_method"].as(); - if (omethod == "rprop") - o.reset(new RPropOptimizer(num_feats)); // TODO add configuration - else - o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as())); - cerr << "Optimizer: " << o->Name() << endl; - } - double objective = 0; - vector gradient(num_feats, 0.0); - vector rcv_grad; - rcv_grad.clear(); - bool converged = false; - - vector corpus, test_corpus; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - assert(corpus.size() > 0); - if (conf.count("test_data")) - ReadTrainingCorpus(conf["test_data"].as(), rank, size, &test_corpus); - - TrainingObserver observer; - ConditionalLikelihoodObserver cllh_observer; - while (!converged) { - observer.Reset(); - cllh_observer.Reset(); -#ifdef HAVE_MPI - mpi::timer timer; - world.barrier(); -#endif - if (rank == 0) { - cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; - cerr << " Testset size: " << test_corpus.size() << " sentences / proc)\n"; - } - for (int i = 0; i < corpus.size(); ++i) - decoder->Decode(corpus[i], &observer); - cerr << " process " << rank << '/' << size << " done\n"; - fill(gradient.begin(), gradient.end(), 0); - observer.SetLocalGradientAndObjective(&gradient, &objective); - - unsigned total_words = 0; -#ifdef HAVE_MPI - double to = 0; - rcv_grad.resize(num_feats, 0.0); - mpi::reduce(world, &gradient[0], gradient.size(), &rcv_grad[0], plus(), 0); - swap(gradient, rcv_grad); - rcv_grad.clear(); - - reduce(world, observer.trg_words, total_words, std::plus(), 0); - mpi::reduce(world, objective, to, plus(), 0); - objective = to; -#else - total_words = observer.trg_words; -#endif - if (rank == 0) - cerr << "TRAINING CORPUS: ln p(f|e)=" << objective << "\t log_2 p(f|e) = " << (objective/log(2)) << "\t cond. entropy = " << (objective/log(2) / total_words) << "\t ppl = " << pow(2, (objective/log(2) / total_words)) << endl; - - for (int i = 0; i < test_corpus.size(); ++i) - decoder->Decode(test_corpus[i], &cllh_observer); - - double test_objective = 0; - unsigned test_total_words = 0; -#ifdef HAVE_MPI - reduce(world, cllh_observer.acc_obj, test_objective, std::plus(), 0); - reduce(world, cllh_observer.trg_words, test_total_words, std::plus(), 0); -#else - test_objective = cllh_observer.acc_obj; - test_total_words = cllh_observer.trg_words; -#endif - - if (rank == 0) { // run optimizer only on rank=0 node - if (test_corpus.size()) - cerr << " TEST CORPUS: ln p(f|e)=" << test_objective << "\t log_2 p(f|e) = " << (test_objective/log(2)) << "\t cond. entropy = " << (test_objective/log(2) / test_total_words) << "\t ppl = " << pow(2, (test_objective/log(2) / test_total_words)) << endl; - if (gaussian_prior) { - const double sigsq = conf["sigma_squared"].as(); - double norm = 0; - for (int k = 1; k < lambdas.size(); ++k) { - const double& lambda_k = lambdas[k]; - if (lambda_k) { - const double param = (lambda_k - means[k]); - norm += param * param; - gradient[k] += param / sigsq; - } - } - const double reg = norm / (2.0 * sigsq); - cerr << "REGULARIZATION TERM: " << reg << endl; - objective += reg; - } - cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; - double gnorm = 0; - for (int i = 0; i < gradient.size(); ++i) - gnorm += gradient[i] * gradient[i]; - cerr << " GNORM=" << sqrt(gnorm) << endl; - vector old = lambdas; - int c = 0; - while (old == lambdas) { - ++c; - if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } - o->Optimize(objective, gradient, &lambdas); - assert(c < 5); - } - old.clear(); - Weights::SanityCheck(lambdas); - Weights::ShowLargestFeatures(lambdas); - - converged = o->HasConverged(); - if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } - - string fname = "weights.cur.gz"; - if (converged) { fname = "weights.final.gz"; } - ostringstream vv; - vv << "Objective = " << objective << " (eval count=" << o->EvaluationCount() << ")"; - const string svv = vv.str(); - Weights::WriteToFile(fname, lambdas, true, &svv); - } // rank == 0 - int cint = converged; -#ifdef HAVE_MPI - mpi::broadcast(world, &lambdas[0], lambdas.size(), 0); - mpi::broadcast(world, cint, 0); - if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } -#endif - converged = cint; - } - return 0; -} - diff --git a/training/mpi_compute_cllh.cc b/training/mpi_compute_cllh.cc deleted file mode 100644 index 066389d0..00000000 --- a/training/mpi_compute_cllh.cc +++ /dev/null @@ -1,134 +0,0 @@ -#include -#include -#include -#include - -#include "config.h" -#ifdef HAVE_MPI -#include -#endif -#include -#include - -#include "cllh_observer.h" -#include "sentence_metadata.h" -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "weights.h" - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("weights,w",po::value(),"Input feature weights file") - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadInstances(const string& fname, int rank, int size, vector* c) { - assert(fname != "-"); - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) c->push_back(line); - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -#ifdef HAVE_MPI -namespace mpi = boost::mpi; -#endif - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return false; - - // load cdec.ini and set up decoder - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - - // load weights - vector& weights = decoder.CurrentWeightVector(); - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &weights); - - vector corpus; - ReadInstances(conf["training_data"].as(), rank, size, &corpus); - assert(corpus.size() > 0); - - if (rank == 0) - cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; - - ConditionalLikelihoodObserver observer; - for (int i = 0; i < corpus.size(); ++i) - decoder.Decode(corpus[i], &observer); - - double objective = 0; - unsigned total_words = 0; -#ifdef HAVE_MPI - reduce(world, observer.acc_obj, objective, std::plus(), 0); - reduce(world, observer.trg_words, total_words, std::plus(), 0); -#else - objective = observer.acc_obj; -#endif - - if (rank == 0) { - cout << "CONDITIONAL LOG_e LIKELIHOOD: " << objective << endl; - cout << "CONDITIONAL LOG_2 LIKELIHOOD: " << (objective/log(2)) << endl; - cout << " CONDITIONAL ENTROPY: " << (objective/log(2) / total_words) << endl; - cout << " PERPLEXITY: " << pow(2, (objective/log(2) / total_words)) << endl; - } - - return 0; -} - diff --git a/training/mpi_em_optimize.cc b/training/mpi_em_optimize.cc deleted file mode 100644 index 48683b15..00000000 --- a/training/mpi_em_optimize.cc +++ /dev/null @@ -1,389 +0,0 @@ -#include -#include -#include -#include -#include - -#ifdef HAVE_MPI -#include -#endif - -#include -#include -#include - -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "optimize.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" - -using namespace std; -using boost::shared_ptr; -namespace po = boost::program_options; - -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!isnan(w[i])); - assert(!isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_weights,w",po::value(),"Input feature weights file") - ("training_data,t",po::value(),"Training data") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("output_weights,o",po::value()->default_value("-"),"Output feature weights file"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !(conf->count("training_data")) || !conf->count("decoder_config")) { - cerr << dcmdline_options << endl; -#ifdef HAVE_MPI - MPI::Finalize(); -#endif - exit(1); - } -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) c->push_back(line); - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - void Reset() { - total_complete = 0; - cur_obj = 0; - tot_obj = 0; - tot.clear(); - } - - void SetLocalGradientAndObjective(SparseVector* g, double* o) const { - *o = tot_obj; - *g = tot; - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - cur_obj = 0; - state = 1; - } - - void ExtractExpectedCounts(Hypergraph* hg) { - vector posts; - cur.clear(); - const prob_t z = hg->ComputeEdgePosteriors(1.0, &posts); - cur_obj = log(z); - for (int i = 0; i < posts.size(); ++i) { - const SparseVector& efeats = hg->edges_[i].feature_values_; - const double post = static_cast(posts[i] / z); - for (SparseVector::const_iterator j = efeats.begin(); j != efeats.end(); ++j) - cur.add_value(j->first, post); - } - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - state = 2; - ExtractExpectedCounts(hg); - } - - // replace translation forest, since we're doing EM training (we don't know which) - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - ExtractExpectedCounts(hg); - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - ++total_complete; - tot_obj += cur_obj; - tot += cur; - } - - int total_complete; - double cur_obj; - double tot_obj; - SparseVector cur, tot; - int state; -}; - -void ReadConfig(const string& ini, vector* out) { - ReadFile rf(ini); - istream& in = *rf.stream(); - while(in) { - string line; - getline(in, line); - if (!in) continue; - out->push_back(line); - } -} - -void StoreConfig(const vector& cfg, istringstream* o) { - ostringstream os; - for (int i = 0; i < cfg.size(); ++i) { os << cfg[i] << endl; } - o->str(os.str()); -} - -struct OptimizableMultinomialFamily { - struct CPD { - CPD() : z() {} - double z; - map c2counts; - }; - map counts; - double Value(WordID conditioning, WordID generated) const { - map::const_iterator it = counts.find(conditioning); - assert(it != counts.end()); - map::const_iterator r = it->second.c2counts.find(generated); - if (r == it->second.c2counts.end()) return 0; - return r->second; - } - void Increment(WordID conditioning, WordID generated, double count) { - CPD& cc = counts[conditioning]; - cc.z += count; - cc.c2counts[generated] += count; - } - void Optimize() { - for (map::iterator i = counts.begin(); i != counts.end(); ++i) { - CPD& cpd = i->second; - for (map::iterator j = cpd.c2counts.begin(); j != cpd.c2counts.end(); ++j) { - j->second /= cpd.z; - // cerr << "P(" << TD::Convert(j->first) << " | " << TD::Convert(i->first) << " ) = " << j->second << endl; - } - } - } - void Clear() { - counts.clear(); - } -}; - -struct CountManager { - CountManager(size_t num_types) : oms_(num_types) {} - virtual ~CountManager(); - virtual void AddCounts(const SparseVector& c) = 0; - void Optimize(SparseVector* weights) { - for (int i = 0; i < oms_.size(); ++i) { - oms_[i].Optimize(); - } - GetOptimalValues(weights); - for (int i = 0; i < oms_.size(); ++i) { - oms_[i].Clear(); - } - } - virtual void GetOptimalValues(SparseVector* wv) const = 0; - vector oms_; -}; -CountManager::~CountManager() {} - -struct TaggerCountManager : public CountManager { - // 0 = transitions, 2 = emissions - TaggerCountManager() : CountManager(2) {} - void AddCounts(const SparseVector& c); - void GetOptimalValues(SparseVector* wv) const { - for (set::const_iterator it = fids_.begin(); it != fids_.end(); ++it) { - int ftype; - WordID cond, gen; - bool is_optimized = TaggerCountManager::GetFeature(*it, &ftype, &cond, &gen); - assert(is_optimized); - wv->set_value(*it, log(oms_[ftype].Value(cond, gen))); - } - } - // Id:0:a=1 Bi:a_b=1 Bi:b_c=1 Bi:c_d=1 Uni:a=1 Uni:b=1 Uni:c=1 Uni:d=1 Id:1:b=1 Bi:BOS_a=1 Id:2:c=1 - static bool GetFeature(const int fid, int* feature_type, WordID* cond, WordID* gen) { - const string& feat = FD::Convert(fid); - if (feat.size() > 5 && feat[0] == 'I' && feat[1] == 'd' && feat[2] == ':') { - // emission - const size_t p = feat.rfind(':'); - assert(p != string::npos); - *cond = TD::Convert(feat.substr(p+1)); - *gen = TD::Convert(feat.substr(3, p - 3)); - *feature_type = 1; - return true; - } else if (feat[0] == 'B' && feat.size() > 5 && feat[2] == ':' && feat[1] == 'i') { - // transition - const size_t p = feat.rfind('_'); - assert(p != string::npos); - *gen = TD::Convert(feat.substr(p+1)); - *cond = TD::Convert(feat.substr(3, p - 3)); - *feature_type = 0; - return true; - } else if (feat[0] == 'U' && feat.size() > 4 && feat[1] == 'n' && feat[2] == 'i' && feat[3] == ':') { - // ignore - return false; - } else { - cerr << "Don't know how to deal with feature of type: " << feat << endl; - abort(); - } - } - set fids_; -}; - -void TaggerCountManager::AddCounts(const SparseVector& c) { - for (SparseVector::const_iterator it = c.begin(); it != c.end(); ++it) { - const double& val = it->second; - int ftype; - WordID cond, gen; - if (GetFeature(it->first, &ftype, &cond, &gen)) { - oms_[ftype].Increment(cond, gen, val); - fids_.insert(it->first); - } - } -} - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - MPI::Init(argc, argv); - const int size = MPI::COMM_WORLD.Get_size(); - const int rank = MPI::COMM_WORLD.Get_rank(); -#else - const int size = 1; - const int rank = 0; -#endif - SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - - TaggerCountManager tcm; - - // load cdec.ini and set up decoder - vector cdec_ini; - ReadConfig(conf["decoder_config"].as(), &cdec_ini); - istringstream ini; - StoreConfig(cdec_ini, &ini); - if (rank == 0) cerr << "Loading grammar...\n"; - Decoder* decoder = new Decoder(&ini); - if (decoder->GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; -#ifdef HAVE_MPI - MPI::COMM_WORLD.Abort(1); -#endif - } - if (rank == 0) cerr << "Done loading grammar!\n"; - Weights w; - if (conf.count("input_weights")) - w.InitFromFile(conf["input_weights"].as()); - - double objective = 0; - bool converged = false; - - vector lambdas; - w.InitVector(&lambdas); - vector corpus; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - assert(corpus.size() > 0); - - int iteration = 0; - TrainingObserver observer; - while (!converged) { - ++iteration; - observer.Reset(); - if (rank == 0) { - cerr << "Starting decoding... (~" << corpus.size() << " sentences / proc)\n"; - } - decoder->SetWeights(lambdas); - for (int i = 0; i < corpus.size(); ++i) - decoder->Decode(corpus[i], &observer); - - SparseVector x; - observer.SetLocalGradientAndObjective(&x, &objective); - cerr << "COUNTS = " << x << endl; - cerr << " OBJ = " << objective << endl; - tcm.AddCounts(x); - -#if 0 -#ifdef HAVE_MPI - MPI::COMM_WORLD.Reduce(const_cast(&gradient.data()[0]), &rcv_grad[0], num_feats, MPI::DOUBLE, MPI::SUM, 0); - MPI::COMM_WORLD.Reduce(&objective, &to, 1, MPI::DOUBLE, MPI::SUM, 0); - swap(gradient, rcv_grad); - objective = to; -#endif -#endif - - if (rank == 0) { - SparseVector wsv; - tcm.Optimize(&wsv); - - w.InitFromVector(wsv); - w.InitVector(&lambdas); - - ShowLargestFeatures(lambdas); - - converged = iteration > 100; - if (converged) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } - - string fname = "weights.cur.gz"; - if (converged) { fname = "weights.final.gz"; } - ostringstream vv; - vv << "Objective = " << objective << " (ITERATION=" << iteration << ")"; - const string svv = vv.str(); - w.WriteToFile(fname, true, &svv); - } // rank == 0 - int cint = converged; -#ifdef HAVE_MPI - MPI::COMM_WORLD.Bcast(const_cast(&lambdas.data()[0]), num_feats, MPI::DOUBLE, 0); - MPI::COMM_WORLD.Bcast(&cint, 1, MPI::INT, 0); - MPI::COMM_WORLD.Barrier(); -#endif - converged = cint; - } -#ifdef HAVE_MPI - MPI::Finalize(); -#endif - return 0; -} diff --git a/training/mpi_extract_features.cc b/training/mpi_extract_features.cc deleted file mode 100644 index 6750aa15..00000000 --- a/training/mpi_extract_features.cc +++ /dev/null @@ -1,151 +0,0 @@ -#include -#include -#include -#include - -#include "config.h" -#ifdef HAVE_MPI -#include -#endif -#include -#include - -#include "ff_register.h" -#include "verbose.h" -#include "filelib.h" -#include "fdict.h" -#include "decoder.h" -#include "weights.h" - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") - ("output_prefix,o",po::value()->default_value("features"),"Output path prefix"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the feature strings encountered.\n"; - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) c->push_back(line); - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - - virtual void NotifyDecodingStart(const SentenceMetadata&) { - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - } -}; - -#ifdef HAVE_MPI -namespace mpi = boost::mpi; -#endif - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return false; - - // load cdec.ini and set up decoder - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - - if (FD::UsingPerfectHashFunction()) { - cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; - return 1; - } - - // load optional weights - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); - - vector corpus; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - assert(corpus.size() > 0); - - TrainingObserver observer; - - if (rank == 0) - cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; - - for (int i = 0; i < corpus.size(); ++i) - decoder.Decode(corpus[i], &observer); - - { - ostringstream os; - os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; - WriteFile wf(os.str()); - ostream& out = *wf.stream(); - const unsigned num_feats = FD::NumFeats(); - for (unsigned i = 1; i < num_feats; ++i) { - out << FD::Convert(i) << endl; - } - cerr << "Wrote " << os.str() << endl; - } - -#ifdef HAVE_MPI - world.barrier(); -#else -#endif - - return 0; -} - diff --git a/training/mpi_extract_reachable.cc b/training/mpi_extract_reachable.cc deleted file mode 100644 index 2a7c2b9d..00000000 --- a/training/mpi_extract_reachable.cc +++ /dev/null @@ -1,163 +0,0 @@ -#include -#include -#include -#include - -#include "config.h" -#ifdef HAVE_MPI -#include -#endif -#include -#include - -#include "ff_register.h" -#include "verbose.h" -#include "filelib.h" -#include "fdict.h" -#include "decoder.h" -#include "weights.h" - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("training_data,t",po::value(),"Training data corpus") - ("decoder_config,c",po::value(),"Decoder configuration file") - ("weights,w", po::value(), "(Optional) weights file; weights may affect what features are encountered in pruning configurations") - ("output_prefix,o",po::value()->default_value("reachable"),"Output path prefix"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("decoder_config")) { - cerr << "Decode an input set (optionally in parallel using MPI) and write\nout the inputs that produce reachable parallel parses.\n"; - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int lc = 0; - while(in) { - getline(in, line); - if (!in) break; - if (lc % size == rank) c->push_back(line); - ++lc; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct ReachabilityObserver : public DecoderObserver { - - virtual void NotifyDecodingStart(const SentenceMetadata&) { - reachable = false; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - reachable = true; - } - - bool reachable; -}; - -#ifdef HAVE_MPI -namespace mpi = boost::mpi; -#endif - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return false; - - // load cdec.ini and set up decoder - ReadFile ini_rf(conf["decoder_config"].as()); - Decoder decoder(ini_rf.stream()); - if (decoder.GetConf()["input"].as() != "-") { - cerr << "cdec.ini must not set an input file\n"; - abort(); - } - - if (FD::UsingPerfectHashFunction()) { - cerr << "Your configuration file has enabled a cmph hash function. Please disable.\n"; - return 1; - } - - // load optional weights - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &decoder.CurrentWeightVector()); - - vector corpus; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus); - assert(corpus.size() > 0); - - - if (rank == 0) - cerr << "Each processor is decoding ~" << corpus.size() << " training examples...\n"; - - size_t num_reached = 0; - { - ostringstream os; - os << conf["output_prefix"].as() << '.' << rank << "_of_" << size; - WriteFile wf(os.str()); - ostream& out = *wf.stream(); - ReachabilityObserver observer; - for (int i = 0; i < corpus.size(); ++i) { - decoder.Decode(corpus[i], &observer); - if (observer.reachable) { - out << corpus[i] << endl; - ++num_reached; - } - corpus[i].clear(); - } - cerr << "Shard " << rank << '/' << size << " finished, wrote " - << num_reached << " instances to " << os.str() << endl; - } - - size_t total = 0; -#ifdef HAVE_MPI - reduce(world, num_reached, total, std::plus(), 0); -#else - total = num_reached; -#endif - if (rank == 0) { - cerr << "-----------------------------------------\n"; - cerr << "TOTAL = " << total << " instances\n"; - } - return 0; -} - diff --git a/training/mpi_flex_optimize.cc b/training/mpi_flex_optimize.cc deleted file mode 100644 index b52decdc..00000000 --- a/training/mpi_flex_optimize.cc +++ /dev/null @@ -1,386 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "stringlib.h" -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "optimize.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" -#include "sampler.h" - -#ifdef HAVE_MPI -#include -#include -namespace mpi = boost::mpi; -#endif - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("cdec_config,c",po::value(),"Decoder configuration file") - ("weights,w",po::value(),"Initial feature weights") - ("training_data,d",po::value(),"Training data") - ("minibatch_size_per_proc,s", po::value()->default_value(6), "Number of training instances evaluated per processor in each minibatch") - ("minibatch_iterations,i", po::value()->default_value(10), "Number of optimization iterations per minibatch") - ("iterations,I", po::value()->default_value(50), "Number of passes through the training data before termination") - ("regularization_strength,C", po::value()->default_value(0.2), "Regularization strength") - ("time_series_strength,T", po::value()->default_value(0.0), "Time series regularization strength") - ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") - ("lbfgs_memory_buffers,M", po::value()->default_value(10), "Number of memory buffers for LBFGS history"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("cdec_config")) { - cerr << "LBFGS minibatch online optimizer (MPI support " -#if HAVE_MPI - << "enabled" -#else - << "not enabled" -#endif - << ")\n" << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int id = 0; - while(in) { - getline(in, line); - if (!in) break; - if (id % size == rank) { - c->push_back(line); - order->push_back(id); - } - ++id; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct CopyHGsObserver : public DecoderObserver { - Hypergraph* hg_; - Hypergraph* gold_hg_; - - // this can free up some memory - void RemoveRules(Hypergraph* h) { - for (unsigned i = 0; i < h->edges_.size(); ++i) - h->edges_[i].rule_.reset(); - } - - void SetCurrentHypergraphs(Hypergraph* h, Hypergraph* gold_h) { - hg_ = h; - gold_hg_ = gold_h; - } - - virtual void NotifyDecodingStart(const SentenceMetadata&) { - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata&, Hypergraph* hg) { - *hg_ = *hg; - RemoveRules(hg_); - assert(state == 1); - state = 2; - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata&, Hypergraph* hg) { - assert(state == 2); - state = 3; - *gold_hg_ = *hg; - RemoveRules(gold_hg_); - } - - virtual void NotifyDecodingComplete(const SentenceMetadata&) { - if (state == 3) { - } else { - hg_->clear(); - gold_hg_->clear(); - } - } - - int state; -}; - -void ReadConfig(const string& ini, istringstream* out) { - ReadFile rf(ini); - istream& in = *rf.stream(); - ostringstream os; - while(in) { - string line; - getline(in, line); - if (!in) continue; - os << line << endl; - } - out->str(os.str()); -} - -#ifdef HAVE_MPI -namespace boost { namespace mpi { - template<> - struct is_commutative >, SparseVector > - : mpl::true_ { }; -} } // end namespace boost::mpi -#endif - -void AddGrad(const SparseVector x, double s, SparseVector* acc) { - for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) - acc->add_value(it->first, it->second.as_float() * s); -} - -double PNorm(const vector& v, const double p) { - double acc = 0; - for (int i = 0; i < v.size(); ++i) - acc += pow(v[i], p); - return pow(acc, 1.0 / p); -} - -void VV(ostream&os, const vector& v) { - for (int i = 1; i < v.size(); ++i) - if (v[i]) os << FD::Convert(i) << "=" << v[i] << " "; -} - -double ApplyRegularizationTerms(const double C, - const double T, - const vector& weights, - const vector& prev_weights, - double* g) { - double reg = 0; - for (size_t i = 0; i < weights.size(); ++i) { - const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); - const double& w_i = weights[i]; - reg += C * w_i * w_i; - g[i] += 2 * C * w_i; - - reg += T * (w_i - prev_w_i) * (w_i - prev_w_i); - g[i] += 2 * T * (w_i - prev_w_i); - } - return reg; -} - -int main(int argc, char** argv) { -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - MT19937* rng = NULL; - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return 1; - - boost::shared_ptr o; - const unsigned lbfgs_memory_buffers = conf["lbfgs_memory_buffers"].as(); - const unsigned size_per_proc = conf["minibatch_size_per_proc"].as(); - const unsigned minibatch_iterations = conf["minibatch_iterations"].as(); - const double regularization_strength = conf["regularization_strength"].as(); - const double time_series_strength = conf["time_series_strength"].as(); - const bool use_time_series_reg = time_series_strength > 0.0; - const unsigned max_iteration = conf["iterations"].as(); - - vector corpus; - vector ids; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); - assert(corpus.size() > 0); - - if (size_per_proc > corpus.size()) { - cerr << "Minibatch size (per processor) must be smaller or equal to the local corpus size!\n"; - return 1; - } - - // initialize decoder (loads hash functions if necessary) - istringstream ins; - ReadConfig(conf["cdec_config"].as(), &ins); - Decoder decoder(&ins); - - // load initial weights - vector prev_weights; - if (conf.count("weights")) - Weights::InitFromFile(conf["weights"].as(), &prev_weights); - - if (conf.count("random_seed")) - rng = new MT19937(conf["random_seed"].as()); - else - rng = new MT19937; - - size_t total_corpus_size = 0; -#ifdef HAVE_MPI - reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); -#else - total_corpus_size = corpus.size(); -#endif - - if (rank == 0) - cerr << "Total corpus size: " << total_corpus_size << endl; - - CopyHGsObserver observer; - - int write_weights_every_ith = 100; // TODO configure - int titer = -1; - - vector& cur_weights = decoder.CurrentWeightVector(); - if (use_time_series_reg) { - cur_weights = prev_weights; - } else { - cur_weights.swap(prev_weights); - prev_weights.clear(); - } - - int iter = -1; - bool converged = false; - vector gg; - while (!converged) { -#ifdef HAVE_MPI - mpi::timer timer; -#endif - ++iter; ++titer; - if (rank == 0) { - converged = (iter == max_iteration); - string fname = "weights.cur.gz"; - if (iter % write_weights_every_ith == 0) { - ostringstream o; o << "weights.epoch_" << iter << ".gz"; - fname = o.str(); - } - if (converged) { fname = "weights.final.gz"; } - ostringstream vv; - vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())); - const string svv = vv.str(); - Weights::WriteToFile(fname, cur_weights, true, &svv); - } - - vector hgs(size_per_proc); - vector gold_hgs(size_per_proc); - for (int i = 0; i < size_per_proc; ++i) { - int ei = corpus.size() * rng->next(); - int id = ids[ei]; - observer.SetCurrentHypergraphs(&hgs[i], &gold_hgs[i]); - decoder.SetId(id); - decoder.Decode(corpus[ei], &observer); - } - - SparseVector local_grad, g; - double local_obj = 0; - o.reset(); - for (unsigned mi = 0; mi < minibatch_iterations; ++mi) { - local_grad.clear(); - g.clear(); - local_obj = 0; - - for (unsigned i = 0; i < size_per_proc; ++i) { - Hypergraph& hg = hgs[i]; - Hypergraph& hg_gold = gold_hgs[i]; - if (hg.edges_.size() < 2) continue; - - hg.Reweight(cur_weights); - hg_gold.Reweight(cur_weights); - SparseVector model_exp, gold_exp; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(hg, &model_exp); - local_obj += log(z); - model_exp /= z; - AddGrad(model_exp, 1.0, &local_grad); - model_exp.clear(); - - const prob_t goldz = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(hg_gold, &gold_exp); - local_obj -= log(goldz); - - if (log(z) - log(goldz) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_gold_z: " << log(z) << " " << log(goldz) << endl; - return 1; - } - - gold_exp /= goldz; - AddGrad(gold_exp, -1.0, &local_grad); - } - - double obj = 0; -#ifdef HAVE_MPI - reduce(world, local_obj, obj, std::plus(), 0); - reduce(world, local_grad, g, std::plus >(), 0); -#else - obj = local_obj; - g.swap(local_grad); -#endif - local_grad.clear(); - if (rank == 0) { - // g /= (size_per_proc * size); - if (!o) - o.reset(new LBFGSOptimizer(FD::NumFeats(), lbfgs_memory_buffers)); - gg.clear(); - gg.resize(FD::NumFeats()); - if (gg.size() != cur_weights.size()) { cur_weights.resize(gg.size()); } - for (SparseVector::iterator it = g.begin(); it != g.end(); ++it) - if (it->first) { gg[it->first] = it->second; } - g.clear(); - double r = ApplyRegularizationTerms(regularization_strength, - time_series_strength, // * (iter == 0 ? 0.0 : 1.0), - cur_weights, - prev_weights, - &gg[0]); - obj += r; - if (mi == 0 || mi == (minibatch_iterations - 1)) { - if (!mi) cerr << iter << ' '; else cerr << ' '; - cerr << "OBJ=" << obj << " (REG=" << r << ")" << " |g|=" << PNorm(gg, 2) << " |w|=" << PNorm(cur_weights, 2); - if (mi > 0) cerr << endl << flush; else cerr << ' '; - } else { cerr << '.' << flush; } - // cerr << "w = "; VV(cerr, cur_weights); cerr << endl; - // cerr << "g = "; VV(cerr, gg); cerr << endl; - o->Optimize(obj, gg, &cur_weights); - } -#ifdef HAVE_MPI - broadcast(world, cur_weights, 0); - broadcast(world, converged, 0); - world.barrier(); -#endif - } - prev_weights = cur_weights; - } - return 0; -} diff --git a/training/mpi_online_optimize.cc b/training/mpi_online_optimize.cc deleted file mode 100644 index d6968848..00000000 --- a/training/mpi_online_optimize.cc +++ /dev/null @@ -1,374 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "stringlib.h" -#include "verbose.h" -#include "hg.h" -#include "prob.h" -#include "inside_outside.h" -#include "ff_register.h" -#include "decoder.h" -#include "filelib.h" -#include "online_optimizer.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" -#include "sampler.h" - -#ifdef HAVE_MPI -#include -#include -namespace mpi = boost::mpi; -#endif - -using namespace std; -namespace po = boost::program_options; - -bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_weights,w",po::value(),"Input feature weights file") - ("frozen_features,z",po::value(), "List of features not to optimize") - ("training_data,t",po::value(),"Training data corpus") - ("training_agenda,a",po::value(), "Text file listing a series of configuration files and the number of iterations to train using each configuration successively") - ("minibatch_size_per_proc,s", po::value()->default_value(5), "Number of training instances evaluated per processor in each minibatch") - ("optimization_method,m", po::value()->default_value("sgd"), "Optimization method (sgd)") - ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") - ("eta_0,e", po::value()->default_value(0.2), "Initial learning rate for SGD (eta_0)") - ("L1,1","Use L1 regularization") - ("regularization_strength,C", po::value()->default_value(1.0), "Regularization strength (C)"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("training_data") || !conf->count("training_agenda")) { - cerr << dcmdline_options << endl; - return false; - } - return true; -} - -void ReadTrainingCorpus(const string& fname, int rank, int size, vector* c, vector* order) { - ReadFile rf(fname); - istream& in = *rf.stream(); - string line; - int id = 0; - while(in) { - getline(in, line); - if (!in) break; - if (id % size == rank) { - c->push_back(line); - order->push_back(id); - } - ++id; - } -} - -static const double kMINUS_EPSILON = -1e-6; - -struct TrainingObserver : public DecoderObserver { - void Reset() { - acc_grad.clear(); - acc_obj = 0; - total_complete = 0; - } - - void SetLocalGradientAndObjective(vector* g, double* o) const { - *o = acc_obj; - for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - (*g)[it->first] = it->second.as_float(); - } - - virtual void NotifyDecodingStart(const SentenceMetadata& smeta) { - cur_model_exp.clear(); - cur_obj = 0; - state = 1; - } - - // compute model expectations, denominator of objective - virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 1); - state = 2; - const prob_t z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &cur_model_exp); - cur_obj = log(z); - cur_model_exp /= z; - } - - // compute "empirical" expectations, numerator of objective - virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg) { - assert(state == 2); - state = 3; - SparseVector ref_exp; - const prob_t ref_z = InsideOutside, - EdgeFeaturesAndProbWeightFunction>(*hg, &ref_exp); - ref_exp /= ref_z; - - double log_ref_z; -#if 0 - if (crf_uniform_empirical) { - log_ref_z = ref_exp.dot(feature_weights); - } else { - log_ref_z = log(ref_z); - } -#else - log_ref_z = log(ref_z); -#endif - - // rounding errors means that <0 is too strict - if ((cur_obj - log_ref_z) < kMINUS_EPSILON) { - cerr << "DIFF. ERR! log_model_z < log_ref_z: " << cur_obj << " " << log_ref_z << endl; - exit(1); - } - assert(!std::isnan(log_ref_z)); - ref_exp -= cur_model_exp; - acc_grad += ref_exp; - acc_obj += (cur_obj - log_ref_z); - } - - virtual void NotifyDecodingComplete(const SentenceMetadata& smeta) { - if (state == 3) { - ++total_complete; - } else { - } - } - - void GetGradient(SparseVector* g) const { - g->clear(); - for (SparseVector::const_iterator it = acc_grad.begin(); it != acc_grad.end(); ++it) - g->set_value(it->first, it->second.as_float()); - } - - int total_complete; - SparseVector cur_model_exp; - SparseVector acc_grad; - double acc_obj; - double cur_obj; - int state; -}; - -#ifdef HAVE_MPI -namespace boost { namespace mpi { - template<> - struct is_commutative >, SparseVector > - : mpl::true_ { }; -} } // end namespace boost::mpi -#endif - -bool LoadAgenda(const string& file, vector >* a) { - ReadFile rf(file); - istream& in = *rf.stream(); - string line; - while(in) { - getline(in, line); - if (!in) break; - if (line.empty()) continue; - if (line[0] == '#') continue; - int sc = 0; - if (line.size() < 3) return false; - for (int i = 0; i < line.size(); ++i) { if (line[i] == ' ') ++sc; } - if (sc != 1) { cerr << "Too many spaces in line: " << line << endl; return false; } - size_t d = line.find(" "); - pair x; - x.first = line.substr(0,d); - x.second = atoi(line.substr(d+1).c_str()); - a->push_back(x); - if (!FileExists(x.first)) { - cerr << "Can't find file " << x.first << endl; - return false; - } - } - return true; -} - -int main(int argc, char** argv) { - cerr << "THIS SOFTWARE IS DEPRECATED YOU SHOULD USE mpi_flex_optimize\n"; -#ifdef HAVE_MPI - mpi::environment env(argc, argv); - mpi::communicator world; - const int size = world.size(); - const int rank = world.rank(); -#else - const int size = 1; - const int rank = 0; -#endif - if (size > 1) SetSilent(true); // turn off verbose decoder output - register_feature_functions(); - std::tr1::shared_ptr rng; - - po::variables_map conf; - if (!InitCommandLine(argc, argv, &conf)) - return 1; - - vector > agenda; - if (!LoadAgenda(conf["training_agenda"].as(), &agenda)) - return 1; - if (rank == 0) - cerr << "Loaded agenda defining " << agenda.size() << " training epochs\n"; - - assert(agenda.size() > 0); - - if (1) { // hack to load the feature hash functions -- TODO this should not be in cdec.ini - const string& cur_config = agenda[0].first; - const unsigned max_iteration = agenda[0].second; - ReadFile ini_rf(cur_config); - Decoder decoder(ini_rf.stream()); - } - - // load initial weights - vector init_weights; - if (conf.count("input_weights")) - Weights::InitFromFile(conf["input_weights"].as(), &init_weights); - - vector frozen_fids; - if (conf.count("frozen_features")) { - ReadFile rf(conf["frozen_features"].as()); - istream& in = *rf.stream(); - string line; - while(in) { - getline(in, line); - if (line.empty()) continue; - if (line[0] == ' ' || line[line.size() - 1] == ' ') { line = Trim(line); } - frozen_fids.push_back(FD::Convert(line)); - } - if (rank == 0) cerr << "Freezing " << frozen_fids.size() << " features.\n"; - } - - vector corpus; - vector ids; - ReadTrainingCorpus(conf["training_data"].as(), rank, size, &corpus, &ids); - assert(corpus.size() > 0); - - std::tr1::shared_ptr o; - std::tr1::shared_ptr lr; - - const unsigned size_per_proc = conf["minibatch_size_per_proc"].as(); - if (size_per_proc > corpus.size()) { - cerr << "Minibatch size must be smaller than corpus size!\n"; - return 1; - } - - size_t total_corpus_size = 0; -#ifdef HAVE_MPI - reduce(world, corpus.size(), total_corpus_size, std::plus(), 0); -#else - total_corpus_size = corpus.size(); -#endif - - if (rank == 0) { - cerr << "Total corpus size: " << total_corpus_size << endl; - const unsigned batch_size = size_per_proc * size; - // TODO config - lr.reset(new ExponentialDecayLearningRate(batch_size, conf["eta_0"].as())); - - const string omethod = conf["optimization_method"].as(); - if (omethod == "sgd") { - const double C = conf["regularization_strength"].as(); - o.reset(new CumulativeL1OnlineOptimizer(lr, total_corpus_size, C, frozen_fids)); - } else { - assert(!"fail"); - } - } - if (conf.count("random_seed")) - rng.reset(new MT19937(conf["random_seed"].as())); - else - rng.reset(new MT19937); - - SparseVector x; - Weights::InitSparseVector(init_weights, &x); - TrainingObserver observer; - - int write_weights_every_ith = 100; // TODO configure - int titer = -1; - - for (int ai = 0; ai < agenda.size(); ++ai) { - const string& cur_config = agenda[ai].first; - const unsigned max_iteration = agenda[ai].second; - if (rank == 0) - cerr << "STARTING TRAINING EPOCH " << (ai+1) << ". CONFIG=" << cur_config << endl; - // load cdec.ini and set up decoder - ReadFile ini_rf(cur_config); - Decoder decoder(ini_rf.stream()); - vector& lambdas = decoder.CurrentWeightVector(); - if (ai == 0) { lambdas.swap(init_weights); init_weights.clear(); } - - if (rank == 0) - o->ResetEpoch(); // resets the learning rate-- TODO is this good? - - int iter = -1; - bool converged = false; - while (!converged) { -#ifdef HAVE_MPI - mpi::timer timer; -#endif - x.init_vector(&lambdas); - ++iter; ++titer; - observer.Reset(); - if (rank == 0) { - converged = (iter == max_iteration); - Weights::SanityCheck(lambdas); - static int cc = 0; ++cc; if (cc > 1) { Weights::ShowLargestFeatures(lambdas); } - string fname = "weights.cur.gz"; - if (iter % write_weights_every_ith == 0) { - ostringstream o; o << "weights.epoch_" << (ai+1) << '.' << iter << ".gz"; - fname = o.str(); - } - if (converged && ((ai+1)==agenda.size())) { fname = "weights.final.gz"; } - ostringstream vv; - vv << "total iter=" << titer << " (of current config iter=" << iter << ") minibatch=" << size_per_proc << " sentences/proc x " << size << " procs. num_feats=" << x.size() << '/' << FD::NumFeats() << " passes_thru_data=" << (titer * size_per_proc / static_cast(corpus.size())) << " eta=" << lr->eta(titer); - const string svv = vv.str(); - cerr << svv << endl; - Weights::WriteToFile(fname, lambdas, true, &svv); - } - - for (int i = 0; i < size_per_proc; ++i) { - int ei = corpus.size() * rng->next(); - int id = ids[ei]; - decoder.SetId(id); - decoder.Decode(corpus[ei], &observer); - } - SparseVector local_grad, g; - observer.GetGradient(&local_grad); -#ifdef HAVE_MPI - reduce(world, local_grad, g, std::plus >(), 0); -#else - g.swap(local_grad); -#endif - local_grad.clear(); - if (rank == 0) { - g /= (size_per_proc * size); - o->UpdateWeights(g, FD::NumFeats(), &x); - } -#ifdef HAVE_MPI - broadcast(world, x, 0); - broadcast(world, converged, 0); - world.barrier(); - if (rank == 0) { cerr << " ELAPSED TIME THIS ITERATION=" << timer.elapsed() << endl; } -#endif - } - } - return 0; -} diff --git a/training/mr_em_adapted_reduce.cc b/training/mr_em_adapted_reduce.cc deleted file mode 100644 index f65b5440..00000000 --- a/training/mr_em_adapted_reduce.cc +++ /dev/null @@ -1,173 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "filelib.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" -#include "m.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("optimization_method,m", po::value()->default_value("em"), "Optimization method (em, vb)") - ("input_format,f",po::value()->default_value("b64"),"Encoding of the input (b64 or text)"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -double NoZero(const double& x) { - if (x) return x; - return 1e-35; -} - -void Maximize(const bool use_vb, - const double& alpha, - const int total_event_types, - SparseVector* pc) { - const SparseVector& counts = *pc; - - if (use_vb) - assert(total_event_types >= counts.size()); - - double tot = 0; - for (SparseVector::const_iterator it = counts.begin(); - it != counts.end(); ++it) - tot += it->second; -// cerr << " = " << tot << endl; - assert(tot > 0.0); - double ltot = log(tot); - if (use_vb) - ltot = Md::digamma(tot + total_event_types * alpha); - for (SparseVector::const_iterator it = counts.begin(); - it != counts.end(); ++it) { - if (use_vb) { - pc->set_value(it->first, NoZero(Md::digamma(it->second + alpha) - ltot)); - } else { - pc->set_value(it->first, NoZero(log(it->second) - ltot)); - } - } -#if 0 - if (counts.size() < 50) { - for (SparseVector::const_iterator it = counts.begin(); - it != counts.end(); ++it) { - cerr << " p(" << FD::Convert(it->first) << ")=" << exp(it->second); - } - cerr << endl; - } -#endif -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - - const bool use_b64 = conf["input_format"].as() == "b64"; - const bool use_vb = conf["optimization_method"].as() == "vb"; - const double alpha = 1e-09; - if (use_vb) - cerr << "Using variational Bayes, make sure alphas are set\n"; - - const string s_obj = "**OBJ**"; - // E-step - string cur_key = ""; - SparseVector acc; - double logprob = 0; - while(cin) { - string line; - getline(cin, line); - if (line.empty()) continue; - int feat; - double val; - size_t i = line.find("\t"); - const string key = line.substr(0, i); - assert(i != string::npos); - ++i; - if (key != cur_key) { - if (cur_key.size() > 0) { - // TODO shouldn't be num_active, should be total number - // of events - Maximize(use_vb, alpha, acc.size(), &acc); - cout << cur_key << '\t'; - if (use_b64) - B64::Encode(0.0, acc, &cout); - else - cout << acc; - cout << endl; - acc.clear(); - } - cur_key = key; - } - if (use_b64) { - SparseVector g; - double obj; - if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { - cerr << "B64 decoder returned error, skipping!\n"; - continue; - } - logprob += obj; - acc += g; - } else { // text encoding - your counts will not be accurate! - while (i < line.size()) { - size_t start = i; - while (line[i] != '=' && i < line.size()) ++i; - if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } - string fname = line.substr(start, i - start); - if (fname == s_obj) { - feat = -1; - } else { - feat = FD::Convert(line.substr(start, i - start)); - } - ++i; - start = i; - while (line[i] != ';' && i < line.size()) ++i; - if (i - start == 0) continue; - val = atof(line.substr(start, i - start).c_str()); - ++i; - if (feat == -1) { - logprob += val; - } else { - acc.add_value(feat, val); - } - } - } - } - // TODO shouldn't be num_active, should be total number - // of events - Maximize(use_vb, alpha, acc.size(), &acc); - cout << cur_key << '\t'; - if (use_b64) - B64::Encode(0.0, acc, &cout); - else - cout << acc; - cout << endl << flush; - - cerr << "LOGPROB: " << logprob << endl; - - return 0; -} diff --git a/training/mr_em_map_adapter.cc b/training/mr_em_map_adapter.cc deleted file mode 100644 index ead4598d..00000000 --- a/training/mr_em_map_adapter.cc +++ /dev/null @@ -1,160 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include -#include "boost/tuple/tuple.hpp" - -#include "fdict.h" -#include "sparse_vector.h" - -using namespace std; -namespace po = boost::program_options; - -// useful for EM models parameterized by a bunch of multinomials -// this converts event counts (returned from cdec as feature expectations) -// into different keys and values (which are lists of all the events, -// conditioned on the key) for summing and normalization by a reducer - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("buffer_size,b", po::value()->default_value(1), "Buffer size (in # of counts) before emitting counts") - ("format,f",po::value()->default_value("b64"), "Encoding of the input (b64 or text)"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -struct EventMapper { - int Map(int fid) { - int& cv = map_[fid]; - if (!cv) { - cv = GetConditioningVariable(fid); - } - return cv; - } - void Clear() { map_.clear(); } - protected: - virtual int GetConditioningVariable(int fid) const = 0; - private: - map map_; -}; - -struct LexAlignEventMapper : public EventMapper { - protected: - virtual int GetConditioningVariable(int fid) const { - const string& str = FD::Convert(fid); - size_t pos = str.rfind("_"); - if (pos == string::npos || pos == 0 || pos >= str.size() - 1) { - cerr << "Bad feature for EM adapter: " << str << endl; - abort(); - } - return FD::Convert(str.substr(0, pos)); - } -}; - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - - const bool use_b64 = conf["format"].as() == "b64"; - const int buffer_size = conf["buffer_size"].as(); - - const string s_obj = "**OBJ**"; - // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; - // 0**OBJ**=1.1;Feat1=1.0; - - EventMapper* event_mapper = new LexAlignEventMapper; - map > counts; - size_t total = 0; - while(cin) { - string line; - getline(cin, line); - if (line.empty()) continue; - int feat; - double val; - size_t i = line.find("\t"); - assert(i != string::npos); - ++i; - SparseVector g; - double obj = 0; - if (use_b64) { - if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { - cerr << "B64 decoder returned error, skipping!\n"; - continue; - } - } else { // text encoding - your counts will not be accurate! - while (i < line.size()) { - size_t start = i; - while (line[i] != '=' && i < line.size()) ++i; - if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } - string fname = line.substr(start, i - start); - if (fname == s_obj) { - feat = -1; - } else { - feat = FD::Convert(line.substr(start, i - start)); - } - ++i; - start = i; - while (line[i] != ';' && i < line.size()) ++i; - if (i - start == 0) continue; - val = atof(line.substr(start, i - start).c_str()); - ++i; - if (feat == -1) { - obj = val; - } else { - g.set_value(feat, val); - } - } - } - //cerr << "OBJ: " << obj << endl; - const SparseVector& cg = g; - for (SparseVector::const_iterator it = cg.begin(); it != cg.end(); ++it) { - const int cond_var = event_mapper->Map(it->first); - SparseVector& cond_counts = counts[cond_var]; - int delta = cond_counts.size(); - cond_counts.add_value(it->first, it->second); - delta = cond_counts.size() - delta; - total += delta; - } - if (total > buffer_size) { - for (map >::iterator it = counts.begin(); - it != counts.end(); ++it) { - const SparseVector& cc = it->second; - cout << FD::Convert(it->first) << '\t'; - if (use_b64) { - B64::Encode(0.0, cc, &cout); - } else { - abort(); - } - cout << endl; - } - cout << flush; - total = 0; - counts.clear(); - } - } - - return 0; -} - diff --git a/training/mr_optimize_reduce.cc b/training/mr_optimize_reduce.cc deleted file mode 100644 index d490192f..00000000 --- a/training/mr_optimize_reduce.cc +++ /dev/null @@ -1,231 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "optimize.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" - -using namespace std; -namespace po = boost::program_options; - -void SanityCheck(const vector& w) { - for (int i = 0; i < w.size(); ++i) { - assert(!std::isnan(w[i])); - assert(!std::isinf(w[i])); - } -} - -struct FComp { - const vector& w_; - FComp(const vector& w) : w_(w) {} - bool operator()(int a, int b) const { - return fabs(w_[a]) > fabs(w_[b]); - } -}; - -void ShowLargestFeatures(const vector& w) { - vector fnums(w.size()); - for (int i = 0; i < w.size(); ++i) - fnums[i] = i; - vector::iterator mid = fnums.begin(); - mid += (w.size() > 10 ? 10 : w.size()); - partial_sort(fnums.begin(), mid, fnums.end(), FComp(w)); - cerr << "TOP FEATURES:"; - for (vector::iterator i = fnums.begin(); i != mid; ++i) { - cerr << ' ' << FD::Convert(*i) << '=' << w[*i]; - } - cerr << endl; -} - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_weights,i",po::value(),"Input feature weights file") - ("output_weights,o",po::value()->default_value("-"),"Output feature weights file") - ("optimization_method,m", po::value()->default_value("lbfgs"), "Optimization method (sgd, lbfgs, rprop)") - ("state,s",po::value(),"Read (and write if output_state is not set) optimizer state from this state file. In the first iteration, the file should not exist.") - ("input_format,f",po::value()->default_value("b64"),"Encoding of the input (b64 or text)") - ("output_state,S", po::value(), "Output state file (optional override)") - ("correction_buffers,M", po::value()->default_value(10), "Number of gradients for LBFGS to maintain in memory") - ("eta,e", po::value()->default_value(0.1), "Learning rate for SGD (eta)") - ("gaussian_prior,p","Use a Gaussian prior on the weights") - ("means,u", po::value(), "File containing the means for Gaussian prior") - ("sigma_squared", po::value()->default_value(1.0), "Sigma squared term for spherical Gaussian prior"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help") || !conf->count("input_weights") || !conf->count("state")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - - const bool use_b64 = conf["input_format"].as() == "b64"; - - vector lambdas; - Weights::InitFromFile(conf["input_weights"].as(), &lambdas); - const string s_obj = "**OBJ**"; - int num_feats = FD::NumFeats(); - cerr << "Number of features: " << num_feats << endl; - const bool gaussian_prior = conf.count("gaussian_prior"); - vector means(num_feats, 0); - if (conf.count("means")) { - if (!gaussian_prior) { - cerr << "Don't use --means without --gaussian_prior!\n"; - exit(1); - } - Weights::InitFromFile(conf["means"].as(), &means); - } - boost::shared_ptr o; - const string omethod = conf["optimization_method"].as(); - if (omethod == "rprop") - o.reset(new RPropOptimizer(num_feats)); // TODO add configuration - else - o.reset(new LBFGSOptimizer(num_feats, conf["correction_buffers"].as())); - cerr << "Optimizer: " << o->Name() << endl; - string state_file = conf["state"].as(); - { - ifstream in(state_file.c_str(), ios::binary); - if (in) - o->Load(&in); - else - cerr << "No state file found, assuming ITERATION 1\n"; - } - - double objective = 0; - vector gradient(num_feats, 0); - // 0**OBJ**=12.2;Feat1=2.3;Feat2=-0.2; - // 0**OBJ**=1.1;Feat1=1.0; - int total_lines = 0; // TODO - this should be a count of the - // training instances!! - while(cin) { - string line; - getline(cin, line); - if (line.empty()) continue; - ++total_lines; - int feat; - double val; - size_t i = line.find("\t"); - assert(i != string::npos); - ++i; - if (use_b64) { - SparseVector g; - double obj; - if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { - cerr << "B64 decoder returned error, skipping gradient!\n"; - cerr << " START: " << line.substr(0,line.size() > 200 ? 200 : line.size()) << endl; - if (line.size() > 200) - cerr << " END: " << line.substr(line.size() - 200, 200) << endl; - cout << "-1\tRESTART\n"; - exit(99); - } - objective += obj; - const SparseVector& cg = g; - for (SparseVector::const_iterator it = cg.begin(); it != cg.end(); ++it) { - if (it->first >= num_feats) { - cerr << "Unexpected feature in gradient: " << FD::Convert(it->first) << endl; - abort(); - } - gradient[it->first] -= it->second; - } - } else { // text encoding - your gradients will not be accurate! - while (i < line.size()) { - size_t start = i; - while (line[i] != '=' && i < line.size()) ++i; - if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } - string fname = line.substr(start, i - start); - if (fname == s_obj) { - feat = -1; - } else { - feat = FD::Convert(line.substr(start, i - start)); - if (feat >= num_feats) { - cerr << "Unexpected feature in gradient: " << line.substr(start, i - start) << endl; - abort(); - } - } - ++i; - start = i; - while (line[i] != ';' && i < line.size()) ++i; - if (i - start == 0) continue; - val = atof(line.substr(start, i - start).c_str()); - ++i; - if (feat == -1) { - objective += val; - } else { - gradient[feat] -= val; - } - } - } - } - - if (gaussian_prior) { - const double sigsq = conf["sigma_squared"].as(); - double norm = 0; - for (int k = 1; k < lambdas.size(); ++k) { - const double& lambda_k = lambdas[k]; - if (lambda_k) { - const double param = (lambda_k - means[k]); - norm += param * param; - gradient[k] += param / sigsq; - } - } - const double reg = norm / (2.0 * sigsq); - cerr << "REGULARIZATION TERM: " << reg << endl; - objective += reg; - } - cerr << "EVALUATION #" << o->EvaluationCount() << " OBJECTIVE: " << objective << endl; - double gnorm = 0; - for (int i = 0; i < gradient.size(); ++i) - gnorm += gradient[i] * gradient[i]; - cerr << " GNORM=" << sqrt(gnorm) << endl; - vector old = lambdas; - int c = 0; - while (old == lambdas) { - ++c; - if (c > 1) { cerr << "Same lambdas, repeating optimization\n"; } - o->Optimize(objective, gradient, &lambdas); - assert(c < 5); - } - old.clear(); - SanityCheck(lambdas); - ShowLargestFeatures(lambdas); - Weights::WriteToFile(conf["output_weights"].as(), lambdas, false); - - const bool conv = o->HasConverged(); - if (conv) { cerr << "OPTIMIZER REPORTS CONVERGENCE!\n"; } - - if (conf.count("output_state")) - state_file = conf["output_state"].as(); - ofstream out(state_file.c_str(), ios::binary); - cerr << "Writing state to: " << state_file << endl; - o->Save(&out); - out.close(); - - cout << o->EvaluationCount() << "\t" << conv << endl; - return 0; -} diff --git a/training/mr_reduce_to_weights.cc b/training/mr_reduce_to_weights.cc deleted file mode 100644 index 16b47720..00000000 --- a/training/mr_reduce_to_weights.cc +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include "filelib.h" -#include "fdict.h" -#include "weights.h" -#include "sparse_vector.h" - -using namespace std; -namespace po = boost::program_options; - -void InitCommandLine(int argc, char** argv, po::variables_map* conf) { - po::options_description opts("Configuration options"); - opts.add_options() - ("input_format,f",po::value()->default_value("b64"),"Encoding of the input (b64 or text)") - ("input,i",po::value()->default_value("-"),"Read file from") - ("output,o",po::value()->default_value("-"),"Write weights to"); - po::options_description clo("Command line options"); - clo.add_options() - ("config", po::value(), "Configuration file") - ("help,h", "Print this help message and exit"); - po::options_description dconfig_options, dcmdline_options; - dconfig_options.add(opts); - dcmdline_options.add(opts).add(clo); - - po::store(parse_command_line(argc, argv, dcmdline_options), *conf); - if (conf->count("config")) { - ifstream config((*conf)["config"].as().c_str()); - po::store(po::parse_config_file(config, dconfig_options), *conf); - } - po::notify(*conf); - - if (conf->count("help")) { - cerr << dcmdline_options << endl; - exit(1); - } -} - -void WriteWeights(const SparseVector& weights, ostream* out) { - for (SparseVector::const_iterator it = weights.begin(); - it != weights.end(); ++it) { - (*out) << FD::Convert(it->first) << " " << it->second << endl; - } -} - -int main(int argc, char** argv) { - po::variables_map conf; - InitCommandLine(argc, argv, &conf); - - const bool use_b64 = conf["input_format"].as() == "b64"; - - const string s_obj = "**OBJ**"; - // E-step - ReadFile rf(conf["input"].as()); - istream* in = rf.stream(); - assert(*in); - WriteFile wf(conf["output"].as()); - ostream* out = wf.stream(); - out->precision(17); - while(*in) { - string line; - getline(*in, line); - if (line.empty()) continue; - int feat; - double val; - size_t i = line.find("\t"); - assert(i != string::npos); - ++i; - if (use_b64) { - SparseVector g; - double obj; - if (!B64::Decode(&obj, &g, &line[i], line.size() - i)) { - cerr << "B64 decoder returned error, skipping!\n"; - continue; - } - WriteWeights(g, out); - } else { // text encoding - your counts will not be accurate! - SparseVector weights; - while (i < line.size()) { - size_t start = i; - while (line[i] != '=' && i < line.size()) ++i; - if (i == line.size()) { cerr << "FORMAT ERROR\n"; break; } - string fname = line.substr(start, i - start); - if (fname == s_obj) { - feat = -1; - } else { - feat = FD::Convert(line.substr(start, i - start)); - } - ++i; - start = i; - while (line[i] != ';' && i < line.size()) ++i; - if (i - start == 0) continue; - val = atof(line.substr(start, i - start).c_str()); - ++i; - if (feat != -1) { - weights.set_value(feat, val); - } - } - WriteWeights(weights, out); - } - } - - return 0; -} diff --git a/training/online_optimizer.cc b/training/online_optimizer.cc deleted file mode 100644 index 3ed95452..00000000 --- a/training/online_optimizer.cc +++ /dev/null @@ -1,16 +0,0 @@ -#include "online_optimizer.h" - -LearningRateSchedule::~LearningRateSchedule() {} - -double StandardLearningRate::eta(int k) const { - return eta_0_ / (1.0 + k / N_); -} - -double ExponentialDecayLearningRate::eta(int k) const { - return eta_0_ * pow(alpha_, k / N_); -} - -OnlineOptimizer::~OnlineOptimizer() {} - -void OnlineOptimizer::ResetEpochImpl() {} - diff --git a/training/online_optimizer.h b/training/online_optimizer.h deleted file mode 100644 index 28d89344..00000000 --- a/training/online_optimizer.h +++ /dev/null @@ -1,129 +0,0 @@ -#ifndef _ONL_OPTIMIZE_H_ -#define _ONL_OPTIMIZE_H_ - -#include -#include -#include -#include -#include "sparse_vector.h" - -struct LearningRateSchedule { - virtual ~LearningRateSchedule(); - // returns the learning rate for the kth iteration - virtual double eta(int k) const = 0; -}; - -// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N -// to mean the batch size in most places, but it doesn't completely -// make sense to me in the learning rate schedules-- this needs -// to be worked out to make sure they didn't mean corpus size -// in some places and batch size in others (since in the paper they -// only ever work with batch sizes of 1) -struct StandardLearningRate : public LearningRateSchedule { - StandardLearningRate( - size_t batch_size, // batch size, not corpus size! - double eta_0 = 0.2) : - eta_0_(eta_0), - N_(static_cast(batch_size)) {} - - virtual double eta(int k) const; - - private: - const double eta_0_; - const double N_; -}; - -struct ExponentialDecayLearningRate : public LearningRateSchedule { - ExponentialDecayLearningRate( - size_t batch_size, // batch size, not corpus size! - double eta_0 = 0.2, - double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) - ) : eta_0_(eta_0), - N_(static_cast(batch_size)), - alpha_(alpha) { - assert(alpha > 0); - assert(alpha < 1.0); - } - - virtual double eta(int k) const; - - private: - const double eta_0_; - const double N_; - const double alpha_; -}; - -class OnlineOptimizer { - public: - virtual ~OnlineOptimizer(); - OnlineOptimizer(const std::tr1::shared_ptr& s, - size_t batch_size, - const std::vector& frozen_feats = std::vector()) - : N_(batch_size),schedule_(s),k_() { - for (int i = 0; i < frozen_feats.size(); ++i) - frozen_.insert(frozen_feats[i]); - } - void ResetEpoch() { k_ = 0; ResetEpochImpl(); } - void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { - ++k_; - const double eta = schedule_->eta(k_); - UpdateWeightsImpl(eta, approx_g, max_feat, weights); - } - - protected: - virtual void ResetEpochImpl(); - virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; - const size_t N_; // number of training instances per batch - std::set frozen_; // frozen (non-optimizing) features - - private: - std::tr1::shared_ptr schedule_; - int k_; // iteration count -}; - -class CumulativeL1OnlineOptimizer : public OnlineOptimizer { - public: - CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, - size_t training_instances, double C, - const std::vector& frozen) : - OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} - - protected: - void ResetEpochImpl() { u_ = 0; } - void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { - u_ += eta * C_ / N_; - for (SparseVector::const_iterator it = approx_g.begin(); - it != approx_g.end(); ++it) { - if (frozen_.count(it->first) == 0) - weights->add_value(it->first, eta * it->second); - } - for (int i = 1; i < max_feat; ++i) - if (frozen_.count(i) == 0) ApplyPenalty(i, weights); - } - - private: - void ApplyPenalty(int i, SparseVector* w) { - const double z = w->value(i); - double w_i = z; - double q_i = q_.value(i); - if (w_i > 0.0) - w_i = std::max(0.0, w_i - (u_ + q_i)); - else if (w_i < 0.0) - w_i = std::min(0.0, w_i + (u_ - q_i)); - q_i += w_i - z; - if (q_i == 0.0) - q_.erase(i); - else - q_.set_value(i, q_i); - if (w_i == 0.0) - w->erase(i); - else - w->set_value(i, w_i); - } - - const double C_; // reguarlization strength - double u_; - SparseVector q_; -}; - -#endif diff --git a/training/optimize.cc b/training/optimize.cc deleted file mode 100644 index 41ac90d8..00000000 --- a/training/optimize.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "optimize.h" - -#include -#include - -#include "lbfgs.h" - -using namespace std; - -BatchOptimizer::~BatchOptimizer() {} - -void BatchOptimizer::Save(ostream* out) const { - out->write((const char*)&eval_, sizeof(eval_)); - out->write((const char*)&has_converged_, sizeof(has_converged_)); - SaveImpl(out); - unsigned int magic = 0xABCDDCBA; // should be uint32_t - out->write((const char*)&magic, sizeof(magic)); -} - -void BatchOptimizer::Load(istream* in) { - in->read((char*)&eval_, sizeof(eval_)); - in->read((char*)&has_converged_, sizeof(has_converged_)); - LoadImpl(in); - unsigned int magic = 0; // should be uint32_t - in->read((char*)&magic, sizeof(magic)); - assert(magic == 0xABCDDCBA); - cerr << Name() << " EVALUATION #" << eval_ << endl; -} - -void BatchOptimizer::SaveImpl(ostream* out) const { - (void)out; -} - -void BatchOptimizer::LoadImpl(istream* in) { - (void)in; -} - -string RPropOptimizer::Name() const { - return "RPropOptimizer"; -} - -void RPropOptimizer::OptimizeImpl(const double& obj, - const vector& g, - vector* x) { - for (int i = 0; i < g.size(); ++i) { - const double g_i = g[i]; - const double sign_i = (signbit(g_i) ? -1.0 : 1.0); - const double prod = g_i * prev_g_[i]; - if (prod > 0.0) { - const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); - (*x)[i] -= dij * sign_i; - delta_ij_[i] = dij; - prev_g_[i] = g_i; - } else if (prod < 0.0) { - delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); - prev_g_[i] = 0.0; - } else { - (*x)[i] -= delta_ij_[i] * sign_i; - prev_g_[i] = g_i; - } - } -} - -void RPropOptimizer::SaveImpl(ostream* out) const { - const size_t n = prev_g_.size(); - out->write((const char*)&n, sizeof(n)); - out->write((const char*)&prev_g_[0], sizeof(double) * n); - out->write((const char*)&delta_ij_[0], sizeof(double) * n); -} - -void RPropOptimizer::LoadImpl(istream* in) { - size_t n; - in->read((char*)&n, sizeof(n)); - assert(n == prev_g_.size()); - assert(n == delta_ij_.size()); - in->read((char*)&prev_g_[0], sizeof(double) * n); - in->read((char*)&delta_ij_[0], sizeof(double) * n); -} - -string LBFGSOptimizer::Name() const { - return "LBFGSOptimizer"; -} - -LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : - opt_(num_feats, memory_buffers) {} - -void LBFGSOptimizer::SaveImpl(ostream* out) const { - opt_.serialize(out); -} - -void LBFGSOptimizer::LoadImpl(istream* in) { - opt_.deserialize(in); -} - -void LBFGSOptimizer::OptimizeImpl(const double& obj, - const vector& g, - vector* x) { - opt_.run(&(*x)[0], obj, &g[0]); - if (!opt_.requests_f_and_g()) opt_.run(&(*x)[0], obj, &g[0]); - // cerr << opt_ << endl; -} - diff --git a/training/optimize.h b/training/optimize.h deleted file mode 100644 index 07943b44..00000000 --- a/training/optimize.h +++ /dev/null @@ -1,92 +0,0 @@ -#ifndef _OPTIMIZE_H_ -#define _OPTIMIZE_H_ - -#include -#include -#include -#include - -#include "lbfgs.h" - -// abstract base class for first order optimizers -// order of invocation: new, Load(), Optimize(), Save(), delete -class BatchOptimizer { - public: - BatchOptimizer() : eval_(1), has_converged_(false) {} - virtual ~BatchOptimizer(); - virtual std::string Name() const = 0; - int EvaluationCount() const { return eval_; } - bool HasConverged() const { return has_converged_; } - - void Optimize(const double& obj, - const std::vector& g, - std::vector* x) { - assert(g.size() == x->size()); - ++eval_; - OptimizeImpl(obj, g, x); - scitbx::lbfgs::traditional_convergence_test converged(g.size()); - has_converged_ = converged(&(*x)[0], &g[0]); - } - - void Save(std::ostream* out) const; - void Load(std::istream* in); - protected: - virtual void SaveImpl(std::ostream* out) const; - virtual void LoadImpl(std::istream* in); - virtual void OptimizeImpl(const double& obj, - const std::vector& g, - std::vector* x) = 0; - - int eval_; - private: - bool has_converged_; -}; - -class RPropOptimizer : public BatchOptimizer { - public: - explicit RPropOptimizer(int num_vars, - double eta_plus = 1.2, - double eta_minus = 0.5, - double delta_0 = 0.1, - double delta_max = 50.0, - double delta_min = 1e-6) : - prev_g_(num_vars, 0.0), - delta_ij_(num_vars, delta_0), - eta_plus_(eta_plus), - eta_minus_(eta_minus), - delta_max_(delta_max), - delta_min_(delta_min) { - assert(eta_plus > 1.0); - assert(eta_minus > 0.0 && eta_minus < 1.0); - assert(delta_max > 0.0); - assert(delta_min > 0.0); - } - std::string Name() const; - void OptimizeImpl(const double& obj, - const std::vector& g, - std::vector* x); - void SaveImpl(std::ostream* out) const; - void LoadImpl(std::istream* in); - private: - std::vector prev_g_; - std::vector delta_ij_; - const double eta_plus_; - const double eta_minus_; - const double delta_max_; - const double delta_min_; -}; - -class LBFGSOptimizer : public BatchOptimizer { - public: - explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); - std::string Name() const; - void SaveImpl(std::ostream* out) const; - void LoadImpl(std::istream* in); - void OptimizeImpl(const double& obj, - const std::vector& g, - std::vector* x); - private: - scitbx::lbfgs::minimizer opt_; -}; - -#endif diff --git a/training/optimize_test.cc b/training/optimize_test.cc deleted file mode 100644 index bff2ca03..00000000 --- a/training/optimize_test.cc +++ /dev/null @@ -1,118 +0,0 @@ -#include -#include -#include -#include -#include "optimize.h" -#include "online_optimizer.h" -#include "sparse_vector.h" -#include "fdict.h" - -using namespace std; - -double TestOptimizer(BatchOptimizer* opt) { - cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; - - // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 - // df/dx1 = 8*x1 + x2 - // df/dx2 = 2*x2 + x1 - // df/dx3 = 2*x3 + 6 - vector x(3); - vector g(3); - x[0] = 8; - x[1] = 8; - x[2] = 8; - double obj = 0; - do { - g[0] = 8 * x[0] + x[1]; - g[1] = 2 * x[1] + x[0]; - g[2] = 2 * x[2] + 6; - obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; - opt->Optimize(obj, g, &x); - - cerr << x[0] << " " << x[1] << " " << x[2] << endl; - cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; - } while (!opt->HasConverged()); - return obj; -} - -double TestPersistentOptimizer(BatchOptimizer* opt) { - cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; - // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 - // df/dx1 = 8*x1 + x2 - // df/dx2 = 2*x2 + x1 - // df/dx3 = 2*x3 + 6 - vector x(3); - vector g(3); - x[0] = 8; - x[1] = 8; - x[2] = 8; - double obj = 0; - string state; - bool converged = false; - while (!converged) { - g[0] = 8 * x[0] + x[1]; - g[1] = 2 * x[1] + x[0]; - g[2] = 2 * x[2] + 6; - obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; - - { - if (state.size() > 0) { - istringstream is(state, ios::binary); - opt->Load(&is); - } - opt->Optimize(obj, g, &x); - ostringstream os(ios::binary); opt->Save(&os); state = os.str(); - - } - - cerr << x[0] << " " << x[1] << " " << x[2] << endl; - cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; - converged = opt->HasConverged(); - if (!converged) { - // now screw up the state (should be undone by Load) - obj += 2.0; - g[1] = -g[2]; - vector x2 = x; - try { - opt->Optimize(obj, g, &x2); - } catch (...) { } - } - } - return obj; -} - -template -void TestOptimizerVariants(int num_vars) { - O oa(num_vars); - cerr << "-------------------------------------------------------------------------\n"; - cerr << "TESTING: " << oa.Name() << endl; - double o1 = TestOptimizer(&oa); - O ob(num_vars); - double o2 = TestPersistentOptimizer(&ob); - if (o1 != o2) { - cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; - exit(1); - } - cerr << oa.Name() << " SUCCESS\n"; -} - -using namespace std::tr1; - -void TestOnline() { - size_t N = 20; - double C = 1.0; - double eta0 = 0.2; - std::tr1::shared_ptr r(new ExponentialDecayLearningRate(N, eta0, 0.85)); - //shared_ptr r(new StandardLearningRate(N, eta0)); - CumulativeL1OnlineOptimizer opt(r, N, C, std::vector()); - assert(r->eta(10) < r->eta(1)); -} - -int main() { - int n = 3; - TestOptimizerVariants(n); - TestOptimizerVariants(n); - TestOnline(); - return 0; -} - diff --git a/training/pro/Makefile.am b/training/pro/Makefile.am new file mode 100644 index 00000000..1916b6b2 --- /dev/null +++ b/training/pro/Makefile.am @@ -0,0 +1,11 @@ +bin_PROGRAMS = \ + mr_pro_map \ + mr_pro_reduce + +mr_pro_map_SOURCES = mr_pro_map.cc +mr_pro_map_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +mr_pro_reduce_SOURCES = mr_pro_reduce.cc +mr_pro_reduce_LDADD = $(top_srcdir)/training/liblbfgs/liblbfgs.a $(top_srcdir)/utils/libutils.a -lz + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils -I$(top_srcdir)/training diff --git a/training/pro/mr_pro_generate_mapper_input.pl b/training/pro/mr_pro_generate_mapper_input.pl new file mode 100755 index 00000000..b30fc4fd --- /dev/null +++ b/training/pro/mr_pro_generate_mapper_input.pl @@ -0,0 +1,18 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; +my $d = shift @ARGV; +die "Can't find directory $d" unless -d $d; + +opendir(DIR, $d) or die "Can't read $d: $!"; +my @hgs = grep { /\.gz$/ } readdir(DIR); +closedir DIR; + +for my $hg (@hgs) { + my $file = $hg; + my $id = $hg; + $id =~ s/(\.json)?\.gz//; + print "$d/$file $id\n"; +} + diff --git a/training/pro/mr_pro_map.cc b/training/pro/mr_pro_map.cc new file mode 100644 index 00000000..eef40b8a --- /dev/null +++ b/training/pro/mr_pro_map.cc @@ -0,0 +1,201 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "candidate_set.h" +#include "sampler.h" +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "inside_outside.h" +#include "hg_io.h" +#include "ns.h" +#include "ns_docscorer.h" + +// This is Figure 4 (Algorithm Sampler) from Hopkins&May (2011) + +using namespace std; +namespace po = boost::program_options; + +boost::shared_ptr rng; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("weights,w",po::value(), "[REQD] Weights files from current iterations") + ("kbest_repository,K",po::value()->default_value("./kbest"),"K-best list repository (directory)") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") + ("source,s",po::value()->default_value(""), "Source file (ignored, except for AER)") + ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") + ("kbest_size,k",po::value()->default_value(1500u), "Top k-hypotheses to extract") + ("candidate_pairs,G", po::value()->default_value(5000u), "Number of pairs to sample per hypothesis (Gamma)") + ("best_pairs,X", po::value()->default_value(50u), "Number of pairs, ranked by magnitude of objective delta, to retain (Xi)") + ("random_seed,S", po::value(), "Random seed (if not specified, /dev/random will be used)") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (!conf->count("weights")) { + cerr << "Please specify weights using -w \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct ThresholdAlpha { + explicit ThresholdAlpha(double t = 0.05) : threshold(t) {} + double operator()(double mag) const { + if (mag < threshold) return 0.0; else return 1.0; + } + const double threshold; +}; + +struct TrainingInstance { + TrainingInstance(const SparseVector& feats, bool positive, float diff) : x(feats), y(positive), gdiff(diff) {} + SparseVector x; +#undef DEBUGGING_PRO +#ifdef DEBUGGING_PRO + vector a; + vector b; +#endif + bool y; + float gdiff; +}; +#ifdef DEBUGGING_PRO +ostream& operator<<(ostream& os, const TrainingInstance& d) { + return os << d.gdiff << " y=" << d.y << "\tA:" << TD::GetString(d.a) << "\n\tB: " << TD::GetString(d.b) << "\n\tX: " << d.x; +} +#endif + +struct DiffOrder { + bool operator()(const TrainingInstance& a, const TrainingInstance& b) const { + return a.gdiff > b.gdiff; + } +}; + +void Sample(const unsigned gamma, + const unsigned xi, + const training::CandidateSet& J_i, + const EvaluationMetric* metric, + vector* pv) { + const bool invert_score = metric->IsErrorMetric(); + vector v1, v2; + float avg_diff = 0; + for (unsigned i = 0; i < gamma; ++i) { + const size_t a = rng->inclusive(0, J_i.size() - 1)(); + const size_t b = rng->inclusive(0, J_i.size() - 1)(); + if (a == b) continue; + float ga = metric->ComputeScore(J_i[a].eval_feats); + float gb = metric->ComputeScore(J_i[b].eval_feats); + bool positive = gb < ga; + if (invert_score) positive = !positive; + const float gdiff = fabs(ga - gb); + if (!gdiff) continue; + avg_diff += gdiff; + SparseVector xdiff = (J_i[a].fmap - J_i[b].fmap).erase_zeros(); + if (xdiff.empty()) { + cerr << "Empty diff:\n " << TD::GetString(J_i[a].ewords) << endl << "x=" << J_i[a].fmap << endl; + cerr << " " << TD::GetString(J_i[b].ewords) << endl << "x=" << J_i[b].fmap << endl; + continue; + } + v1.push_back(TrainingInstance(xdiff, positive, gdiff)); +#ifdef DEBUGGING_PRO + v1.back().a = J_i[a].hyp; + v1.back().b = J_i[b].hyp; + cerr << "N: " << v1.back() << endl; +#endif + } + avg_diff /= v1.size(); + + for (unsigned i = 0; i < v1.size(); ++i) { + double p = 1.0 / (1.0 + exp(-avg_diff - v1[i].gdiff)); + // cerr << "avg_diff=" << avg_diff << " gdiff=" << v1[i].gdiff << " p=" << p << endl; + if (rng->next() < p) v2.push_back(v1[i]); + } + vector::iterator mid = v2.begin() + xi; + if (xi > v2.size()) mid = v2.end(); + partial_sort(v2.begin(), mid, v2.end(), DiffOrder()); + copy(v2.begin(), mid, back_inserter(*pv)); +#ifdef DEBUGGING_PRO + if (v2.size() >= 5) { + for (int i =0; i < (mid - v2.begin()); ++i) { + cerr << v2[i] << endl; + } + cerr << pv->back() << endl; + } +#endif +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + if (conf.count("random_seed")) + rng.reset(new MT19937(conf["random_seed"].as())); + else + rng.reset(new MT19937); + const string evaluation_metric = conf["evaluation_metric"].as(); + + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; + + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as()); + istream &in=*in_read.stream(); + const unsigned kbest_size = conf["kbest_size"].as(); + const unsigned gamma = conf["candidate_pairs"].as(); + const unsigned xi = conf["best_pairs"].as(); + string weightsf = conf["weights"].as(); + vector weights; + Weights::InitFromFile(weightsf, &weights); + string kbest_repo = conf["kbest_repository"].as(); + MkDirP(kbest_repo); + while(in) { + vector v; + string line; + getline(in, line); + if (line.empty()) continue; + istringstream is(line); + int sent_id; + string file; + // path-to-file (JSON) sent_id + is >> file >> sent_id; + ReadFile rf(file); + ostringstream os; + training::CandidateSet J_i; + os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; + const string kbest_file = os.str(); + if (FileExists(kbest_file)) + J_i.ReadFromFile(kbest_file); + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(weights); + J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); + J_i.WriteToFile(kbest_file); + + Sample(gamma, xi, J_i, metric, &v); + for (unsigned i = 0; i < v.size(); ++i) { + const TrainingInstance& vi = v[i]; + cout << vi.y << "\t" << vi.x << endl; + cout << (!vi.y) << "\t" << (vi.x * -1.0) << endl; + } + } + return 0; +} + diff --git a/training/pro/mr_pro_reduce.cc b/training/pro/mr_pro_reduce.cc new file mode 100644 index 00000000..5ef9b470 --- /dev/null +++ b/training/pro/mr_pro_reduce.cc @@ -0,0 +1,286 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "weights.h" +#include "sparse_vector.h" +#include "optimize.h" +#include "liblbfgs/lbfgs++.h" + +using namespace std; +namespace po = boost::program_options; + +// since this is a ranking model, there should be equal numbers of +// positive and negative examples, so the bias should be 0 +static const double MAX_BIAS = 1e-10; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("weights,w", po::value(), "Weights from previous iteration (used as initialization and interpolation") + ("regularization_strength,C",po::value()->default_value(500.0), "l2 regularization strength") + ("l1",po::value()->default_value(0.0), "l1 regularization strength") + ("regularize_to_weights,y",po::value()->default_value(5000.0), "Differences in learned weights to previous weights are penalized with an l2 penalty with this strength; 0.0 = no effect") + ("memory_buffers,m",po::value()->default_value(100), "Number of memory buffers (LBFGS)") + ("min_reg,r",po::value()->default_value(0.01), "When tuning (-T) regularization strength, minimum regularization strenght") + ("max_reg,R",po::value()->default_value(1e6), "When tuning (-T) regularization strength, maximum regularization strenght") + ("testset,t",po::value(), "Optional held-out test set") + ("tune_regularizer,T", "Use the held out test set (-t) to tune the regularization strength") + ("interpolate_with_weights,p",po::value()->default_value(1.0), "[deprecated] Output weights are p*w + (1-p)*w_prev; 1.0 = no effect") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const weight_t val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + +void ReadCorpus(istream* pin, vector > >* corpus) { + istream& in = *pin; + corpus->clear(); + bool flag = false; + int lc = 0; + string line; + SparseVector x; + while(getline(in, line)) { + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc % 40000 == 0) { cerr << " [" << lc << "]\n"; flag = false; } + if (line.empty()) continue; + const size_t ks = line.find("\t"); + assert(string::npos != ks); + assert(ks == 1); + const bool y = line[0] == '1'; + x.clear(); + ParseSparseVector(line, ks + 1, &x); + corpus->push_back(make_pair(y, x)); + } + if (flag) cerr << endl; +} + +void GradAdd(const SparseVector& v, const double scale, weight_t* acc) { + for (SparseVector::const_iterator it = v.begin(); + it != v.end(); ++it) { + acc[it->first] += it->second * scale; + } +} + +double ApplyRegularizationTerms(const double C, + const double T, + const vector& weights, + const vector& prev_weights, + weight_t* g) { + double reg = 0; + for (size_t i = 0; i < weights.size(); ++i) { + const double prev_w_i = (i < prev_weights.size() ? prev_weights[i] : 0.0); + const double& w_i = weights[i]; + reg += C * w_i * w_i; + g[i] += 2 * C * w_i; + + const double diff_i = w_i - prev_w_i; + reg += T * diff_i * diff_i; + g[i] += 2 * T * diff_i; + } + return reg; +} + +double TrainingInference(const vector& x, + const vector > >& corpus, + weight_t* g = NULL) { + double cll = 0; + for (int i = 0; i < corpus.size(); ++i) { + const double dotprod = corpus[i].second.dot(x) + (x.size() ? x[0] : weight_t()); // x[0] is bias + double lp_false = dotprod; + double lp_true = -dotprod; + if (0 < lp_true) { + lp_true += log1p(exp(-lp_true)); + lp_false = log1p(exp(lp_false)); + } else { + lp_true = log1p(exp(lp_true)); + lp_false += log1p(exp(-lp_false)); + } + lp_true*=-1; + lp_false*=-1; + if (corpus[i].first) { // true label + cll -= lp_true; + if (g) { + // g -= corpus[i].second * exp(lp_false); + GradAdd(corpus[i].second, -exp(lp_false), g); + g[0] -= exp(lp_false); // bias + } + } else { // false label + cll -= lp_false; + if (g) { + // g += corpus[i].second * exp(lp_true); + GradAdd(corpus[i].second, exp(lp_true), g); + g[0] += exp(lp_true); // bias + } + } + } + return cll; +} + +struct ProLoss { + ProLoss(const vector > >& tr, + const vector > >& te, + const double c, + const double t, + const vector& px) : training(tr), testing(te), C(c), T(t), prev_x(px){} + double operator()(const vector& x, double* g) const { + fill(g, g + x.size(), 0.0); + double cll = TrainingInference(x, training, g); + tppl = 0; + if (testing.size()) + tppl = pow(2.0, TrainingInference(x, testing, g) / (log(2) * testing.size())); + double ppl = cll / log(2); + ppl /= training.size(); + ppl = pow(2.0, ppl); + double reg = ApplyRegularizationTerms(C, T, x, prev_x, g); + return cll + reg; + } + const vector > >& training, testing; + const double C, T; + const vector& prev_x; + mutable double tppl; +}; + +// return held-out log likelihood +double LearnParameters(const vector > >& training, + const vector > >& testing, + const double C, + const double C1, + const double T, + const unsigned memory_buffers, + const vector& prev_x, + vector* px) { + assert(px->size() == prev_x.size()); + ProLoss loss(training, testing, C, T, prev_x); + LBFGS lbfgs(px, loss, memory_buffers, C1); + lbfgs.MinimizeFunction(); + return loss.tppl; +} + +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string line; + vector > > training, testing; + const bool tune_regularizer = conf.count("tune_regularizer"); + if (tune_regularizer && !conf.count("testset")) { + cerr << "--tune_regularizer requires --testset to be set\n"; + return 1; + } + const double min_reg = conf["min_reg"].as(); + const double max_reg = conf["max_reg"].as(); + double C = conf["regularization_strength"].as(); // will be overridden if parameter is tuned + double C1 = conf["l1"].as(); // will be overridden if parameter is tuned + const double T = conf["regularize_to_weights"].as(); + assert(C >= 0.0); + assert(min_reg >= 0.0); + assert(max_reg >= 0.0); + assert(max_reg > min_reg); + const double psi = conf["interpolate_with_weights"].as(); + if (psi < 0.0 || psi > 1.0) { cerr << "Invalid interpolation weight: " << psi << endl; return 1; } + ReadCorpus(&cin, &training); + if (conf.count("testset")) { + ReadFile rf(conf["testset"].as()); + ReadCorpus(rf.stream(), &testing); + } + cerr << "Number of features: " << FD::NumFeats() << endl; + + vector x, prev_x; // x[0] is bias + if (conf.count("weights")) { + Weights::InitFromFile(conf["weights"].as(), &x); + x.resize(FD::NumFeats()); + prev_x = x; + } else { + x.resize(FD::NumFeats()); + prev_x = x; + } + cerr << " Number of features: " << x.size() << endl; + cerr << "Number of training examples: " << training.size() << endl; + cerr << "Number of testing examples: " << testing.size() << endl; + double tppl = 0.0; + vector > sp; + vector smoothed; + if (tune_regularizer) { + C = min_reg; + const double steps = 18; + double sweep_factor = exp((log(max_reg) - log(min_reg)) / steps); + cerr << "SWEEP FACTOR: " << sweep_factor << endl; + while(C < max_reg) { + cerr << "C=" << C << "\tT=" <(), prev_x, &x); + sp.push_back(make_pair(C, tppl)); + C *= sweep_factor; + } + smoothed.resize(sp.size(), 0); + smoothed[0] = sp[0].second; + smoothed.back() = sp.back().second; + for (int i = 1; i < sp.size()-1; ++i) { + double prev = sp[i-1].second; + double next = sp[i+1].second; + double cur = sp[i].second; + smoothed[i] = (prev*0.2) + cur * 0.6 + (0.2*next); + } + double best_ppl = 9999999; + unsigned best_i = 0; + for (unsigned i = 0; i < sp.size(); ++i) { + if (smoothed[i] < best_ppl) { + best_ppl = smoothed[i]; + best_i = i; + } + } + C = sp[best_i].first; + } // tune regularizer + tppl = LearnParameters(training, testing, C, C1, T, conf["memory_buffers"].as(), prev_x, &x); + if (conf.count("weights")) { + for (int i = 1; i < x.size(); ++i) { + x[i] = (x[i] * psi) + prev_x[i] * (1.0 - psi); + } + } + cout.precision(15); + cout << "# C=" << C << "\theld out perplexity="; + if (tppl) { cout << tppl << endl; } else { cout << "N/A\n"; } + if (sp.size()) { + cout << "# Parameter sweep:\n"; + for (int i = 0; i < sp.size(); ++i) { + cout << "# " << sp[i].first << "\t" << sp[i].second << "\t" << smoothed[i] << endl; + } + } + Weights::WriteToFile("-", x); + return 0; +} diff --git a/training/pro/pro.pl b/training/pro/pro.pl new file mode 100755 index 00000000..3b30c379 --- /dev/null +++ b/training/pro/pro.pl @@ -0,0 +1,555 @@ +#!/usr/bin/env perl +use strict; +use File::Basename qw(basename); +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment", "$SCRIPT_DIR/../utils"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use IPC::Open2; +use POSIX ":sys_wait_h"; +my $QSUB_CMD = qsub_args(mert_memory()); +my $default_jobs = env_default_jobs(); + +my $UTILS_DIR="$SCRIPT_DIR/../utils"; +require "$UTILS_DIR/libcall.pl"; + +# Default settings +my $srcFile; +my $refFiles; +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/mr_pro_generate_mapper_input.pl"; +my $MAPPER = "$bin_dir/mr_pro_map"; +my $REDUCER = "$bin_dir/mr_pro_reduce"; +my $parallelize = "$UTILS_DIR/parallelize.pl"; +my $libcall = "$UTILS_DIR/libcall.pl"; +my $sentserver = "$UTILS_DIR/sentserver"; +my $sentclient = "$UTILS_DIR/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $lines_per_mapper = 30; +my $iteration = 1; +my $best_weights; +my $psi = 1; +my $default_max_iter = 30; +my $max_iterations = $default_max_iter; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "4g"; +my $disable_clean = 0; +my %seen_weights; +my $help = 0; +my $epsilon = 0.0001; +my $dryrun = 0; +my $last_score = -10000000; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $use_make = 1; # use make to parallelize +my $useqsub = 0; +my $initial_weights; +my $pass_suffix = ''; +my $devset; + +# regularization strength +my $reg = 500; +my $reg_previous = 5000; + +# Process command-line options +if (GetOptions( + "config=s" => \$iniFile, + "weights=s" => \$initial_weights, + "devset=s" => \$devset, + "jobs=i" => \$jobs, + "metric=s" => \$metric, + "pass-suffix=s" => \$pass_suffix, + "qsub" => \$useqsub, + "help" => \$help, + "reg=f" => \$reg, + "reg-previous=f" => \$reg_previous, + "output-dir=s" => \$dir, +) == 0 || @ARGV!=0 || $help) { + print_help(); + exit; +} + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (!defined $iniFile) { push @missing_args, "--config"; } +if (!defined $devset) { push @missing_args, "--devset"; } +if (!defined $initial_weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); + +if ($metric =~ /^(combi|ter)$/i) { + $lines_per_mapper = 5; +} + +my $host =check_output("hostname"); chomp $host; +my $bleu; +my $interval_count = 0; +my $logfile; +my $projected_score; + +# used in sorting scores +my $DIR_FLAG = '-r'; +if ($metric =~ /^ter$|^aer$/i) { + $DIR_FLAG = ''; +} + +unless ($dir){ + $dir = 'pro'; +} +unless ($dir =~ /^\//){ # convert relative path to absolute path + my $basedir = check_output("pwd"); + chomp $basedir; + $dir = "$basedir/$dir"; +} + +# Initializations and helper functions +srand; + +my @childpids = (); +my @cleanupcmds = (); + +sub cleanup { + print STDERR "Cleanup...\n"; + for my $pid (@childpids){ unchecked_call("kill $pid"); } + for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } + exit 1; +}; +# Always call cleanup, no matter how we exit +*CORE::GLOBAL::exit = + sub{ cleanup(); }; +$SIG{INT} = "cleanup"; +$SIG{TERM} = "cleanup"; +$SIG{HUP} = "cleanup"; + +my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; +my $newIniFile = "$dir/$decoderBase.ini"; +my $inputFileName = "$dir/input"; +my $user = $ENV{"USER"}; + + +# process ini file +-e $iniFile || die "Error: could not open $iniFile for reading\n"; +open(INI, $iniFile); + +if (-e $dir) { + die "ERROR: working dir $dir already exists\n\n"; +} else { + mkdir "$dir" or die "Can't mkdir $dir: $!"; + mkdir "$dir/hgs" or die; + mkdir "$dir/scripts" or die; + print STDERR <) { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +my @allweights; +while (1){ + print STDERR "\n\nITERATION $iteration\n==========\n"; + + if ($iteration > $max_iterations){ + print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + check_call("mkdir -p $logdir"); + + + #decode + print STDERR "RUNNING DECODER AT "; + print STDERR unchecked_output("date"); + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + push @allweights, "-w $dir/weights.$im1"; + `rm -f $dir/hgs/*.gz`; + my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; + my $pcmd; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; + } else { + $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; + } + my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + my $num_hgs; + my $num_topbest; + my $retries = 0; + while($retries < 5) { + $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); + $num_topbest = check_output("wc -l < $runFile"); + print STDERR "NUMBER OF HGs: $num_hgs\n"; + print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; + if($devSize == $num_hgs && $devSize == $num_topbest) { + last; + } else { + print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; + sleep(3); + } + $retries++; + } + die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); + my $dec_score = check_output("cat $runFile | $SCORER -r $refs -m $metric"); + chomp $dec_score; + print STDERR "DECODER SCORE: $dec_score\n"; + + # save space + check_call("gzip -f $runFile"); + check_call("gzip -f $decoderLog"); + + # run optimizer + print STDERR "RUNNING OPTIMIZER AT "; + print STDERR unchecked_output("date"); + print STDERR " - GENERATE TRAINING EXEMPLARS\n"; + my $mergeLog="$logdir/prune-merge.log.$iteration"; + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + check_call("mkdir -p $dir/splag.$im1"); + $cmd="split -a 3 -l $lines_per_mapper $dir/agenda.$im1 $dir/splag.$im1/mapinput."; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + opendir(DIR, "$dir/splag.$im1") or die "Can't open directory: $!"; + my @shards = grep { /^mapinput\./ } readdir(DIR); + closedir DIR; + die "No shards!" unless scalar @shards > 0; + my $joblist = ""; + my $nmappers = 0; + @cleanupcmds = (); + my %o2i = (); + my $first_shard = 1; + my $mkfile; # only used with makefiles + my $mkfilename; + if ($use_make) { + $mkfilename = "$dir/splag.$im1/domap.mk"; + open $mkfile, ">$mkfilename" or die "Couldn't write $mkfilename: $!"; + print $mkfile "all: $dir/splag.$im1/map.done\n\n"; + } + my @mkouts = (); # only used with makefiles + my @mapoutputs = (); + for my $shard (@shards) { + my $mapoutput = $shard; + my $client_name = $shard; + $client_name =~ s/mapinput.//; + $client_name = "pro.$client_name"; + $mapoutput =~ s/mapinput/mapoutput/; + push @mapoutputs, "$dir/splag.$im1/$mapoutput"; + $o2i{"$dir/splag.$im1/$mapoutput"} = "$dir/splag.$im1/$shard"; + my $script = "$MAPPER -s $srcFile -m $metric -r $refs -w $inweights -K $dir/kbest < $dir/splag.$im1/$shard > $dir/splag.$im1/$mapoutput"; + if ($use_make) { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "#!/bin/bash\n"; + print F "$script\n"; + close F; + my $output = "$dir/splag.$im1/$mapoutput"; + push @mkouts, $output; + chmod(0755, $script_file) or die "Can't chmod $script_file: $!"; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + print $mkfile "$output: $dir/splag.$im1/$shard\n\t$script_file\n\n"; + } else { + my $script_file = "$dir/scripts/map.$shard"; + open F, ">$script_file" or die "Can't write $script_file: $!"; + print F "$script\n"; + close F; + if ($first_shard) { print STDERR "$script\n"; $first_shard=0; } + + $nmappers++; + my $qcmd = "$QSUB_CMD -N $client_name -o /dev/null -e $logdir/$client_name.ER $script_file"; + my $jobid = check_output("$qcmd"); + chomp $jobid; + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + push(@cleanupcmds, "qdel $jobid 2> /dev/null"); + print STDERR " $jobid"; + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + } + } + my @dev_outs = (); + my @devtest_outs = (); + @dev_outs = @mapoutputs; + if ($use_make) { + print $mkfile "$dir/splag.$im1/map.done: @mkouts\n\ttouch $dir/splag.$im1/map.done\n\n"; + close $mkfile; + my $mcmd = "make -j $jobs -f $mkfilename"; + print STDERR "\nExecuting: $mcmd\n"; + check_call($mcmd); + } else { + print STDERR "\nLaunched $nmappers mappers.\n"; + sleep 8; + print STDERR "Waiting for mappers to complete...\n"; + while ($nmappers > 0) { + sleep 5; + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat | grep -v ' C '"))); + $nmappers = scalar @livejobs; + } + print STDERR "All mappers complete.\n"; + } + my $tol = 0; + my $til = 0; + my $dev_test_file = "$dir/splag.$im1/devtest.gz"; + print STDERR "\nRUNNING CLASSIFIER (REDUCER)\n"; + print STDERR unchecked_output("date"); + $cmd="cat @dev_outs | $REDUCER -w $dir/weights.$im1 -C $reg -y $reg_previous --interpolate_with_weights $psi"; + $cmd .= " > $dir/weights.$iteration"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + $lastWeightsFile = "$dir/weights.$iteration"; + $lastPScore = $score; + $iteration++; + print STDERR "\n==========\n"; +} + + +check_call("cp $lastWeightsFile $dir/weights.final"); +print STDERR "\nFINAL WEIGHTS: $dir/weights.final\n(Use -w with the decoder)\n\n"; +print STDOUT "$dir/weights.final\n"; + +exit 0; + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; + die "Empty dev set!" if ($i == 0); +} + +sub print_help { + + my $executable = basename($0); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable [options] + Runs a complete PRO optimization using the ini file specified. + +Required: + + --config + Decoder configuration file. + + --devset + Dev set source and reference data. + + --weights + Initial weights file (use empty file to start from 0) + +General options: + + --help + Print this message and exit. + + --max-iterations + Maximum number of iterations to run. If not specified, defaults + to $default_max_iter. + + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --pass-suffix + If the decoder is doing multi-pass decoding, the pass suffix "2", + "3", etc., is used to control what iteration of weights is set. + + --workdir + Directory for intermediate and output files. If not specified, the + name is derived from the ini filename. Assuming that the ini + filename begins with the decoder name and ends with ini, the default + name of the working directory is inferred from the middle part of + the filename. E.g. an ini file named decoder.foo.ini would have + a default working directory name foo. + +Regularization options: + + --reg + l2 regularization strength [default=500]. The greater this value, + the closer to zero the weights will be. + + --reg-previous + l2 penalty for moving away from the weights from the previous + iteration. [default=5000]. The greater this value, the closer + to the previous iteration's weights the next iteration's weights + will be. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Deprecated options: + + --interpolate-with-weights + [deprecated] At each iteration the resulting weights are + interpolated with the weights from the previous iteration, with + this factor. [default=1.0, i.e., no effect] + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} + +sub split_devset { + my ($infile, $outsrc, $outref) = @_; + open F, "<$infile" or die "Can't read $infile: $!"; + open S, ">$outsrc" or die "Can't write $outsrc: $!"; + open R, ">$outref" or die "Can't write $outref: $!"; + while() { + chomp; + my ($src, @refs) = split /\s*\|\|\|\s*/; + die "Malformed devset line: $_\n" unless scalar @refs > 0; + print S "$src\n"; + print R join(' ||| ', @refs) . "\n"; + } + close R; + close S; + close F; +} + diff --git a/training/rampion/Makefile.am b/training/rampion/Makefile.am new file mode 100644 index 00000000..1633d0f7 --- /dev/null +++ b/training/rampion/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = rampion_cccp + +rampion_cccp_SOURCES = rampion_cccp.cc +rampion_cccp_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz + +AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils diff --git a/training/rampion/rampion.pl b/training/rampion/rampion.pl new file mode 100755 index 00000000..ae084db6 --- /dev/null +++ b/training/rampion/rampion.pl @@ -0,0 +1,540 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment", "$SCRIPT_DIR/../utils"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use IPC::Open2; +use POSIX ":sys_wait_h"; +my $QSUB_CMD = qsub_args(mert_memory()); +my $default_jobs = env_default_jobs(); + +my $UTILS_DIR="$SCRIPT_DIR/../utils"; +require "$UTILS_DIR/libcall.pl"; + +# Default settings +my $srcFile; +my $refFiles; +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $MAPINPUT = "$bin_dir/rampion_generate_input.pl"; +my $MAPPER = "$bin_dir/rampion_cccp"; +my $parallelize = "$UTILS_DIR/parallelize.pl"; +my $libcall = "$UTILS_DIR/libcall.pl"; +my $sentserver = "$UTILS_DIR/sentserver"; +my $sentclient = "$UTILS_DIR/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +die "Can't find $MAPPER" unless -x $MAPPER; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $lines_per_mapper = 30; +my $iteration = 1; +my $best_weights; +my $psi = 1; +my $default_max_iter = 30; +my $max_iterations = $default_max_iter; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "4g"; +my $disable_clean = 0; +my %seen_weights; +my $help = 0; +my $epsilon = 0.0001; +my $dryrun = 0; +my $last_score = -10000000; +my $metric = "ibm_bleu"; +my $dir; +my $iniFile; +my $weights; +my $use_make = 1; # use make to parallelize +my $useqsub = 0; +my $initial_weights; +my $pass_suffix = ''; +my $cpbin=1; + +# regularization strength +my $tune_regularizer = 0; +my $reg = 500; +my $reg_previous = 5000; +my $dont_accum = 0; + +# Process command-line options +Getopt::Long::Configure("no_auto_abbrev"); +if (GetOptions( + "jobs=i" => \$jobs, + "dont-clean" => \$disable_clean, + "dont-accumulate" => \$dont_accum, + "pass-suffix=s" => \$pass_suffix, + "qsub" => \$useqsub, + "dry-run" => \$dryrun, + "epsilon=s" => \$epsilon, + "help" => \$help, + "weights=s" => \$initial_weights, + "reg=f" => \$reg, + "use-make=i" => \$use_make, + "max-iterations=i" => \$max_iterations, + "pmem=s" => \$pmem, + "cpbin!" => \$cpbin, + "ref-files=s" => \$refFiles, + "metric=s" => \$metric, + "source-file=s" => \$srcFile, + "workdir=s" => \$dir, +) == 0 || @ARGV!=1 || $help) { + print_help(); + exit; +} + +die "--tune-regularizer is no longer supported with --reg-previous and --reg. Please tune manually.\n" if $tune_regularizer; + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); +if (!defined $srcFile) { push @missing_args, "--source-file"; } +if (!defined $refFiles) { push @missing_args, "--ref-files"; } +if (!defined $initial_weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\n" if (@missing_args); + +if ($metric =~ /^(combi|ter)$/i) { + $lines_per_mapper = 5; +} + +($iniFile) = @ARGV; + + +sub write_config; +sub enseg; +sub print_help; + +my $nodelist; +my $host =check_output("hostname"); chomp $host; +my $bleu; +my $interval_count = 0; +my $logfile; +my $projected_score; + +# used in sorting scores +my $DIR_FLAG = '-r'; +if ($metric =~ /^ter$|^aer$/i) { + $DIR_FLAG = ''; +} + +my $refs_comma_sep = get_comma_sep_refs('r',$refFiles); + +unless ($dir){ + $dir = "rampion"; +} +unless ($dir =~ /^\//){ # convert relative path to absolute path + my $basedir = check_output("pwd"); + chomp $basedir; + $dir = "$basedir/$dir"; +} + + +# Initializations and helper functions +srand; + +my @childpids = (); +my @cleanupcmds = (); + +sub cleanup { + print STDERR "Cleanup...\n"; + for my $pid (@childpids){ unchecked_call("kill $pid"); } + for my $cmd (@cleanupcmds){ unchecked_call("$cmd"); } + exit 1; +}; +# Always call cleanup, no matter how we exit +*CORE::GLOBAL::exit = + sub{ cleanup(); }; +$SIG{INT} = "cleanup"; +$SIG{TERM} = "cleanup"; +$SIG{HUP} = "cleanup"; + +my $decoderBase = check_output("basename $decoder"); chomp $decoderBase; +my $newIniFile = "$dir/$decoderBase.ini"; +my $inputFileName = "$dir/input"; +my $user = $ENV{"USER"}; +# process ini file +-e $iniFile || die "Error: could not open $iniFile for reading\n"; +open(INI, $iniFile); + +use File::Basename qw(basename); +#pass bindir, refs to vars holding bin +sub modbin { + local $_; + my $bindir=shift; + check_call("mkdir -p $bindir"); + -d $bindir || die "couldn't make bindir $bindir"; + for (@_) { + my $src=$$_; + $$_="$bindir/".basename($src); + check_call("cp -p $src $$_"); + } +} +sub dirsize { + opendir ISEMPTY,$_[0]; + return scalar(readdir(ISEMPTY))-1; +} +my @allweights; +if ($dryrun){ + write_config(*STDERR); + exit 0; +} else { + if (-e $dir && dirsize($dir)>1 && -e "$dir/hgs" ){ # allow preexisting logfile, binaries, but not dist-pro.pl outputs + die "ERROR: working dir $dir already exists\n\n"; + } else { + -e $dir || mkdir $dir; + mkdir "$dir/hgs"; + modbin("$dir/bin",\$LocalConfig,\$cdec,\$SCORER,\$MAPINPUT,\$MAPPER,\$parallelize,\$sentserver,\$sentclient,\$libcall) if $cpbin; + mkdir "$dir/scripts"; + my $cmdfile="$dir/rerun-pro.sh"; + open CMD,'>',$cmdfile; + print CMD "cd ",&getcwd,"\n"; +# print CMD &escaped_cmdline,"\n"; #buggy - last arg is quoted. + my $cline=&cmdline."\n"; + print CMD $cline; + close CMD; + print STDERR $cline; + chmod(0755,$cmdfile); + check_call("cp $initial_weights $dir/weights.0"); + die "Can't find weights.0" unless (-e "$dir/weights.0"); + } + write_config(*STDERR); +} + + +# Generate initial files and values +check_call("cp $iniFile $newIniFile"); +$iniFile = $newIniFile; + +my $newsrc = "$dir/dev.input"; +enseg($srcFile, $newsrc); +$srcFile = $newsrc; +my $devSize = 0; +open F, "<$srcFile" or die "Can't read $srcFile: $!"; +while() { $devSize++; } +close F; + +unless($best_weights){ $best_weights = $weights; } +unless($projected_score){ $projected_score = 0.0; } +$seen_weights{$weights} = 1; +my $kbest = "$dir/kbest"; +if ($dont_accum) { + $kbest = ''; +} else { + check_call("mkdir -p $kbest"); + $kbest = "--kbest_repository $kbest"; +} + +my $random_seed = int(time / 1000); +my $lastWeightsFile; +my $lastPScore = 0; +# main optimization loop +while (1){ + print STDERR "\n\nITERATION $iteration\n==========\n"; + + if ($iteration > $max_iterations){ + print STDERR "\nREACHED STOPPING CRITERION: Maximum iterations\n"; + last; + } + # iteration-specific files + my $runFile="$dir/run.raw.$iteration"; + my $onebestFile="$dir/1best.$iteration"; + my $logdir="$dir/logs.$iteration"; + my $decoderLog="$logdir/decoder.sentserver.log.$iteration"; + my $scorerLog="$logdir/scorer.log.$iteration"; + check_call("mkdir -p $logdir"); + + + #decode + print STDERR "RUNNING DECODER AT "; + print STDERR unchecked_output("date"); + my $im1 = $iteration - 1; + my $weightsFile="$dir/weights.$im1"; + push @allweights, "-w $dir/weights.$im1"; + `rm -f $dir/hgs/*.gz`; + my $decoder_cmd = "$decoder -c $iniFile --weights$pass_suffix $weightsFile -O $dir/hgs"; + my $pcmd; + if ($use_make) { + $pcmd = "cat $srcFile | $parallelize --use-fork -p $pmem -e $logdir -j $jobs --"; + } else { + $pcmd = "cat $srcFile | $parallelize -p $pmem -e $logdir -j $jobs --"; + } + my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $runFile"; + print STDERR "COMMAND:\n$cmd\n"; + check_bash_call($cmd); + my $num_hgs; + my $num_topbest; + my $retries = 0; + while($retries < 5) { + $num_hgs = check_output("ls $dir/hgs/*.gz | wc -l"); + $num_topbest = check_output("wc -l < $runFile"); + print STDERR "NUMBER OF HGs: $num_hgs\n"; + print STDERR "NUMBER OF TOP-BEST HYPs: $num_topbest\n"; + if($devSize == $num_hgs && $devSize == $num_topbest) { + last; + } else { + print STDERR "Incorrect number of hypergraphs or topbest. Waiting for distributed filesystem and retrying...\n"; + sleep(3); + } + $retries++; + } + die "Dev set contains $devSize sentences, but we don't have topbest and hypergraphs for all these! Decoder failure? Check $decoderLog\n" if ($devSize != $num_hgs || $devSize != $num_topbest); + my $dec_score = check_output("cat $runFile | $SCORER $refs_comma_sep -m $metric"); + chomp $dec_score; + print STDERR "DECODER SCORE: $dec_score\n"; + + # save space + check_call("gzip -f $runFile"); + check_call("gzip -f $decoderLog"); + + # run optimizer + print STDERR "RUNNING OPTIMIZER AT "; + print STDERR unchecked_output("date"); + print STDERR " - GENERATE TRAINING EXEMPLARS\n"; + my $mergeLog="$logdir/prune-merge.log.$iteration"; + + my $score = 0; + my $icc = 0; + my $inweights="$dir/weights.$im1"; + my $outweights="$dir/weights.$iteration"; + $cmd="$MAPINPUT $dir/hgs > $dir/agenda.$im1"; + print STDERR "COMMAND:\n$cmd\n"; + check_call($cmd); + $cmd="$MAPPER $refs_comma_sep -m $metric -i $dir/agenda.$im1 $kbest -w $inweights > $outweights"; + check_call($cmd); + $lastWeightsFile = $outweights; + $iteration++; + `rm hgs/*.gz`; + print STDERR "\n==========\n"; +} + +print STDERR "\nFINAL WEIGHTS: $lastWeightsFile\n(Use -w with the decoder)\n\n"; + +print STDOUT "$lastWeightsFile\n"; + +exit 0; + +sub get_lines { + my $fn = shift @_; + open FL, "<$fn" or die "Couldn't read $fn: $!"; + my $lc = 0; + while() { $lc++; } + return $lc; +} + +sub get_comma_sep_refs { + my ($r,$p) = @_; + my $o = check_output("echo $p"); + chomp $o; + my @files = split /\s+/, $o; + return "-$r " . join(" -$r ", @files); +} + +sub read_weights_file { + my ($file) = @_; + open F, "<$file" or die "Couldn't read $file: $!"; + my @r = (); + my $pm = -1; + while() { + next if /^#/; + next if /^\s*$/; + chomp; + if (/^(.+)\s+(.+)$/) { + my $m = $1; + my $w = $2; + die "Weights out of order: $m <= $pm" unless $m > $pm; + push @r, $w; + } else { + warn "Unexpected feature name in weight file: $_"; + } + } + close F; + return join ' ', @r; +} + +# subs +sub write_config { + my $fh = shift; + my $cleanup = "yes"; + if ($disable_clean) {$cleanup = "no";} + + print $fh "\n"; + print $fh "DECODER: $decoder\n"; + print $fh "INI FILE: $iniFile\n"; + print $fh "WORKING DIR: $dir\n"; + print $fh "SOURCE (DEV): $srcFile\n"; + print $fh "REFS (DEV): $refFiles\n"; + print $fh "EVAL METRIC: $metric\n"; + print $fh "MAX ITERATIONS: $max_iterations\n"; + print $fh "JOBS: $jobs\n"; + print $fh "HEAD NODE: $host\n"; + print $fh "PMEM (DECODING): $pmem\n"; + print $fh "CLEANUP: $cleanup\n"; +} + +sub update_weights_file { + my ($neww, $rfn, $rpts) = @_; + my @feats = @$rfn; + my @pts = @$rpts; + my $num_feats = scalar @feats; + my $num_pts = scalar @pts; + die "$num_feats (num_feats) != $num_pts (num_pts)" unless $num_feats == $num_pts; + open G, ">$neww" or die; + for (my $i = 0; $i < $num_feats; $i++) { + my $f = $feats[$i]; + my $lambda = $pts[$i]; + print G "$f $lambda\n"; + } + close G; +} + +sub enseg { + my $src = shift; + my $newsrc = shift; + open(SRC, $src); + open(NEWSRC, ">$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; + die "Empty dev set!" if ($i == 0); +} + +sub print_help { + + my $executable = check_output("basename $0"); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable [options] + Runs a complete PRO optimization using the ini file specified. + +Required: + + --ref-files + Dev set ref files. This option takes only a single string argument. + To use multiple files (including file globbing), this argument should + be quoted. + + --source-file + Dev set source file. + + --weights + Initial weights file (use empty file to start from 0) + +General options: + + --help + Print this message and exit. + + --dont-accumulate + Don't accumulate k-best lists from multiple iterations. + + --max-iterations + Maximum number of iterations to run. If not specified, defaults + to $default_max_iter. + + --metric + Metric to optimize. + Example values: IBM_BLEU, NIST_BLEU, Koehn_BLEU, TER, Combi + + --pass-suffix + If the decoder is doing multi-pass decoding, the pass suffix "2", + "3", etc., is used to control what iteration of weights is set. + + --workdir + Directory for intermediate and output files. If not specified, the + name is derived from the ini filename. Assuming that the ini + filename begins with the decoder name and ends with ini, the default + name of the working directory is inferred from the middle part of + the filename. E.g. an ini file named decoder.foo.ini would have + a default working directory name foo. + +Regularization options: + + --reg + l2 regularization strength [default=500]. The greater this value, + the closer to zero the weights will be. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} diff --git a/training/rampion/rampion_cccp.cc b/training/rampion/rampion_cccp.cc new file mode 100644 index 00000000..1e36dc51 --- /dev/null +++ b/training/rampion/rampion_cccp.cc @@ -0,0 +1,168 @@ +#include +#include +#include +#include + +#include +#include + +#include "filelib.h" +#include "stringlib.h" +#include "weights.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "ns.h" +#include "ns_docscorer.h" +#include "candidate_set.h" + +using namespace std; +namespace po = boost::program_options; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("reference,r",po::value >(), "[REQD] Reference translation (tokenized text)") + ("weights,w",po::value(), "[REQD] Weights files from current iterations") + ("input,i",po::value()->default_value("-"), "Input file to map (- is STDIN)") + ("evaluation_metric,m",po::value()->default_value("IBM_BLEU"), "Evaluation metric (ibm_bleu, koehn_bleu, nist_bleu, ter, meteor, etc.)") + ("kbest_repository,R",po::value(), "Accumulate k-best lists from previous iterations (parameter is path to repository)") + ("kbest_size,k",po::value()->default_value(500u), "Top k-hypotheses to extract") + ("cccp_iterations,I", po::value()->default_value(10u), "CCCP iterations (T')") + ("ssd_iterations,J", po::value()->default_value(5u), "Stochastic subgradient iterations (T'')") + ("eta", po::value()->default_value(1e-4), "Step size") + ("regularization_strength,C", po::value()->default_value(1.0), "L2 regularization strength") + ("alpha,a", po::value()->default_value(10.0), "Cost scale (alpha); alpha * [1-metric(y,y')]") + ("help,h", "Help"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + bool flag = false; + if (!conf->count("reference")) { + cerr << "Please specify one or more references using -r \n"; + flag = true; + } + if (!conf->count("weights")) { + cerr << "Please specify weights using -w \n"; + flag = true; + } + if (flag || conf->count("help")) { + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct GainFunction { + explicit GainFunction(const EvaluationMetric* m) : metric(m) {} + float operator()(const SufficientStats& eval_feats) const { + float g = metric->ComputeScore(eval_feats); + if (!metric->IsErrorMetric()) g = 1 - g; + return g; + } + const EvaluationMetric* metric; +}; + +template +void CostAugmentedSearch(const GainFunc& gain, + const training::CandidateSet& cs, + const SparseVector& w, + double alpha, + SparseVector* fmap) { + unsigned best_i = 0; + double best = -numeric_limits::infinity(); + for (unsigned i = 0; i < cs.size(); ++i) { + double s = cs[i].fmap.dot(w) + alpha * gain(cs[i].eval_feats); + if (s > best) { + best = s; + best_i = i; + } + } + *fmap = cs[best_i].fmap; +} + + + +// runs lines 4--15 of rampion algorithm +int main(int argc, char** argv) { + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + const string evaluation_metric = conf["evaluation_metric"].as(); + + EvaluationMetric* metric = EvaluationMetric::Instance(evaluation_metric); + DocumentScorer ds(metric, conf["reference"].as >()); + cerr << "Loaded " << ds.size() << " references for scoring with " << evaluation_metric << endl; + double goodsign = -1; + double badsign = -goodsign; + + Hypergraph hg; + string last_file; + ReadFile in_read(conf["input"].as()); + string kbest_repo; + if (conf.count("kbest_repository")) { + kbest_repo = conf["kbest_repository"].as(); + MkDirP(kbest_repo); + } + istream &in=*in_read.stream(); + const unsigned kbest_size = conf["kbest_size"].as(); + const unsigned tp = conf["cccp_iterations"].as(); + const unsigned tpp = conf["ssd_iterations"].as(); + const double eta = conf["eta"].as(); + const double reg = conf["regularization_strength"].as(); + const double alpha = conf["alpha"].as(); + SparseVector weights; + { + vector vweights; + const string weightsf = conf["weights"].as(); + Weights::InitFromFile(weightsf, &vweights); + Weights::InitSparseVector(vweights, &weights); + } + string line, file; + vector kis; + cerr << "Loading hypergraphs...\n"; + while(getline(in, line)) { + istringstream is(line); + int sent_id; + kis.resize(kis.size() + 1); + training::CandidateSet& curkbest = kis.back(); + string kbest_file; + if (kbest_repo.size()) { + ostringstream os; + os << kbest_repo << "/kbest." << sent_id << ".txt.gz"; + kbest_file = os.str(); + if (FileExists(kbest_file)) + curkbest.ReadFromFile(kbest_file); + } + is >> file >> sent_id; + ReadFile rf(file); + if (kis.size() % 5 == 0) { cerr << '.'; } + if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } + HypergraphIO::ReadFromJSON(rf.stream(), &hg); + hg.Reweight(weights); + curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); + if (kbest_file.size()) + curkbest.WriteToFile(kbest_file); + } + cerr << "\nHypergraphs loaded.\n"; + + vector > goals(kis.size()); // f(x_i,y+,h+) + SparseVector fear; // f(x,y-,h-) + const GainFunction gain(metric); + for (unsigned iterp = 1; iterp <= tp; ++iterp) { + cerr << "CCCP Iteration " << iterp << endl; + for (unsigned i = 0; i < goals.size(); ++i) + CostAugmentedSearch(gain, kis[i], weights, goodsign * alpha, &goals[i]); + for (unsigned iterpp = 1; iterpp <= tpp; ++iterpp) { + cerr << " SSD Iteration " << iterpp << endl; + for (unsigned i = 0; i < goals.size(); ++i) { + CostAugmentedSearch(gain, kis[i], weights, badsign * alpha, &fear); + weights -= weights * (eta * reg / goals.size()); + weights += (goals[i] - fear) * eta; + } + } + } + vector w; + weights.init_vector(&w); + Weights::WriteToFile("-", w); + return 0; +} + diff --git a/training/rampion/rampion_generate_input.pl b/training/rampion/rampion_generate_input.pl new file mode 100755 index 00000000..b30fc4fd --- /dev/null +++ b/training/rampion/rampion_generate_input.pl @@ -0,0 +1,18 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 HG_DIR\n" unless scalar @ARGV == 1; +my $d = shift @ARGV; +die "Can't find directory $d" unless -d $d; + +opendir(DIR, $d) or die "Can't read $d: $!"; +my @hgs = grep { /\.gz$/ } readdir(DIR); +closedir DIR; + +for my $hg (@hgs) { + my $file = $hg; + my $id = $hg; + $id =~ s/(\.json)?\.gz//; + print "$d/$file $id\n"; +} + diff --git a/training/risk.cc b/training/risk.cc deleted file mode 100644 index d5a12cfd..00000000 --- a/training/risk.cc +++ /dev/null @@ -1,45 +0,0 @@ -#include "risk.h" - -#include "prob.h" -#include "candidate_set.h" -#include "ns.h" - -using namespace std; - -namespace training { - -// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)]) -double CandidateSetRisk::operator()(const vector& params, - SparseVector* g) const { - prob_t z; - for (unsigned i = 0; i < cands_.size(); ++i) { - const prob_t u(cands_[i].fmap.dot(params), init_lnx()); - z += u; - } - const double log_z = log(z); - - SparseVector exp_feats; - if (g) { - for (unsigned i = 0; i < cands_.size(); ++i) { - const double log_prob = cands_[i].fmap.dot(params) - log_z; - const double prob = exp(log_prob); - exp_feats += cands_[i].fmap * prob; - } - } - - double risk = 0; - for (unsigned i = 0; i < cands_.size(); ++i) { - const double log_prob = cands_[i].fmap.dot(params) - log_z; - const double prob = exp(log_prob); - const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) - : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); - const double r = prob * cost; - risk += r; - if (g) (*g) += (cands_[i].fmap - exp_feats) * r; - } - return risk; -} - -} - - diff --git a/training/risk.h b/training/risk.h deleted file mode 100644 index 2e8db0fb..00000000 --- a/training/risk.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef _RISK_H_ -#define _RISK_H_ - -#include -#include "sparse_vector.h" -class EvaluationMetric; - -namespace training { - class CandidateSet; - - class CandidateSetRisk { - public: - explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) : - cands_(cs), - metric_(metric) {} - // compute the risk (expected loss) of a CandidateSet - // (optional) the gradient of the risk with respect to params - double operator()(const std::vector& params, - SparseVector* g = NULL) const; - private: - const CandidateSet& cands_; - const EvaluationMetric& metric_; - }; -}; - -#endif diff --git a/training/ttables.cc b/training/ttables.cc deleted file mode 100644 index 45bf14c5..00000000 --- a/training/ttables.cc +++ /dev/null @@ -1,31 +0,0 @@ -#include "ttables.h" - -#include - -#include "dict.h" - -using namespace std; -using namespace std::tr1; - -void TTable::DeserializeProbsFromText(std::istream* in) { - int c = 0; - while(*in) { - string e; - string f; - double p; - (*in) >> e >> f >> p; - if (e.empty()) break; - ++c; - ttable[TD::Convert(e)][TD::Convert(f)] = p; - } - cerr << "Loaded " << c << " translation parameters.\n"; -} - -void TTable::SerializeHelper(string* out, const Word2Word2Double& o) { - assert(!"not implemented"); -} - -void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) { - assert(!"not implemented"); -} - diff --git a/training/ttables.h b/training/ttables.h deleted file mode 100644 index 9baa13ca..00000000 --- a/training/ttables.h +++ /dev/null @@ -1,101 +0,0 @@ -#ifndef _TTABLES_H_ -#define _TTABLES_H_ - -#include -#include - -#include "sparse_vector.h" -#include "m.h" -#include "wordid.h" -#include "tdict.h" - -class TTable { - public: - TTable() {} - typedef std::tr1::unordered_map Word2Double; - typedef std::tr1::unordered_map Word2Word2Double; - inline double prob(const int& e, const int& f) const { - const Word2Word2Double::const_iterator cit = ttable.find(e); - if (cit != ttable.end()) { - const Word2Double& cpd = cit->second; - const Word2Double::const_iterator it = cpd.find(f); - if (it == cpd.end()) return 1e-9; - return it->second; - } else { - return 1e-9; - } - } - inline void Increment(const int& e, const int& f) { - counts[e][f] += 1.0; - } - inline void Increment(const int& e, const int& f, double x) { - counts[e][f] += x; - } - void NormalizeVB(const double alpha) { - ttable.swap(counts); - for (Word2Word2Double::iterator cit = ttable.begin(); - cit != ttable.end(); ++cit) { - double tot = 0; - Word2Double& cpd = cit->second; - for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - tot += it->second + alpha; - for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); - } - counts.clear(); - } - void Normalize() { - ttable.swap(counts); - for (Word2Word2Double::iterator cit = ttable.begin(); - cit != ttable.end(); ++cit) { - double tot = 0; - Word2Double& cpd = cit->second; - for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - tot += it->second; - for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) - it->second /= tot; - } - counts.clear(); - } - // adds counts from another TTable - probabilities remain unchanged - TTable& operator+=(const TTable& rhs) { - for (Word2Word2Double::const_iterator it = rhs.counts.begin(); - it != rhs.counts.end(); ++it) { - const Word2Double& cpd = it->second; - Word2Double& tgt = counts[it->first]; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - tgt[j->first] += j->second; - } - } - return *this; - } - void ShowTTable() const { - for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) { - const Word2Double& cpd = it->second; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; - } - } - } - void ShowCounts() const { - for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) { - const Word2Double& cpd = it->second; - for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { - std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; - } - } - } - void DeserializeProbsFromText(std::istream* in); - void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } - void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } - void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } - void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } - private: - static void SerializeHelper(std::string*, const Word2Word2Double& o); - static void DeserializeHelper(const std::string&, Word2Word2Double* o); - public: - Word2Word2Double ttable; - Word2Word2Double counts; -}; - -#endif diff --git a/training/utils/candidate_set.cc b/training/utils/candidate_set.cc new file mode 100644 index 00000000..087efec3 --- /dev/null +++ b/training/utils/candidate_set.cc @@ -0,0 +1,169 @@ +#include "candidate_set.h" + +#include + +#include + +#include "verbose.h" +#include "ns.h" +#include "filelib.h" +#include "wordid.h" +#include "tdict.h" +#include "hg.h" +#include "kbest.h" +#include "viterbi.h" + +using namespace std; + +namespace training { + +struct ApproxVectorHasher { + static const size_t MASK = 0xFFFFFFFFull; + union UType { + double f; // leave as double + size_t i; + }; + static inline double round(const double x) { + UType t; + t.f = x; + size_t r = t.i & MASK; + if ((r << 1) > MASK) + t.i += MASK - r + 1; + else + t.i &= (1ull - MASK); + return t.f; + } + size_t operator()(const SparseVector& x) const { + size_t h = 0x573915839; + for (SparseVector::const_iterator it = x.begin(); it != x.end(); ++it) { + UType t; + t.f = it->second; + if (t.f) { + size_t z = (t.i >> 32); + boost::hash_combine(h, it->first); + boost::hash_combine(h, z); + } + } + return h; + } +}; + +struct ApproxVectorEquals { + bool operator()(const SparseVector& a, const SparseVector& b) const { + SparseVector::const_iterator bit = b.begin(); + for (SparseVector::const_iterator ait = a.begin(); ait != a.end(); ++ait) { + if (bit == b.end() || + ait->first != bit->first || + ApproxVectorHasher::round(ait->second) != ApproxVectorHasher::round(bit->second)) + return false; + ++bit; + } + if (bit != b.end()) return false; + return true; + } +}; + +struct CandidateCompare { + bool operator()(const Candidate& a, const Candidate& b) const { + ApproxVectorEquals eq; + return (a.ewords == b.ewords && eq(a.fmap,b.fmap)); + } +}; + +struct CandidateHasher { + size_t operator()(const Candidate& x) const { + boost::hash > hhasher; + ApproxVectorHasher vhasher; + size_t ha = hhasher(x.ewords); + boost::hash_combine(ha, vhasher(x.fmap)); + return ha; + } +}; + +static void ParseSparseVector(string& line, size_t cur, SparseVector* out) { + SparseVector& x = *out; + size_t last_start = cur; + size_t last_comma = string::npos; + while(cur <= line.size()) { + if (line[cur] == ' ' || cur == line.size()) { + if (!(cur > last_start && last_comma != string::npos && cur > last_comma)) { + cerr << "[ERROR] " << line << endl << " position = " << cur << endl; + exit(1); + } + const int fid = FD::Convert(line.substr(last_start, last_comma - last_start)); + if (cur < line.size()) line[cur] = 0; + const double val = strtod(&line[last_comma + 1], NULL); + x.set_value(fid, val); + + last_comma = string::npos; + last_start = cur+1; + } else { + if (line[cur] == '=') + last_comma = cur; + } + ++cur; + } +} + +void CandidateSet::WriteToFile(const string& file) const { + WriteFile wf(file); + ostream& out = *wf.stream(); + out.precision(10); + string ss; + for (unsigned i = 0; i < cs.size(); ++i) { + out << TD::GetString(cs[i].ewords) << endl; + out << cs[i].fmap << endl; + cs[i].eval_feats.Encode(&ss); + out << ss << endl; + } +} + +void CandidateSet::ReadFromFile(const string& file) { + if(!SILENT) cerr << "Reading candidates from " << file << endl; + ReadFile rf(file); + istream& in = *rf.stream(); + string cand; + string feats; + string ss; + while(getline(in, cand)) { + getline(in, feats); + getline(in, ss); + assert(in); + cs.push_back(Candidate()); + TD::ConvertSentence(cand, &cs.back().ewords); + ParseSparseVector(feats, 0, &cs.back().fmap); + cs.back().eval_feats = SufficientStats(ss); + } + if(!SILENT) cerr << " read " << cs.size() << " candidates\n"; +} + +void CandidateSet::Dedup() { + if(!SILENT) cerr << "Dedup in=" << cs.size(); + tr1::unordered_set u; + while(cs.size() > 0) { + u.insert(cs.back()); + cs.pop_back(); + } + tr1::unordered_set::iterator it = u.begin(); + while (it != u.end()) { + cs.push_back(*it); + it = u.erase(it); + } + if(!SILENT) cerr << " out=" << cs.size() << endl; +} + +void CandidateSet::AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer) { + KBest::KBestDerivations, ESentenceTraversal> kbest(hg, kbest_size); + + for (unsigned i = 0; i < kbest_size; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg.nodes_.size() - 1, i); + if (!d) break; + cs.push_back(Candidate(d->yield, d->feature_values)); + if (scorer) + scorer->Evaluate(d->yield, &cs.back().eval_feats); + } + Dedup(); +} + +} diff --git a/training/utils/candidate_set.h b/training/utils/candidate_set.h new file mode 100644 index 00000000..9d326ed0 --- /dev/null +++ b/training/utils/candidate_set.h @@ -0,0 +1,60 @@ +#ifndef _CANDIDATE_SET_H_ +#define _CANDIDATE_SET_H_ + +#include +#include + +#include "ns.h" +#include "wordid.h" +#include "sparse_vector.h" + +class Hypergraph; + +namespace training { + +struct Candidate { + Candidate() {} + Candidate(const std::vector& e, const SparseVector& fm) : + ewords(e), + fmap(fm) {} + Candidate(const std::vector& e, + const SparseVector& fm, + const SegmentEvaluator& se) : + ewords(e), + fmap(fm) { + se.Evaluate(ewords, &eval_feats); + } + + void swap(Candidate& other) { + eval_feats.swap(other.eval_feats); + ewords.swap(other.ewords); + fmap.swap(other.fmap); + } + + std::vector ewords; + SparseVector fmap; + SufficientStats eval_feats; +}; + +// represents some kind of collection of translation candidates, e.g. +// aggregated k-best lists, sample lists, etc. +class CandidateSet { + public: + CandidateSet() {} + inline size_t size() const { return cs.size(); } + const Candidate& operator[](size_t i) const { return cs[i]; } + + void ReadFromFile(const std::string& file); + void WriteToFile(const std::string& file) const; + void AddKBestCandidates(const Hypergraph& hg, size_t kbest_size, const SegmentEvaluator* scorer = NULL); + // TODO add code to do unique k-best + // TODO add code to draw k samples + + private: + void Dedup(); + std::vector cs; +}; + +} + +#endif diff --git a/training/utils/decode-and-evaluate.pl b/training/utils/decode-and-evaluate.pl new file mode 100755 index 00000000..1a332c08 --- /dev/null +++ b/training/utils/decode-and-evaluate.pl @@ -0,0 +1,246 @@ +#!/usr/bin/env perl +use strict; +my @ORIG_ARGV=@ARGV; +use Cwd qw(getcwd); +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } + +# Skip local config (used for distributing jobs) if we're running in local-only mode +use LocalConfig; +use Getopt::Long; +use File::Basename qw(basename); +my $QSUB_CMD = qsub_args(mert_memory()); + +require "libcall.pl"; + +# Default settings +my $default_jobs = env_default_jobs(); +my $bin_dir = $SCRIPT_DIR; +die "Bin directory $bin_dir missing/inaccessible" unless -d $bin_dir; +my $FAST_SCORE="$bin_dir/../../mteval/fast_score"; +die "Can't execute $FAST_SCORE" unless -x $FAST_SCORE; +my $parallelize = "$bin_dir/parallelize.pl"; +my $libcall = "$bin_dir/libcall.pl"; +my $sentserver = "$bin_dir/sentserver"; +my $sentclient = "$bin_dir/sentclient"; +my $LocalConfig = "$SCRIPT_DIR/../../environment/LocalConfig.pm"; + +my $SCORER = $FAST_SCORE; +my $cdec = "$bin_dir/../../decoder/cdec"; +die "Can't find decoder in $cdec" unless -x $cdec; +die "Can't find $parallelize" unless -x $parallelize; +die "Can't find $libcall" unless -e $libcall; +my $decoder = $cdec; +my $jobs = $default_jobs; # number of decode nodes +my $pmem = "9g"; +my $help = 0; +my $config; +my $test_set; +my $weights; +my $use_make = 1; +my $useqsub; +my $cpbin=1; +# Process command-line options +if (GetOptions( + "jobs=i" => \$jobs, + "help" => \$help, + "qsub" => \$useqsub, + "input=s" => \$test_set, + "config=s" => \$config, + "weights=s" => \$weights, +) == 0 || @ARGV!=0 || $help) { + print_help(); + exit; +} + +if ($useqsub) { + $use_make = 0; + die "LocalEnvironment.pm does not have qsub configuration for this host. Cannot run with --qsub!\n" unless has_qsub(); +} + +my @missing_args = (); + +if (!defined $test_set) { push @missing_args, "--input"; } +if (!defined $config) { push @missing_args, "--config"; } +if (!defined $weights) { push @missing_args, "--weights"; } +die "Please specify missing arguments: " . join (', ', @missing_args) . "\nUse --help for more information.\n" if (@missing_args); + +my @tf = localtime(time); +my $tname = basename($test_set); +$tname =~ s/\.(sgm|sgml|xml)$//i; +my $dir = "eval.$tname." . sprintf('%d%02d%02d-%02d%02d%02d', 1900+$tf[5], $tf[4], $tf[3], $tf[2], $tf[1], $tf[0]); + +my $time = unchecked_output("date"); + +check_call("mkdir -p $dir"); + +split_devset($test_set, "$dir/test.input.raw", "$dir/test.refs"); +my $refs = "-r $dir/test.refs"; +my $newsrc = "$dir/test.input"; +enseg("$dir/test.input.raw", $newsrc); +my $src_file = $newsrc; +open F, "<$src_file" or die "Can't read $src_file: $!"; close F; + +my $test_trans="$dir/test.trans"; +my $logdir="$dir/logs"; +my $decoderLog="$logdir/decoder.sentserver.log"; +check_call("mkdir -p $logdir"); + +#decode +print STDERR "RUNNING DECODER AT "; +print STDERR unchecked_output("date"); +my $decoder_cmd = "$decoder -c $config --weights $weights"; +my $pcmd; +if ($use_make) { + $pcmd = "cat $src_file | $parallelize --workdir $dir --use-fork -p $pmem -e $logdir -j $jobs --"; +} else { + $pcmd = "cat $src_file | $parallelize --workdir $dir -p $pmem -e $logdir -j $jobs --"; +} +my $cmd = "$pcmd $decoder_cmd 2> $decoderLog 1> $test_trans"; +check_bash_call($cmd); +print STDERR "DECODER COMPLETED AT "; +print STDERR unchecked_output("date"); +print STDERR "\nOUTPUT: $test_trans\n\n"; +my $bleu = check_output("cat $test_trans | $SCORER $refs -m ibm_bleu"); +chomp $bleu; +print STDERR "BLEU: $bleu\n"; +my $ter = check_output("cat $test_trans | $SCORER $refs -m ter"); +chomp $ter; +print STDERR " TER: $ter\n"; +open TR, ">$dir/test.scores" or die "Can't write $dir/test.scores: $!"; +print TR <$newsrc"); + my $i=0; + while (my $line=){ + chomp $line; + if ($line =~ /^\s* tags, you must include a zero-based id attribute"; + } + } else { + print NEWSRC "$line\n"; + } + $i++; + } + close SRC; + close NEWSRC; +} + +sub print_help { + my $executable = basename($0); chomp $executable; + print << "Help"; + +Usage: $executable [options] + + $executable --config cdec.ini --weights weights.txt [--jobs N] [--qsub] + +Options: + + --help + Print this message and exit. + + --config + A path to the cdec.ini file. + + --weights + A file specifying feature weights. + + --dir + Directory for intermediate and output files. + +Job control options: + + --jobs + Number of decoder processes to run in parallel. [default=$default_jobs] + + --qsub + Use qsub to run jobs in parallel (qsub must be configured in + environment/LocalEnvironment.pm) + + --pmem + Amount of physical memory requested for parallel decoding jobs + (used with qsub requests only) + +Help +} + +sub convert { + my ($str) = @_; + my @ps = split /;/, $str; + my %dict = (); + for my $p (@ps) { + my ($k, $v) = split /=/, $p; + $dict{$k} = $v; + } + return %dict; +} + + + +sub cmdline { + return join ' ',($0,@ORIG_ARGV); +} + +#buggy: last arg gets quoted sometimes? +my $is_shell_special=qr{[ \t\n\\><|&;"'`~*?{}$!()]}; +my $shell_escape_in_quote=qr{[\\"\$`!]}; + +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} + +sub escaped_shell_args { + return map {local $_=$_;chomp;escape_shell($_)} @_; +} + +sub escaped_shell_args_str { + return join ' ',&escaped_shell_args(@_); +} + +sub escaped_cmdline { + return "$0 ".&escaped_shell_args_str(@ORIG_ARGV); +} + +sub split_devset { + my ($infile, $outsrc, $outref) = @_; + open F, "<$infile" or die "Can't read $infile: $!"; + open S, ">$outsrc" or die "Can't write $outsrc: $!"; + open R, ">$outref" or die "Can't write $outref: $!"; + while() { + chomp; + my ($src, @refs) = split /\s*\|\|\|\s*/; + die "Malformed devset line: $_\n" unless scalar @refs > 0; + print S "$src\n"; + print R join(' ||| ', @refs) . "\n"; + } + close R; + close S; + close F; +} + diff --git a/training/utils/entropy.cc b/training/utils/entropy.cc new file mode 100644 index 00000000..4fdbe2be --- /dev/null +++ b/training/utils/entropy.cc @@ -0,0 +1,41 @@ +#include "entropy.h" + +#include "prob.h" +#include "candidate_set.h" + +using namespace std; + +namespace training { + +// see Mann and McCallum "Efficient Computation of Entropy Gradient ..." for +// a mostly clear derivation of: +// g = E[ F(x,y) * log p(y|x) ] + H(y | x) * E[ F(x,y) ] +double CandidateSetEntropy::operator()(const vector& params, + SparseVector* g) const { + prob_t z; + vector dps(cands_.size()); + for (unsigned i = 0; i < cands_.size(); ++i) { + dps[i] = cands_[i].fmap.dot(params); + const prob_t u(dps[i], init_lnx()); + z += u; + } + const double log_z = log(z); + + SparseVector exp_feats; + double entropy = 0; + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + const double e_logprob = prob * log_prob; + entropy -= e_logprob; + if (g) { + (*g) += cands_[i].fmap * e_logprob; + exp_feats += cands_[i].fmap * prob; + } + } + if (g) (*g) += exp_feats * entropy; + return entropy; +} + +} + diff --git a/training/utils/entropy.h b/training/utils/entropy.h new file mode 100644 index 00000000..796589ca --- /dev/null +++ b/training/utils/entropy.h @@ -0,0 +1,22 @@ +#ifndef _CSENTROPY_H_ +#define _CSENTROPY_H_ + +#include +#include "sparse_vector.h" + +namespace training { + class CandidateSet; + + class CandidateSetEntropy { + public: + explicit CandidateSetEntropy(const CandidateSet& cs) : cands_(cs) {} + // compute the entropy (expected log likelihood) of a CandidateSet + // (optional) the gradient of the entropy with respect to params + double operator()(const std::vector& params, + SparseVector* g = NULL) const; + private: + const CandidateSet& cands_; + }; +}; + +#endif diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc new file mode 100644 index 00000000..607a7cb9 --- /dev/null +++ b/training/utils/grammar_convert.cc @@ -0,0 +1,348 @@ +/* + this program modifies cfg hypergraphs (forests) and extracts kbests? + what are: json, split ? + */ +#include +#include +#include + +#include +#include + +#include "inside_outside.h" +#include "tdict.h" +#include "filelib.h" +#include "hg.h" +#include "hg_io.h" +#include "kbest.h" +#include "viterbi.h" +#include "weights.h" + +namespace po = boost::program_options; +using namespace std; + +WordID kSTART; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value()->default_value("-"), "Input file") + ("format,f", po::value()->default_value("cfg"), "Input format. Values: cfg, json, split") + ("output,o", po::value()->default_value("json"), "Output command. Values: json, 1best") + ("reorder,r", "Add Yamada & Knight (2002) reorderings") + ("weights,w", po::value(), "Feature weights for k-best derivations [optional]") + ("collapse_weights,C", "Collapse order features into a single feature whose value is all of the locally applying feature weights") + ("k_derivations,k", po::value(), "Show k derivations and their features") + ("max_reorder,m", po::value()->default_value(999), "Move a constituent at most this far") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +int GetOrCreateNode(const WordID& lhs, map* lhs2node, Hypergraph* hg) { + int& node_id = (*lhs2node)[lhs]; + if (!node_id) + node_id = hg->AddNode(lhs)->id_ + 1; + return node_id - 1; +} + +void FilterAndCheckCorrectness(int goal, Hypergraph* hg) { + if (goal < 0) { + cerr << "Error! [S] not found in grammar!\n"; + exit(1); + } + if (hg->nodes_[goal].in_edges_.size() != 1) { + cerr << "Error! [S] has more than one rewrite!\n"; + exit(1); + } + int old_size = hg->nodes_.size(); + hg->TopologicallySortNodesAndEdges(goal); + if (hg->nodes_.size() != old_size) { + cerr << "Warning! During sorting " << (old_size - hg->nodes_.size()) << " disappeared!\n"; + } + vector inside; // inside score at each node + double p = Inside(*hg, &inside); + if (!p) { + cerr << "Warning! Grammar defines the empty language!\n"; + hg->clear(); + return; + } + vector prune(hg->edges_.size(), false); + int bad_edges = 0; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + bool bad = false; + for (unsigned j = 0; j < edge.tail_nodes_.size(); ++j) { + if (!inside[edge.tail_nodes_[j]]) { + bad = true; + ++bad_edges; + } + } + prune[i] = bad; + } + cerr << "Removing " << bad_edges << " bad edges from the grammar.\n"; + for (unsigned i = 0; i < hg->edges_.size(); ++i) { + if (prune[i]) + cerr << " " << hg->edges_[i].rule_->AsString() << endl; + } + hg->PruneEdges(prune); +} + +void CreateEdge(const TRulePtr& r, const Hypergraph::TailNodeVector& tail, Hypergraph::Node* head_node, Hypergraph* hg) { + Hypergraph::Edge* new_edge = hg->AddEdge(r, tail); + hg->ConnectEdgeToHeadNode(new_edge, head_node); + new_edge->feature_values_ = r->scores_; +} + +// from a category label like "NP_2", return "NP" +string PureCategory(WordID cat) { + assert(cat < 0); + string c = TD::Convert(cat*-1); + size_t p = c.find("_"); + if (p == string::npos) return c; + return c.substr(0, p); +}; + +string ConstituentOrderFeature(const TRule& rule, const vector& pi) { + const static string kTERM_VAR = "x"; + const vector& f = rule.f(); + map used; + vector terms(f.size()); + for (int i = 0; i < f.size(); ++i) { + const string term = (f[i] < 0 ? PureCategory(f[i]) : kTERM_VAR); + int& count = used[term]; + if (!count) { + terms[i] = term; + } else { + ostringstream os; + os << term << count; + terms[i] = os.str(); + } + ++count; + } + ostringstream os; + os << PureCategory(rule.GetLHS()) << ':'; + for (int i = 0; i < f.size(); ++i) { + if (i > 0) os << '_'; + os << terms[pi[i]]; + } + return os.str(); +} + +bool CheckPermutationMask(const vector& mask, const vector& pi) { + assert(mask.size() == pi.size()); + + int req_min = -1; + int cur_max = 0; + int cur_mask = -1; + for (int i = 0; i < mask.size(); ++i) { + if (mask[i] != cur_mask) { + cur_mask = mask[i]; + req_min = cur_max - 1; + } + if (pi[i] > req_min) { + if (pi[i] > cur_max) cur_max = pi[i]; + } else { + return false; + } + } + + return true; +} + +void PermuteYKRecursive(int nodeid, const WordID& parent, const int max_reorder, Hypergraph* hg) { + // Hypergraph tmp = *hg; + Hypergraph::Node* node = &hg->nodes_[nodeid]; + if (node->in_edges_.size() != 1) { + cerr << "Multiple rewrites of [" << TD::Convert(node->cat_ * -1) << "] (parent is [" << TD::Convert(parent*-1) << "])\n"; + cerr << " not recursing!\n"; + return; + } +// for (int eii = 0; eii < node->in_edges_.size(); ++eii) { + const int oe_index = node->in_edges_.front(); + const TRule& rule = *hg->edges_[oe_index].rule_; + const Hypergraph::TailNodeVector orig_tail = hg->edges_[oe_index].tail_nodes_; + const int tail_size = orig_tail.size(); + for (int i = 0; i < tail_size; ++i) { + PermuteYKRecursive(hg->edges_[oe_index].tail_nodes_[i], node->cat_, max_reorder, hg); + } + const vector& of = rule.f_; + if (of.size() == 1) return; + // cerr << "Permuting [" << TD::Convert(node->cat_ * -1) << "]\n"; + // cerr << "ORIG: " << rule.AsString() << endl; + vector pi(of.size(), 0); + for (int i = 0; i < pi.size(); ++i) pi[i] = i; + + vector permutation_mask(of.size(), 0); + const bool dont_reorder_across_PU = true; // TODO add configuration + if (dont_reorder_across_PU) { + int cur = 0; + for (int i = 0; i < pi.size(); ++i) { + if (of[i] >= 0) continue; + const string cat = PureCategory(of[i]); + if (cat == "PU" || cat == "PU!H" || cat == "PUNC" || cat == "PUNC!H" || cat == "CC") { + ++cur; + permutation_mask[i] = cur; + ++cur; + } else { + permutation_mask[i] = cur; + } + } + } + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + hg->edges_[oe_index].feature_values_.set_value(fid, 1.0); + while (next_permutation(pi.begin(), pi.end())) { + if (!CheckPermutationMask(permutation_mask, pi)) + continue; + vector nf(pi.size(), 0); + Hypergraph::TailNodeVector tail(pi.size(), 0); + bool skip = false; + for (int i = 0; i < pi.size(); ++i) { + int dist = pi[i] - i; if (dist < 0) dist *= -1; + if (dist > max_reorder) { skip = true; break; } + nf[i] = of[pi[i]]; + tail[i] = orig_tail[pi[i]]; + } + if (skip) continue; + TRulePtr nr(new TRule(rule)); + nr->f_ = nf; + int fid = FD::Convert(ConstituentOrderFeature(rule, pi)); + nr->scores_.set_value(fid, 1.0); + // cerr << "PERM: " << nr->AsString() << endl; + CreateEdge(nr, tail, node, hg); + } + // } +} + +void PermuteYamadaAndKnight(Hypergraph* hg, int max_reorder) { + assert(hg->nodes_.back().cat_ == kSTART); + assert(hg->nodes_.back().in_edges_.size() == 1); + PermuteYKRecursive(hg->nodes_.size() - 1, kSTART, max_reorder, hg); +} + +void CollapseWeights(Hypergraph* hg) { + int fid = FD::Convert("Reordering"); + for (int i = 0; i < hg->edges_.size(); ++i) { + Hypergraph::Edge& edge = hg->edges_[i]; + edge.feature_values_.clear(); + if (edge.edge_prob_ != prob_t::Zero()) { + edge.feature_values_.set_value(fid, log(edge.edge_prob_)); + } + } +} + +void ProcessHypergraph(const vector& w, const po::variables_map& conf, const string& ref, Hypergraph* hg) { + if (conf.count("reorder")) + PermuteYamadaAndKnight(hg, conf["max_reorder"].as()); + if (w.size() > 0) { hg->Reweight(w); } + if (conf.count("collapse_weights")) CollapseWeights(hg); + if (conf["output"].as() == "json") { + HypergraphIO::WriteToJSON(*hg, false, &cout); + if (!ref.empty()) { cerr << "REF: " << ref << endl; } + } else { + vector onebest; + ViterbiESentence(*hg, &onebest); + if (ref.empty()) { + cout << TD::GetString(onebest) << endl; + } else { + cout << TD::GetString(onebest) << " ||| " << ref << endl; + } + } + if (conf.count("k_derivations")) { + const int k = conf["k_derivations"].as(); + KBest::KBestDerivations, ESentenceTraversal> kbest(*hg, k); + for (int i = 0; i < k; ++i) { + const KBest::KBestDerivations, ESentenceTraversal>::Derivation* d = + kbest.LazyKthBest(hg->nodes_.size() - 1, i); + if (!d) break; + cerr << log(d->score) << " ||| " << TD::GetString(d->yield) << " ||| " << d->feature_values << endl; + } + } +} + +int main(int argc, char **argv) { + kSTART = TD::Convert("S") * -1; + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + string infile = conf["input"].as(); + const bool is_split_input = (conf["format"].as() == "split"); + const bool is_json_input = is_split_input || (conf["format"].as() == "json"); + const bool collapse_weights = conf.count("collapse_weights"); + vector w; + if (conf.count("weights")) + Weights::InitFromFile(conf["weights"].as(), &w); + + if (collapse_weights && !w.size()) { + cerr << "--collapse_weights requires a weights file to be specified!\n"; + exit(1); + } + ReadFile rf(infile); + istream* in = rf.stream(); + assert(*in); + int lc = 0; + Hypergraph hg; + map lhs2node; + while(*in) { + string line; + ++lc; + getline(*in, line); + if (is_json_input) { + if (line.empty() || line[0] == '#') continue; + string ref; + if (is_split_input) { + size_t pos = line.rfind("}}"); + assert(pos != string::npos); + size_t rstart = line.find("||| ", pos); + assert(rstart != string::npos); + ref = line.substr(rstart + 4); + line = line.substr(0, pos + 2); + } + istringstream is(line); + if (HypergraphIO::ReadFromJSON(&is, &hg)) { + ProcessHypergraph(w, conf, ref, &hg); + hg.clear(); + } else { + cerr << "Error reading grammar from JSON: line " << lc << endl; + exit(1); + } + } else { + if (line.empty()) { + int goal = lhs2node[kSTART] - 1; + FilterAndCheckCorrectness(goal, &hg); + ProcessHypergraph(w, conf, "", &hg); + hg.clear(); + lhs2node.clear(); + continue; + } + if (line[0] == '#') continue; + if (line[0] != '[') { + cerr << "Line " << lc << ": bad format\n"; + exit(1); + } + TRulePtr tr(TRule::CreateRuleMonolingual(line)); + Hypergraph::TailNodeVector tail; + for (int i = 0; i < tr->f_.size(); ++i) { + WordID var_cat = tr->f_[i]; + if (var_cat < 0) + tail.push_back(GetOrCreateNode(var_cat, &lhs2node, &hg)); + } + const WordID lhs = tr->GetLHS(); + int head = GetOrCreateNode(lhs, &lhs2node, &hg); + Hypergraph::Edge* edge = hg.AddEdge(tr, tail); + edge->feature_values_ = tr->scores_; + Hypergraph::Node* node = &hg.nodes_[head]; + hg.ConnectEdgeToHeadNode(edge, node); + } + } +} + diff --git a/training/utils/lbfgs.h b/training/utils/lbfgs.h new file mode 100644 index 00000000..e8baecab --- /dev/null +++ b/training/utils/lbfgs.h @@ -0,0 +1,1459 @@ +#ifndef SCITBX_LBFGS_H +#define SCITBX_LBFGS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace scitbx { + +//! Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) %minimizer. +/*! Implementation of the + Limited-memory Broyden-Fletcher-Goldfarb-Shanno (LBFGS) + algorithm for large-scale multidimensional minimization + problems. + + This code was manually derived from Java code which was + in turn derived from the Fortran program + lbfgs.f. The Java translation was + effected mostly mechanically, with some manual + clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran + code have been pasted in. + + Information on the original LBFGS Fortran source code is + available at + http://www.netlib.org/opt/lbfgs_um.shar . The following + information is taken verbatim from the Netlib documentation + for the Fortran source. + +
+    file    opt/lbfgs_um.shar
+    for     unconstrained optimization problems
+    alg     limited memory BFGS method
+    by      J. Nocedal
+    contact nocedal@eecs.nwu.edu
+    ref     D. C. Liu and J. Nocedal, ``On the limited memory BFGS method for
+    ,       large scale optimization methods'' Mathematical Programming 45
+    ,       (1989), pp. 503-528.
+    ,       (Postscript file of this paper is available via anonymous ftp
+    ,       to eecs.nwu.edu in the directory pub/%lbfgs/lbfgs_um.)
+    
+ + @author Jorge Nocedal: original Fortran version, including comments + (July 1990).
+ Robert Dodier: Java translation, August 1997.
+ Ralf W. Grosse-Kunstleve: C++ port, March 2002.
+ Chris Dyer: serialize/deserialize functionality + */ +namespace lbfgs { + + //! Generic exception class for %lbfgs %error messages. + /*! All exceptions thrown by the minimizer are derived from this class. + */ + class error : public std::exception { + public: + //! Constructor. + error(std::string const& msg) throw() + : msg_("lbfgs error: " + msg) + {} + //! Access to error message. + virtual const char* what() const throw() { return msg_.c_str(); } + protected: + virtual ~error() throw() {} + std::string msg_; + public: + static std::string itoa(unsigned long i) { + std::ostringstream os; + os << i; + return os.str(); + } + }; + + //! Specific exception class. + class error_internal_error : public error { + public: + //! Constructor. + error_internal_error(const char* file, unsigned long line) throw() + : error( + "Internal Error: " + std::string(file) + "(" + itoa(line) + ")") + {} + }; + + //! Specific exception class. + class error_improper_input_parameter : public error { + public: + //! Constructor. + error_improper_input_parameter(std::string const& msg) throw() + : error("Improper input parameter: " + msg) + {} + }; + + //! Specific exception class. + class error_improper_input_data : public error { + public: + //! Constructor. + error_improper_input_data(std::string const& msg) throw() + : error("Improper input data: " + msg) + {} + }; + + //! Specific exception class. + class error_search_direction_not_descent : public error { + public: + //! Constructor. + error_search_direction_not_descent() throw() + : error("The search direction is not a descent direction.") + {} + }; + + //! Specific exception class. + class error_line_search_failed : public error { + public: + //! Constructor. + error_line_search_failed(std::string const& msg) throw() + : error("Line search failed: " + msg) + {} + }; + + //! Specific exception class. + class error_line_search_failed_rounding_errors + : public error_line_search_failed { + public: + //! Constructor. + error_line_search_failed_rounding_errors(std::string const& msg) throw() + : error_line_search_failed(msg) + {} + }; + + namespace detail { + + template + inline + NumType + pow2(NumType const& x) { return x * x; } + + template + inline + NumType + abs(NumType const& x) { + if (x < NumType(0)) return -x; + return x; + } + + // This class implements an algorithm for multi-dimensional line search. + template + class mcsrch + { + protected: + int infoc; + FloatType dginit; + bool brackt; + bool stage1; + FloatType finit; + FloatType dgtest; + FloatType width; + FloatType width1; + FloatType stx; + FloatType fx; + FloatType dgx; + FloatType sty; + FloatType fy; + FloatType dgy; + FloatType stmin; + FloatType stmax; + + static FloatType const& max3( + FloatType const& x, + FloatType const& y, + FloatType const& z) + { + return x < y ? (y < z ? z : y ) : (x < z ? z : x ); + } + + public: + /* Minimize a function along a search direction. This code is + a Java translation of the function MCSRCH from + lbfgs.f, which in turn is a slight modification + of the subroutine CSRCH of More' and Thuente. + The changes are to allow reverse communication, and do not + affect the performance of the routine. This function, in turn, + calls mcstep.

+ + The Java translation was effected mostly mechanically, with + some manual clean-up; in particular, array indices start at 0 + instead of 1. Most of the comments from the Fortran code have + been pasted in here as well.

+ + The purpose of mcsrch is to find a step which + satisfies a sufficient decrease condition and a curvature + condition.

+ + At each stage this function updates an interval of uncertainty + with endpoints stx and sty. The + interval of uncertainty is initially chosen so that it + contains a minimizer of the modified function +

+                f(x+stp*s) - f(x) - ftol*stp*(gradf(x)'s).
+           
+ If a step is obtained for which the modified function has a + nonpositive function value and nonnegative derivative, then + the interval of uncertainty is chosen so that it contains a + minimizer of f(x+stp*s).

+ + The algorithm is designed to find a step which satisfies + the sufficient decrease condition +

+                 f(x+stp*s) <= f(X) + ftol*stp*(gradf(x)'s),
+           
+ and the curvature condition +
+                 abs(gradf(x+stp*s)'s)) <= gtol*abs(gradf(x)'s).
+           
+ If ftol is less than gtol and if, + for example, the function is bounded below, then there is + always a step which satisfies both conditions. If no step can + be found which satisfies both conditions, then the algorithm + usually stops when rounding errors prevent further progress. + In this case stp only satisfies the sufficient + decrease condition.

+ + @author Original Fortran version by Jorge J. More' and + David J. Thuente as part of the Minpack project, June 1983, + Argonne National Laboratory. Java translation by Robert + Dodier, August 1997. + + @param n The number of variables. + + @param x On entry this contains the base point for the line + search. On exit it contains x + stp*s. + + @param f On entry this contains the value of the objective + function at x. On exit it contains the value + of the objective function at x + stp*s. + + @param g On entry this contains the gradient of the objective + function at x. On exit it contains the gradient + at x + stp*s. + + @param s The search direction. + + @param stp On entry this contains an initial estimate of a + satifactory step length. On exit stp contains + the final estimate. + + @param ftol Tolerance for the sufficient decrease condition. + + @param xtol Termination occurs when the relative width of the + interval of uncertainty is at most xtol. + + @param maxfev Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param info This is an output variable, which can have these + values: +

    +
  • info = -1 A return is made to compute + the function and gradient. +
  • info = 1 The sufficient decrease condition + and the directional derivative condition hold. +
+ + @param nfev On exit, this is set to the number of function + evaluations. + + @param wa Temporary storage array, of length n. + */ + void run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa); + + /* The purpose of this function is to compute a safeguarded step + for a linesearch and to update an interval of uncertainty for + a minimizer of the function.

+ + The parameter stx contains the step with the + least function value. The parameter stp contains + the current step. It is assumed that the derivative at + stx is negative in the direction of the step. If + brackt is true when + mcstep returns then a minimizer has been + bracketed in an interval of uncertainty with endpoints + stx and sty.

+ + Variables that must be modified by mcstep are + implemented as 1-element arrays. + + @param stx Step at the best step obtained so far. + This variable is modified by mcstep. + @param fx Function value at the best step obtained so far. + This variable is modified by mcstep. + @param dx Derivative at the best step obtained so far. + The derivative must be negative in the direction of the + step, that is, dx and stp-stx must + have opposite signs. This variable is modified by + mcstep. + + @param sty Step at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + @param fy Function value at the other endpoint of the interval + of uncertainty. This variable is modified by + mcstep. + + @param dy Derivative at the other endpoint of the interval of + uncertainty. This variable is modified by mcstep. + + @param stp Step at the current step. If brackt is set + then on input stp must be between stx + and sty. On output stp is set to the + new step. + @param fp Function value at the current step. + @param dp Derivative at the current step. + + @param brackt Tells whether a minimizer has been bracketed. + If the minimizer has not been bracketed, then on input this + variable must be set false. If the minimizer has + been bracketed, then on output this variable is + true. + + @param stpmin Lower bound for the step. + @param stpmax Upper bound for the step. + + If the return value is 1, 2, 3, or 4, then the step has + been computed successfully. A return value of 0 indicates + improper input parameters. + + @author Jorge J. More, David J. Thuente: original Fortran version, + as part of Minpack project. Argonne Nat'l Laboratory, June 1983. + Robert Dodier: Java translation, August 1997. + */ + static int mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax); + + void serialize(std::ostream* out) const { + out->write((const char*)&infoc,sizeof(infoc)); + out->write((const char*)&dginit,sizeof(dginit)); + out->write((const char*)&brackt,sizeof(brackt)); + out->write((const char*)&stage1,sizeof(stage1)); + out->write((const char*)&finit,sizeof(finit)); + out->write((const char*)&dgtest,sizeof(dgtest)); + out->write((const char*)&width,sizeof(width)); + out->write((const char*)&width1,sizeof(width1)); + out->write((const char*)&stx,sizeof(stx)); + out->write((const char*)&fx,sizeof(fx)); + out->write((const char*)&dgx,sizeof(dgx)); + out->write((const char*)&sty,sizeof(sty)); + out->write((const char*)&fy,sizeof(fy)); + out->write((const char*)&dgy,sizeof(dgy)); + out->write((const char*)&stmin,sizeof(stmin)); + out->write((const char*)&stmax,sizeof(stmax)); + } + + void deserialize(std::istream* in) const { + in->read((char*)&infoc, sizeof(infoc)); + in->read((char*)&dginit, sizeof(dginit)); + in->read((char*)&brackt, sizeof(brackt)); + in->read((char*)&stage1, sizeof(stage1)); + in->read((char*)&finit, sizeof(finit)); + in->read((char*)&dgtest, sizeof(dgtest)); + in->read((char*)&width, sizeof(width)); + in->read((char*)&width1, sizeof(width1)); + in->read((char*)&stx, sizeof(stx)); + in->read((char*)&fx, sizeof(fx)); + in->read((char*)&dgx, sizeof(dgx)); + in->read((char*)&sty, sizeof(sty)); + in->read((char*)&fy, sizeof(fy)); + in->read((char*)&dgy, sizeof(dgy)); + in->read((char*)&stmin, sizeof(stmin)); + in->read((char*)&stmax, sizeof(stmax)); + } + }; + + template + void mcsrch::run( + FloatType const& gtol, + FloatType const& stpmin, + FloatType const& stpmax, + SizeType n, + FloatType* x, + FloatType f, + const FloatType* g, + FloatType* s, + SizeType is0, + FloatType& stp, + FloatType ftol, + FloatType xtol, + SizeType maxfev, + int& info, + SizeType& nfev, + FloatType* wa) + { + if (info != -1) { + infoc = 1; + if ( n == 0 + || maxfev == 0 + || gtol < FloatType(0) + || xtol < FloatType(0) + || stpmin < FloatType(0) + || stpmax < stpmin) { + throw error_internal_error(__FILE__, __LINE__); + } + if (stp <= FloatType(0) || ftol < FloatType(0)) { + throw error_internal_error(__FILE__, __LINE__); + } + // Compute the initial gradient in the search direction + // and check that s is a descent direction. + dginit = FloatType(0); + for (SizeType j = 0; j < n; j++) { + dginit += g[j] * s[is0+j]; + } + if (dginit >= FloatType(0)) { + throw error_search_direction_not_descent(); + } + brackt = false; + stage1 = true; + nfev = 0; + finit = f; + dgtest = ftol*dginit; + width = stpmax - stpmin; + width1 = FloatType(2) * width; + std::copy(x, x+n, wa); + // The variables stx, fx, dgx contain the values of the step, + // function, and directional derivative at the best step. + // The variables sty, fy, dgy contain the value of the step, + // function, and derivative at the other endpoint of + // the interval of uncertainty. + // The variables stp, f, dg contain the values of the step, + // function, and derivative at the current step. + stx = FloatType(0); + fx = finit; + dgx = dginit; + sty = FloatType(0); + fy = finit; + dgy = dginit; + } + for (;;) { + if (info != -1) { + // Set the minimum and maximum steps to correspond + // to the present interval of uncertainty. + if (brackt) { + stmin = std::min(stx, sty); + stmax = std::max(stx, sty); + } + else { + stmin = stx; + stmax = stp + FloatType(4) * (stp - stx); + } + // Force the step to be within the bounds stpmax and stpmin. + stp = std::max(stp, stpmin); + stp = std::min(stp, stpmax); + // If an unusual termination is to occur then let + // stp be the lowest point obtained so far. + if ( (brackt && (stp <= stmin || stp >= stmax)) + || nfev >= maxfev - 1 || infoc == 0 + || (brackt && stmax - stmin <= xtol * stmax)) { + stp = stx; + } + // Evaluate the function and gradient at stp + // and compute the directional derivative. + // We return to main program to obtain F and G. + for (SizeType j = 0; j < n; j++) { + x[j] = wa[j] + stp * s[is0+j]; + } + info=-1; + break; + } + info = 0; + nfev++; + FloatType dg(0); + for (SizeType j = 0; j < n; j++) { + dg += g[j] * s[is0+j]; + } + FloatType ftest1 = finit + stp*dgtest; + // Test for convergence. + if ((brackt && (stp <= stmin || stp >= stmax)) || infoc == 0) { + throw error_line_search_failed_rounding_errors( + "Rounding errors prevent further progress." + " There may not be a step which satisfies the" + " sufficient decrease and curvature conditions." + " Tolerances may be too small."); + } + if (stp == stpmax && f <= ftest1 && dg <= dgtest) { + throw error_line_search_failed( + "The step is at the upper bound stpmax()."); + } + if (stp == stpmin && (f > ftest1 || dg >= dgtest)) { + throw error_line_search_failed( + "The step is at the lower bound stpmin()."); + } + if (nfev >= maxfev) { + throw error_line_search_failed( + "Number of function evaluations has reached maxfev()."); + } + if (brackt && stmax - stmin <= xtol * stmax) { + throw error_line_search_failed( + "Relative width of the interval of uncertainty" + " is at most xtol()."); + } + // Check for termination. + if (f <= ftest1 && abs(dg) <= gtol * (-dginit)) { + info = 1; + break; + } + // In the first stage we seek a step for which the modified + // function has a nonpositive value and nonnegative derivative. + if ( stage1 && f <= ftest1 + && dg >= std::min(ftol, gtol) * dginit) { + stage1 = false; + } + // A modified function is used to predict the step only if + // we have not obtained a step for which the modified + // function has a nonpositive function value and nonnegative + // derivative, and if a lower function value has been + // obtained but the decrease is not sufficient. + if (stage1 && f <= fx && f > ftest1) { + // Define the modified function and derivative values. + FloatType fm = f - stp*dgtest; + FloatType fxm = fx - stx*dgtest; + FloatType fym = fy - sty*dgtest; + FloatType dgm = dg - dgtest; + FloatType dgxm = dgx - dgtest; + FloatType dgym = dgy - dgtest; + // Call cstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fxm, dgxm, sty, fym, dgym, stp, fm, dgm, + brackt, stmin, stmax); + // Reset the function and gradient values for f. + fx = fxm + stx*dgtest; + fy = fym + sty*dgtest; + dgx = dgxm + dgtest; + dgy = dgym + dgtest; + } + else { + // Call mcstep to update the interval of uncertainty + // and to compute the new step. + infoc = mcstep(stx, fx, dgx, sty, fy, dgy, stp, f, dg, + brackt, stmin, stmax); + } + // Force a sufficient decrease in the size of the + // interval of uncertainty. + if (brackt) { + if (abs(sty - stx) >= FloatType(0.66) * width1) { + stp = stx + FloatType(0.5) * (sty - stx); + } + width1 = width; + width = abs(sty - stx); + } + } + } + + template + int mcsrch::mcstep( + FloatType& stx, + FloatType& fx, + FloatType& dx, + FloatType& sty, + FloatType& fy, + FloatType& dy, + FloatType& stp, + FloatType fp, + FloatType dp, + bool& brackt, + FloatType stpmin, + FloatType stpmax) + { + bool bound; + FloatType gamma, p, q, r, s, sgnd, stpc, stpf, stpq, theta; + int info = 0; + if ( ( brackt && (stp <= std::min(stx, sty) + || stp >= std::max(stx, sty))) + || dx * (stp - stx) >= FloatType(0) || stpmax < stpmin) { + return 0; + } + // Determine if the derivatives have opposite sign. + sgnd = dp * (dx / abs(dx)); + if (fp > fx) { + // First case. A higher function value. + // The minimum is bracketed. If the cubic step is closer + // to stx than the quadratic step, the cubic step is taken, + // else the average of the cubic and quadratic steps is taken. + info = 1; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp < stx) gamma = - gamma; + p = (gamma - dx) + theta; + q = ((gamma - dx) + gamma) + dp; + r = p/q; + stpc = stx + r * (stp - stx); + stpq = stx + + ((dx / ((fx - fp) / (stp - stx) + dx)) / FloatType(2)) + * (stp - stx); + if (abs(stpc - stx) < abs(stpq - stx)) { + stpf = stpc; + } + else { + stpf = stpc + (stpq - stpc) / FloatType(2); + } + brackt = true; + } + else if (sgnd < FloatType(0)) { + // Second case. A lower function value and derivatives of + // opposite sign. The minimum is bracketed. If the cubic + // step is closer to stx than the quadratic (secant) step, + // the cubic step is taken, else the quadratic step is taken. + info = 2; + bound = false; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dx / s) * (dp / s)); + if (stp > stx) gamma = - gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dx; + r = p/q; + stpc = stp + r * (stx - stp); + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (abs(stpc - stp) > abs(stpq - stp)) { + stpf = stpc; + } + else { + stpf = stpq; + } + brackt = true; + } + else if (abs(dp) < abs(dx)) { + // Third case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative decreases. + // The cubic step is only used if the cubic tends to infinity + // in the direction of the step or if the minimum of the cubic + // is beyond stp. Otherwise the cubic step is defined to be + // either stpmin or stpmax. The quadratic (secant) step is also + // computed and if the minimum is bracketed then the the step + // closest to stx is taken, else the step farthest away is taken. + info = 3; + bound = true; + theta = FloatType(3) * (fx - fp) / (stp - stx) + dx + dp; + s = max3(abs(theta), abs(dx), abs(dp)); + gamma = s * std::sqrt( + std::max(FloatType(0), pow2(theta / s) - (dx / s) * (dp / s))); + if (stp > stx) gamma = -gamma; + p = (gamma - dp) + theta; + q = (gamma + (dx - dp)) + gamma; + r = p/q; + if (r < FloatType(0) && gamma != FloatType(0)) { + stpc = stp + r * (stx - stp); + } + else if (stp > stx) { + stpc = stpmax; + } + else { + stpc = stpmin; + } + stpq = stp + (dp / (dp - dx)) * (stx - stp); + if (brackt) { + if (abs(stp - stpc) < abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + else { + if (abs(stp - stpc) > abs(stp - stpq)) { + stpf = stpc; + } + else { + stpf = stpq; + } + } + } + else { + // Fourth case. A lower function value, derivatives of the + // same sign, and the magnitude of the derivative does + // not decrease. If the minimum is not bracketed, the step + // is either stpmin or stpmax, else the cubic step is taken. + info = 4; + bound = false; + if (brackt) { + theta = FloatType(3) * (fp - fy) / (sty - stp) + dy + dp; + s = max3(abs(theta), abs(dy), abs(dp)); + gamma = s * std::sqrt(pow2(theta / s) - (dy / s) * (dp / s)); + if (stp > sty) gamma = -gamma; + p = (gamma - dp) + theta; + q = ((gamma - dp) + gamma) + dy; + r = p/q; + stpc = stp + r * (sty - stp); + stpf = stpc; + } + else if (stp > stx) { + stpf = stpmax; + } + else { + stpf = stpmin; + } + } + // Update the interval of uncertainty. This update does not + // depend on the new step or the case analysis above. + if (fp > fx) { + sty = stp; + fy = fp; + dy = dp; + } + else { + if (sgnd < FloatType(0)) { + sty = stx; + fy = fx; + dy = dx; + } + stx = stp; + fx = fp; + dx = dp; + } + // Compute the new step and safeguard it. + stpf = std::min(stpmax, stpf); + stpf = std::max(stpmin, stpf); + stp = stpf; + if (brackt && bound) { + if (sty > stx) { + stp = std::min(stx + FloatType(0.66) * (sty - stx), stp); + } + else { + stp = std::max(stx + FloatType(0.66) * (sty - stx), stp); + } + } + return info; + } + + /* Compute the sum of a vector times a scalar plus another vector. + Adapted from the subroutine daxpy in + lbfgs.f. + */ + template + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + SizeType incx, + FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + if (n == 0) return; + if (da == FloatType(0)) return; + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dy[iy0+iy] += da * dx[ix0+ix]; + ix += incx; + iy += incy; + } + return; + } + m = n % 4; + for (i = 0; i < m; i++) { + dy[iy0+i] += da * dx[ix0+i]; + } + for (; i < n;) { + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + dy[iy0+i] += da * dx[ix0+i]; i++; + } + } + + template + inline + void daxpy( + SizeType n, + FloatType da, + const FloatType* dx, + SizeType ix0, + FloatType* dy) + { + daxpy(n, da, dx, ix0, SizeType(1), dy, SizeType(0), SizeType(1)); + } + + /* Compute the dot product of two vectors. + Adapted from the subroutine ddot + in lbfgs.f. + */ + template + FloatType ddot( + SizeType n, + const FloatType* dx, + SizeType ix0, + SizeType incx, + const FloatType* dy, + SizeType iy0, + SizeType incy) + { + SizeType i, ix, iy, m; + FloatType dtemp(0); + if (n == 0) return FloatType(0); + if (!(incx == 1 && incy == 1)) { + ix = 0; + iy = 0; + for (i = 0; i < n; i++) { + dtemp += dx[ix0+ix] * dy[iy0+iy]; + ix += incx; + iy += incy; + } + return dtemp; + } + m = n % 5; + for (i = 0; i < m; i++) { + dtemp += dx[ix0+i] * dy[iy0+i]; + } + for (; i < n;) { + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + dtemp += dx[ix0+i] * dy[iy0+i]; i++; + } + return dtemp; + } + + template + inline + FloatType ddot( + SizeType n, + const FloatType* dx, + const FloatType* dy) + { + return ddot( + n, dx, SizeType(0), SizeType(1), dy, SizeType(0), SizeType(1)); + } + + } // namespace detail + + //! Interface to the LBFGS %minimizer. + /*! This class solves the unconstrained minimization problem +

+          min f(x),  x = (x1,x2,...,x_n),
+      
+ using the limited-memory BFGS method. The routine is + especially effective on problems involving a large number of + variables. In a typical iteration of this method an + approximation Hk to the inverse of the Hessian + is obtained by applying m BFGS updates to a + diagonal matrix Hk0, using information from the + previous m steps. The user specifies the number + m, which determines the amount of storage + required by the routine. The user may also provide the + diagonal matrices Hk0 (parameter diag in the run() + function) if not satisfied with the default choice. The + algorithm is described in "On the limited memory BFGS method for + large scale optimization", by D. Liu and J. Nocedal, Mathematical + Programming B 45 (1989) 503-528. + + The user is required to calculate the function value + f and its gradient g. In order to + allow the user complete control over these computations, + reverse communication is used. The routine must be called + repeatedly under the control of the member functions + requests_f_and_g(), + requests_diag(). + If neither requests_f_and_g() nor requests_diag() is + true the user should check for convergence + (using class traditional_convergence_test or any + other custom test). If the convergence test is negative, + the minimizer may be called again for the next iteration. + + The steplength (stp()) is determined at each iteration + by means of the line search routine mcsrch, which is + a slight modification of the routine CSRCH written + by More' and Thuente. + + The only variables that are machine-dependent are + xtol, + stpmin and + stpmax. + + Fatal errors cause error exceptions to be thrown. + The generic class error is sub-classed (e.g. + class error_line_search_failed) to facilitate + granular %error handling. + + A note on performance: Using Compaq Fortran V5.4 and + Compaq C++ V6.5, the C++ implementation is about 15% slower + than the Fortran implementation. + */ + template + class minimizer + { + public: + //! Default constructor. Some members are not initialized! + minimizer() + : n_(0), m_(0), maxfev_(0), + gtol_(0), xtol_(0), + stpmin_(0), stpmax_(0), + ispt(0), iypt(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param m The number of corrections used in the BFGS update. + Values of m less than 3 are not recommended; + large values of m will result in excessive + computing time. 3 <= m <= 7 is + recommended. + Restriction: m > 0. + + @param maxfev Maximum number of function evaluations + per line search. + Termination occurs when the number of evaluations + of the objective function is at least maxfev by + the end of an iteration. + + @param gtol Controls the accuracy of the line search. + If the function and gradient evaluations are inexpensive with + respect to the cost of the iteration (which is sometimes the + case when solving very large problems) it may be advantageous + to set gtol to a small value. A typical small + value is 0.1. + Restriction: gtol should be greater than 1e-4. + + @param xtol An estimate of the machine precision (e.g. 10e-16 + on a SUN station 3/60). The line search routine will + terminate if the relative width of the interval of + uncertainty is less than xtol. + + @param stpmin Specifies the lower bound for the step + in the line search. + The default value is 1e-20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + + @param stpmax specifies the upper bound for the step + in the line search. + The default value is 1e20. This value need not be modified + unless the exponent is too large for the machine being used, + or unless the problem is extremely badly scaled (in which + case the exponent should be increased). + */ + explicit + minimizer( + SizeType n, + SizeType m = 5, + SizeType maxfev = 20, + FloatType gtol = FloatType(0.9), + FloatType xtol = FloatType(1.e-16), + FloatType stpmin = FloatType(1.e-20), + FloatType stpmax = FloatType(1.e20)) + : n_(n), m_(m), maxfev_(maxfev), + gtol_(gtol), xtol_(xtol), + stpmin_(stpmin), stpmax_(stpmax), + iflag_(0), requests_f_and_g_(false), requests_diag_(false), + iter_(0), nfun_(0), stp_(0), + stp1(0), ftol(0.0001), ys(0), point(0), npt(0), + ispt(n+2*m), iypt((n+2*m)+n*m), + info(0), bound(0), nfev(0) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (m_ == 0) { + throw error_improper_input_parameter("m = 0."); + } + if (maxfev_ == 0) { + throw error_improper_input_parameter("maxfev = 0."); + } + if (gtol_ <= FloatType(1.e-4)) { + throw error_improper_input_parameter("gtol <= 1.e-4."); + } + if (xtol_ < FloatType(0)) { + throw error_improper_input_parameter("xtol < 0."); + } + if (stpmin_ < FloatType(0)) { + throw error_improper_input_parameter("stpmin < 0."); + } + if (stpmax_ < stpmin) { + throw error_improper_input_parameter("stpmax < stpmin"); + } + w_.resize(n_*(2*m_+1)+2*m_); + scratch_array_.resize(n_); + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + //! Number of corrections kept (as passed to the constructor). + SizeType m() const { return m_; } + + /*! \brief Maximum number of evaluations of the objective function + per line search (as passed to the constructor). + */ + SizeType maxfev() const { return maxfev_; } + + /*! \brief Control of the accuracy of the line search. + (as passed to the constructor). + */ + FloatType gtol() const { return gtol_; } + + //! Estimate of the machine precision (as passed to the constructor). + FloatType xtol() const { return xtol_; } + + /*! \brief Lower bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmin() const { return stpmin_; } + + /*! \brief Upper bound for the step in the line search. + (as passed to the constructor). + */ + FloatType stpmax() const { return stpmax_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the objective function (f) and + gradients (g) for the current point + (x). To continue the minimization the + run() function is called again with the updated values for + f and g. +

+ See also: requests_diag() + */ + bool requests_f_and_g() const { return requests_f_and_g_; } + + //! Status indicator for reverse communication. + /*! true if the run() function returns to request + evaluation of the diagonal matrix (diag) + for the current point (x). + To continue the minimization the run() function is called + again with the updated values for diag. +

+ See also: requests_f_and_g() + */ + bool requests_diag() const { return requests_diag_; } + + //! Number of iterations so far. + /*! Note that one iteration may involve multiple evaluations + of the objective function. +

+ See also: nfun() + */ + SizeType iter() const { return iter_; } + + //! Total number of evaluations of the objective function so far. + /*! The total number of function evaluations increases by the + number of evaluations required for the line search. The total + is only increased after a successful line search. +

+ See also: iter() + */ + SizeType nfun() const { return nfun_; } + + //! Norm of gradient given gradient array of length n(). + FloatType euclidean_norm(const FloatType* a) const { + return std::sqrt(detail::ddot(n_, a, a)); + } + + //! Current stepsize. + FloatType stp() const { return stp_; } + + //! Execution of one step of the minimization. + /*! @param x On initial entry this must be set by the user to + the values of the initial estimate of the solution vector. + + @param f Before initial entry or on re-entry under the + control of requests_f_and_g(), f must be set + by the user to contain the value of the objective function + at the current point x. + + @param g Before initial entry or on re-entry under the + control of requests_f_and_g(), g must be set + by the user to contain the components of the gradient at + the current point x. + + The return value is true if either + requests_f_and_g() or requests_diag() is true. + Otherwise the user should check for convergence + (e.g. using class traditional_convergence_test) and + call the run() function again to continue the minimization. + If the return value is false the user + should not update f, g or + diag (other overload) before calling + the run() function again. + + Note that x is always modified by the run() + function. Depending on the situation it can therefore be + necessary to evaluate the objective function one more time + after the minimization is terminated. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g) + { + return generic_run(x, f, g, false, 0); + } + + //! Execution of one step of the minimization. + /*! @param x See other overload. + + @param f See other overload. + + @param g See other overload. + + @param diag On initial entry or on re-entry under the + control of requests_diag(), diag must be set by + the user to contain the values of the diagonal matrix Hk0. + The routine will return at each iteration of the algorithm + with requests_diag() set to true. +

+ Restriction: all elements of diag must be + positive. + */ + bool run( + FloatType* x, + FloatType f, + const FloatType* g, + const FloatType* diag) + { + return generic_run(x, f, g, true, diag); + } + + void serialize(std::ostream* out) const { + out->write((const char*)&n_, sizeof(n_)); // sanity check + out->write((const char*)&m_, sizeof(m_)); // sanity check + SizeType fs = sizeof(FloatType); + out->write((const char*)&fs, sizeof(fs)); // sanity check + + mcsrch_instance.serialize(out); + out->write((const char*)&iflag_, sizeof(iflag_)); + out->write((const char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + out->write((const char*)&requests_diag_, sizeof(requests_diag_)); + out->write((const char*)&iter_, sizeof(iter_)); + out->write((const char*)&nfun_, sizeof(nfun_)); + out->write((const char*)&stp_, sizeof(stp_)); + out->write((const char*)&stp1, sizeof(stp1)); + out->write((const char*)&ftol, sizeof(ftol)); + out->write((const char*)&ys, sizeof(ys)); + out->write((const char*)&point, sizeof(point)); + out->write((const char*)&npt, sizeof(npt)); + out->write((const char*)&info, sizeof(info)); + out->write((const char*)&bound, sizeof(bound)); + out->write((const char*)&nfev, sizeof(nfev)); + out->write((const char*)&w_[0], sizeof(FloatType) * w_.size()); + out->write((const char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + void deserialize(std::istream* in) { + SizeType n, m, fs; + in->read((char*)&n, sizeof(n)); + in->read((char*)&m, sizeof(m)); + in->read((char*)&fs, sizeof(fs)); + assert(n == n_); + assert(m == m_); + assert(fs == sizeof(FloatType)); + + mcsrch_instance.deserialize(in); + in->read((char*)&iflag_, sizeof(iflag_)); + in->read((char*)&requests_f_and_g_, sizeof(requests_f_and_g_)); + in->read((char*)&requests_diag_, sizeof(requests_diag_)); + in->read((char*)&iter_, sizeof(iter_)); + in->read((char*)&nfun_, sizeof(nfun_)); + in->read((char*)&stp_, sizeof(stp_)); + in->read((char*)&stp1, sizeof(stp1)); + in->read((char*)&ftol, sizeof(ftol)); + in->read((char*)&ys, sizeof(ys)); + in->read((char*)&point, sizeof(point)); + in->read((char*)&npt, sizeof(npt)); + in->read((char*)&info, sizeof(info)); + in->read((char*)&bound, sizeof(bound)); + in->read((char*)&nfev, sizeof(nfev)); + in->read((char*)&w_[0], sizeof(FloatType) * w_.size()); + in->read((char*)&scratch_array_[0], sizeof(FloatType) * scratch_array_.size()); + } + + protected: + static void throw_diagonal_element_not_positive(SizeType i) { + throw error_improper_input_data( + "The " + error::itoa(i) + ". diagonal element of the" + " inverse Hessian approximation is not positive."); + } + + bool generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag); + + detail::mcsrch mcsrch_instance; + const SizeType n_; + const SizeType m_; + const SizeType maxfev_; + const FloatType gtol_; + const FloatType xtol_; + const FloatType stpmin_; + const FloatType stpmax_; + int iflag_; + bool requests_f_and_g_; + bool requests_diag_; + SizeType iter_; + SizeType nfun_; + FloatType stp_; + FloatType stp1; + FloatType ftol; + FloatType ys; + SizeType point; + SizeType npt; + const SizeType ispt; + const SizeType iypt; + int info; + SizeType bound; + SizeType nfev; + std::vector w_; + std::vector scratch_array_; + }; + + template + bool minimizer::generic_run( + FloatType* x, + FloatType f, + const FloatType* g, + bool diagco, + const FloatType* diag) + { + bool execute_entire_while_loop = false; + if (!(requests_f_and_g_ || requests_diag_)) { + execute_entire_while_loop = true; + } + requests_f_and_g_ = false; + requests_diag_ = false; + FloatType* w = &(*(w_.begin())); + if (iflag_ == 0) { // Initialize. + nfun_ = 1; + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + else { + std::fill_n(scratch_array_.begin(), n_, FloatType(1)); + diag = &(*(scratch_array_.begin())); + } + for (SizeType i = 0; i < n_; i++) { + w[ispt + i] = -g[i] * diag[i]; + } + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm == FloatType(0)) return false; + stp1 = FloatType(1) / gnorm; + execute_entire_while_loop = true; + } + if (execute_entire_while_loop) { + bound = iter_; + iter_++; + info = 0; + if (iter_ != 1) { + if (iter_ > m_) bound = m_; + ys = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, ispt + npt, SizeType(1)); + if (!diagco) { + FloatType yy = detail::ddot( + n_, w, iypt + npt, SizeType(1), w, iypt + npt, SizeType(1)); + std::fill_n(scratch_array_.begin(), n_, ys / yy); + diag = &(*(scratch_array_.begin())); + } + else { + iflag_ = 2; + requests_diag_ = true; + return true; + } + } + } + if (execute_entire_while_loop || iflag_ == 2) { + if (iter_ != 1) { + if (diag == 0) { + throw error_internal_error(__FILE__, __LINE__); + } + if (diagco) { + for (SizeType i = 0; i < n_; i++) { + if (diag[i] <= FloatType(0)) { + throw_diagonal_element_not_positive(i); + } + } + } + SizeType cp = point; + if (point == 0) cp = m_; + w[n_ + cp -1] = 1 / ys; + SizeType i; + for (i = 0; i < n_; i++) { + w[i] = -g[i]; + } + cp = point; + for (i = 0; i < bound; i++) { + if (cp == 0) cp = m_; + cp--; + FloatType sq = detail::ddot( + n_, w, ispt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + SizeType inmc=n_+m_+cp; + SizeType iycn=iypt+cp*n_; + w[inmc] = w[n_ + cp] * sq; + detail::daxpy(n_, -w[inmc], w, iycn, w); + } + for (i = 0; i < n_; i++) { + w[i] *= diag[i]; + } + for (i = 0; i < bound; i++) { + FloatType yr = detail::ddot( + n_, w, iypt + cp * n_, SizeType(1), w, SizeType(0), SizeType(1)); + FloatType beta = w[n_ + cp] * yr; + SizeType inmc=n_+m_+cp; + beta = w[inmc] - beta; + SizeType iscn=ispt+cp*n_; + detail::daxpy(n_, beta, w, iscn, w); + cp++; + if (cp == m_) cp = 0; + } + std::copy(w, w+n_, w+(ispt + point * n_)); + } + stp_ = FloatType(1); + if (iter_ == 1) stp_ = stp1; + std::copy(g, g+n_, w); + } + mcsrch_instance.run( + gtol_, stpmin_, stpmax_, n_, x, f, g, w, ispt + point * n_, + stp_, ftol, xtol_, maxfev_, info, nfev, &(*(scratch_array_.begin()))); + if (info == -1) { + iflag_ = 1; + requests_f_and_g_ = true; + return true; + } + if (info != 1) { + throw error_internal_error(__FILE__, __LINE__); + } + nfun_ += nfev; + npt = point*n_; + for (SizeType i = 0; i < n_; i++) { + w[ispt + npt + i] = stp_ * w[ispt + npt + i]; + w[iypt + npt + i] = g[i] - w[i]; + } + point++; + if (point == m_) point = 0; + return false; + } + + //! Traditional LBFGS convergence test. + /*! This convergence test is equivalent to the test embedded + in the lbfgs.f Fortran code. The test assumes that + there is a meaningful relation between the Euclidean norm of the + parameter vector x and the norm of the gradient + vector g. Therefore this test should not be used if + this assumption is not correct for a given problem. + */ + template + class traditional_convergence_test + { + public: + //! Default constructor. + traditional_convergence_test() + : n_(0), eps_(0) + {} + + //! Constructor. + /*! @param n The number of variables in the minimization problem. + Restriction: n > 0. + + @param eps Determines the accuracy with which the solution + is to be found. + */ + explicit + traditional_convergence_test( + SizeType n, + FloatType eps = FloatType(1.e-5)) + : n_(n), eps_(eps) + { + if (n_ == 0) { + throw error_improper_input_parameter("n = 0."); + } + if (eps_ < FloatType(0)) { + throw error_improper_input_parameter("eps < 0."); + } + } + + //! Number of free parameters (as passed to the constructor). + SizeType n() const { return n_; } + + /*! \brief Accuracy with which the solution is to be found + (as passed to the constructor). + */ + FloatType eps() const { return eps_; } + + //! Execution of the convergence test for the given parameters. + /*! Returns true if +

+            ||g|| < eps * max(1,||x||),
+          
+ where ||.|| denotes the Euclidean norm. + + @param x Current solution vector. + + @param g Components of the gradient at the current + point x. + */ + bool + operator()(const FloatType* x, const FloatType* g) const + { + FloatType xnorm = std::sqrt(detail::ddot(n_, x, x)); + FloatType gnorm = std::sqrt(detail::ddot(n_, g, g)); + if (gnorm <= eps_ * std::max(FloatType(1), xnorm)) return true; + return false; + } + protected: + const SizeType n_; + const FloatType eps_; + }; + +}} // namespace scitbx::lbfgs + +template +std::ostream& operator<<(std::ostream& os, const scitbx::lbfgs::minimizer& min) { + return os << "ITER=" << min.iter() << "\tNFUN=" << min.nfun() << "\tSTP=" << min.stp() << "\tDIAG=" << min.requests_diag() << "\tF&G=" << min.requests_f_and_g(); +} + + +#endif // SCITBX_LBFGS_H diff --git a/training/utils/lbfgs_test.cc b/training/utils/lbfgs_test.cc new file mode 100644 index 00000000..9678e788 --- /dev/null +++ b/training/utils/lbfgs_test.cc @@ -0,0 +1,117 @@ +#include +#include +#include +#include +#include "lbfgs.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer() { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::minimizer opt(3); + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt.run(x, obj, g); + if (!opt.requests_f_and_g()) { + if (converged(x,g)) break; + opt.run(x, obj, g); + } + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + cerr << opt << endl; + } while (true); + return obj; +} + +double TestPersistentOptimizer() { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + double x[3]; + double g[3]; + scitbx::lbfgs::traditional_convergence_test converged(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + scitbx::lbfgs::minimizer opt(3); + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt.deserialize(&is); + } + opt.run(x, obj, g); + ostringstream os(ios::binary); opt.serialize(&os); state = os.str(); + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!converged(x, g)); + return obj; +} + +void TestSparseVector() { + cerr << "Testing SparseVector serialization.\n"; + int f1 = FD::Convert("Feature_1"); + int f2 = FD::Convert("Feature_2"); + FD::Convert("LanguageModel"); + int f4 = FD::Convert("SomeFeature"); + int f5 = FD::Convert("SomeOtherFeature"); + SparseVector g; + g.set_value(f2, log(0.5)); + g.set_value(f4, log(0.125)); + g.set_value(f1, 0); + g.set_value(f5, 23.777); + ostringstream os; + double iobj = 1.5; + B64::Encode(iobj, g, &os); + cerr << iobj << "\t" << g << endl; + string data = os.str(); + cout << data << endl; + SparseVector v; + double obj; + bool decode_b64 = B64::Decode(&obj, &v, &data[0], data.size()); + cerr << obj << "\t" << v << endl; + assert(decode_b64); + assert(obj == iobj); + assert(g.size() == v.size()); +} + +int main() { + double o1 = TestOptimizer(); + double o2 = TestPersistentOptimizer(); + if (fabs(o1 - o2) > 1e-5) { + cerr << "OPTIMIZERS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + return 1; + } + TestSparseVector(); + cerr << "SUCCESS\n"; + return 0; +} + diff --git a/training/utils/libcall.pl b/training/utils/libcall.pl new file mode 100644 index 00000000..c7d0f128 --- /dev/null +++ b/training/utils/libcall.pl @@ -0,0 +1,71 @@ +use IPC::Open3; +use Symbol qw(gensym); + +$DUMMY_STDERR = gensym(); +$DUMMY_STDIN = gensym(); + +# Run the command and ignore failures +sub unchecked_call { + system("@_") +} + +# Run the command and return its output, if any ignoring failures +sub unchecked_output { + return `@_` +} + +# WARNING: Do not use this for commands that will return large amounts +# of stdout or stderr -- they might block indefinitely +sub check_output { + print STDERR "Executing and gathering output: @_\n"; + + my $pid = open3($DUMMY_STDIN, \*PH, $DUMMY_STDERR, @_); + my $proc_output = ""; + while( ) { + $proc_output .= $_; + } + waitpid($pid, 0); + # TODO: Grab signal that the process died from + my $child_exit_status = $? >> 8; + if($child_exit_status == 0) { + return $proc_output; + } else { + print STDERR "ERROR: Execution of @_ failed.\n"; + exit(1); + } +} + +# Based on Moses' safesystem sub +sub check_call { + print STDERR "Executing: @_\n"; + system(@_); + my $exitcode = $? >> 8; + if($exitcode == 0) { + return 0; + } elsif ($? == -1) { + print STDERR "ERROR: Failed to execute: @_\n $!\n"; + exit(1); + + } elsif ($? & 127) { + printf STDERR "ERROR: Execution of: @_\n died with signal %d, %s coredump\n", + ($? & 127), ($? & 128) ? 'with' : 'without'; + exit(1); + + } else { + print STDERR "Failed with exit code: $exitcode\n" if $exitcode; + exit($exitcode); + } +} + +sub check_bash_call { + my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); + check_call(@args); +} + +sub check_bash_output { + my @args = ( "bash", "-auxeo", "pipefail", "-c", "@_"); + return check_output(@args); +} + +# perl module weirdness... +return 1; diff --git a/training/utils/online_optimizer.cc b/training/utils/online_optimizer.cc new file mode 100644 index 00000000..3ed95452 --- /dev/null +++ b/training/utils/online_optimizer.cc @@ -0,0 +1,16 @@ +#include "online_optimizer.h" + +LearningRateSchedule::~LearningRateSchedule() {} + +double StandardLearningRate::eta(int k) const { + return eta_0_ / (1.0 + k / N_); +} + +double ExponentialDecayLearningRate::eta(int k) const { + return eta_0_ * pow(alpha_, k / N_); +} + +OnlineOptimizer::~OnlineOptimizer() {} + +void OnlineOptimizer::ResetEpochImpl() {} + diff --git a/training/utils/online_optimizer.h b/training/utils/online_optimizer.h new file mode 100644 index 00000000..28d89344 --- /dev/null +++ b/training/utils/online_optimizer.h @@ -0,0 +1,129 @@ +#ifndef _ONL_OPTIMIZE_H_ +#define _ONL_OPTIMIZE_H_ + +#include +#include +#include +#include +#include "sparse_vector.h" + +struct LearningRateSchedule { + virtual ~LearningRateSchedule(); + // returns the learning rate for the kth iteration + virtual double eta(int k) const = 0; +}; + +// TODO in the Tsoruoaka et al. (ACL 2009) paper, they use N +// to mean the batch size in most places, but it doesn't completely +// make sense to me in the learning rate schedules-- this needs +// to be worked out to make sure they didn't mean corpus size +// in some places and batch size in others (since in the paper they +// only ever work with batch sizes of 1) +struct StandardLearningRate : public LearningRateSchedule { + StandardLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2) : + eta_0_(eta_0), + N_(static_cast(batch_size)) {} + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; +}; + +struct ExponentialDecayLearningRate : public LearningRateSchedule { + ExponentialDecayLearningRate( + size_t batch_size, // batch size, not corpus size! + double eta_0 = 0.2, + double alpha = 0.85 // recommended by Tsuruoka et al. (ACL 2009) + ) : eta_0_(eta_0), + N_(static_cast(batch_size)), + alpha_(alpha) { + assert(alpha > 0); + assert(alpha < 1.0); + } + + virtual double eta(int k) const; + + private: + const double eta_0_; + const double N_; + const double alpha_; +}; + +class OnlineOptimizer { + public: + virtual ~OnlineOptimizer(); + OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t batch_size, + const std::vector& frozen_feats = std::vector()) + : N_(batch_size),schedule_(s),k_() { + for (int i = 0; i < frozen_feats.size(); ++i) + frozen_.insert(frozen_feats[i]); + } + void ResetEpoch() { k_ = 0; ResetEpochImpl(); } + void UpdateWeights(const SparseVector& approx_g, int max_feat, SparseVector* weights) { + ++k_; + const double eta = schedule_->eta(k_); + UpdateWeightsImpl(eta, approx_g, max_feat, weights); + } + + protected: + virtual void ResetEpochImpl(); + virtual void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) = 0; + const size_t N_; // number of training instances per batch + std::set frozen_; // frozen (non-optimizing) features + + private: + std::tr1::shared_ptr schedule_; + int k_; // iteration count +}; + +class CumulativeL1OnlineOptimizer : public OnlineOptimizer { + public: + CumulativeL1OnlineOptimizer(const std::tr1::shared_ptr& s, + size_t training_instances, double C, + const std::vector& frozen) : + OnlineOptimizer(s, training_instances, frozen), C_(C), u_() {} + + protected: + void ResetEpochImpl() { u_ = 0; } + void UpdateWeightsImpl(const double& eta, const SparseVector& approx_g, int max_feat, SparseVector* weights) { + u_ += eta * C_ / N_; + for (SparseVector::const_iterator it = approx_g.begin(); + it != approx_g.end(); ++it) { + if (frozen_.count(it->first) == 0) + weights->add_value(it->first, eta * it->second); + } + for (int i = 1; i < max_feat; ++i) + if (frozen_.count(i) == 0) ApplyPenalty(i, weights); + } + + private: + void ApplyPenalty(int i, SparseVector* w) { + const double z = w->value(i); + double w_i = z; + double q_i = q_.value(i); + if (w_i > 0.0) + w_i = std::max(0.0, w_i - (u_ + q_i)); + else if (w_i < 0.0) + w_i = std::min(0.0, w_i + (u_ - q_i)); + q_i += w_i - z; + if (q_i == 0.0) + q_.erase(i); + else + q_.set_value(i, q_i); + if (w_i == 0.0) + w->erase(i); + else + w->set_value(i, w_i); + } + + const double C_; // reguarlization strength + double u_; + SparseVector q_; +}; + +#endif diff --git a/training/utils/optimize.cc b/training/utils/optimize.cc new file mode 100644 index 00000000..41ac90d8 --- /dev/null +++ b/training/utils/optimize.cc @@ -0,0 +1,102 @@ +#include "optimize.h" + +#include +#include + +#include "lbfgs.h" + +using namespace std; + +BatchOptimizer::~BatchOptimizer() {} + +void BatchOptimizer::Save(ostream* out) const { + out->write((const char*)&eval_, sizeof(eval_)); + out->write((const char*)&has_converged_, sizeof(has_converged_)); + SaveImpl(out); + unsigned int magic = 0xABCDDCBA; // should be uint32_t + out->write((const char*)&magic, sizeof(magic)); +} + +void BatchOptimizer::Load(istream* in) { + in->read((char*)&eval_, sizeof(eval_)); + in->read((char*)&has_converged_, sizeof(has_converged_)); + LoadImpl(in); + unsigned int magic = 0; // should be uint32_t + in->read((char*)&magic, sizeof(magic)); + assert(magic == 0xABCDDCBA); + cerr << Name() << " EVALUATION #" << eval_ << endl; +} + +void BatchOptimizer::SaveImpl(ostream* out) const { + (void)out; +} + +void BatchOptimizer::LoadImpl(istream* in) { + (void)in; +} + +string RPropOptimizer::Name() const { + return "RPropOptimizer"; +} + +void RPropOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + for (int i = 0; i < g.size(); ++i) { + const double g_i = g[i]; + const double sign_i = (signbit(g_i) ? -1.0 : 1.0); + const double prod = g_i * prev_g_[i]; + if (prod > 0.0) { + const double dij = min(delta_ij_[i] * eta_plus_, delta_max_); + (*x)[i] -= dij * sign_i; + delta_ij_[i] = dij; + prev_g_[i] = g_i; + } else if (prod < 0.0) { + delta_ij_[i] = max(delta_ij_[i] * eta_minus_, delta_min_); + prev_g_[i] = 0.0; + } else { + (*x)[i] -= delta_ij_[i] * sign_i; + prev_g_[i] = g_i; + } + } +} + +void RPropOptimizer::SaveImpl(ostream* out) const { + const size_t n = prev_g_.size(); + out->write((const char*)&n, sizeof(n)); + out->write((const char*)&prev_g_[0], sizeof(double) * n); + out->write((const char*)&delta_ij_[0], sizeof(double) * n); +} + +void RPropOptimizer::LoadImpl(istream* in) { + size_t n; + in->read((char*)&n, sizeof(n)); + assert(n == prev_g_.size()); + assert(n == delta_ij_.size()); + in->read((char*)&prev_g_[0], sizeof(double) * n); + in->read((char*)&delta_ij_[0], sizeof(double) * n); +} + +string LBFGSOptimizer::Name() const { + return "LBFGSOptimizer"; +} + +LBFGSOptimizer::LBFGSOptimizer(int num_feats, int memory_buffers) : + opt_(num_feats, memory_buffers) {} + +void LBFGSOptimizer::SaveImpl(ostream* out) const { + opt_.serialize(out); +} + +void LBFGSOptimizer::LoadImpl(istream* in) { + opt_.deserialize(in); +} + +void LBFGSOptimizer::OptimizeImpl(const double& obj, + const vector& g, + vector* x) { + opt_.run(&(*x)[0], obj, &g[0]); + if (!opt_.requests_f_and_g()) opt_.run(&(*x)[0], obj, &g[0]); + // cerr << opt_ << endl; +} + diff --git a/training/utils/optimize.h b/training/utils/optimize.h new file mode 100644 index 00000000..07943b44 --- /dev/null +++ b/training/utils/optimize.h @@ -0,0 +1,92 @@ +#ifndef _OPTIMIZE_H_ +#define _OPTIMIZE_H_ + +#include +#include +#include +#include + +#include "lbfgs.h" + +// abstract base class for first order optimizers +// order of invocation: new, Load(), Optimize(), Save(), delete +class BatchOptimizer { + public: + BatchOptimizer() : eval_(1), has_converged_(false) {} + virtual ~BatchOptimizer(); + virtual std::string Name() const = 0; + int EvaluationCount() const { return eval_; } + bool HasConverged() const { return has_converged_; } + + void Optimize(const double& obj, + const std::vector& g, + std::vector* x) { + assert(g.size() == x->size()); + ++eval_; + OptimizeImpl(obj, g, x); + scitbx::lbfgs::traditional_convergence_test converged(g.size()); + has_converged_ = converged(&(*x)[0], &g[0]); + } + + void Save(std::ostream* out) const; + void Load(std::istream* in); + protected: + virtual void SaveImpl(std::ostream* out) const; + virtual void LoadImpl(std::istream* in); + virtual void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x) = 0; + + int eval_; + private: + bool has_converged_; +}; + +class RPropOptimizer : public BatchOptimizer { + public: + explicit RPropOptimizer(int num_vars, + double eta_plus = 1.2, + double eta_minus = 0.5, + double delta_0 = 0.1, + double delta_max = 50.0, + double delta_min = 1e-6) : + prev_g_(num_vars, 0.0), + delta_ij_(num_vars, delta_0), + eta_plus_(eta_plus), + eta_minus_(eta_minus), + delta_max_(delta_max), + delta_min_(delta_min) { + assert(eta_plus > 1.0); + assert(eta_minus > 0.0 && eta_minus < 1.0); + assert(delta_max > 0.0); + assert(delta_min > 0.0); + } + std::string Name() const; + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + private: + std::vector prev_g_; + std::vector delta_ij_; + const double eta_plus_; + const double eta_minus_; + const double delta_max_; + const double delta_min_; +}; + +class LBFGSOptimizer : public BatchOptimizer { + public: + explicit LBFGSOptimizer(int num_vars, int memory_buffers = 10); + std::string Name() const; + void SaveImpl(std::ostream* out) const; + void LoadImpl(std::istream* in); + void OptimizeImpl(const double& obj, + const std::vector& g, + std::vector* x); + private: + scitbx::lbfgs::minimizer opt_; +}; + +#endif diff --git a/training/utils/optimize_test.cc b/training/utils/optimize_test.cc new file mode 100644 index 00000000..bff2ca03 --- /dev/null +++ b/training/utils/optimize_test.cc @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include "optimize.h" +#include "online_optimizer.h" +#include "sparse_vector.h" +#include "fdict.h" + +using namespace std; + +double TestOptimizer(BatchOptimizer* opt) { + cerr << "TESTING NON-PERSISTENT OPTIMIZER\n"; + + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + do { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + opt->Optimize(obj, g, &x); + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + } while (!opt->HasConverged()); + return obj; +} + +double TestPersistentOptimizer(BatchOptimizer* opt) { + cerr << "\nTESTING PERSISTENT OPTIMIZER\n"; + // f(x,y) = 4x1^2 + x1*x2 + x2^2 + x3^2 + 6x3 + 5 + // df/dx1 = 8*x1 + x2 + // df/dx2 = 2*x2 + x1 + // df/dx3 = 2*x3 + 6 + vector x(3); + vector g(3); + x[0] = 8; + x[1] = 8; + x[2] = 8; + double obj = 0; + string state; + bool converged = false; + while (!converged) { + g[0] = 8 * x[0] + x[1]; + g[1] = 2 * x[1] + x[0]; + g[2] = 2 * x[2] + 6; + obj = 4 * x[0]*x[0] + x[0] * x[1] + x[1]*x[1] + x[2]*x[2] + 6 * x[2] + 5; + + { + if (state.size() > 0) { + istringstream is(state, ios::binary); + opt->Load(&is); + } + opt->Optimize(obj, g, &x); + ostringstream os(ios::binary); opt->Save(&os); state = os.str(); + + } + + cerr << x[0] << " " << x[1] << " " << x[2] << endl; + cerr << " obj=" << obj << "\td/dx1=" << g[0] << " d/dx2=" << g[1] << " d/dx3=" << g[2] << endl; + converged = opt->HasConverged(); + if (!converged) { + // now screw up the state (should be undone by Load) + obj += 2.0; + g[1] = -g[2]; + vector x2 = x; + try { + opt->Optimize(obj, g, &x2); + } catch (...) { } + } + } + return obj; +} + +template +void TestOptimizerVariants(int num_vars) { + O oa(num_vars); + cerr << "-------------------------------------------------------------------------\n"; + cerr << "TESTING: " << oa.Name() << endl; + double o1 = TestOptimizer(&oa); + O ob(num_vars); + double o2 = TestPersistentOptimizer(&ob); + if (o1 != o2) { + cerr << oa.Name() << " VARIANTS PERFORMED DIFFERENTLY!\n" << o1 << " vs. " << o2 << endl; + exit(1); + } + cerr << oa.Name() << " SUCCESS\n"; +} + +using namespace std::tr1; + +void TestOnline() { + size_t N = 20; + double C = 1.0; + double eta0 = 0.2; + std::tr1::shared_ptr r(new ExponentialDecayLearningRate(N, eta0, 0.85)); + //shared_ptr r(new StandardLearningRate(N, eta0)); + CumulativeL1OnlineOptimizer opt(r, N, C, std::vector()); + assert(r->eta(10) < r->eta(1)); +} + +int main() { + int n = 3; + TestOptimizerVariants(n); + TestOptimizerVariants(n); + TestOnline(); + return 0; +} + diff --git a/training/utils/parallelize.pl b/training/utils/parallelize.pl new file mode 100755 index 00000000..4197e0e5 --- /dev/null +++ b/training/utils/parallelize.pl @@ -0,0 +1,423 @@ +#!/usr/bin/env perl + +# Author: Adam Lopez +# +# This script takes a command that processes input +# from stdin one-line-at-time, and parallelizes it +# on the cluster using David Chiang's sentserver/ +# sentclient architecture. +# +# Prerequisites: the command *must* read each line +# without waiting for subsequent lines of input +# (for instance, a command which must read all lines +# of input before processing will not work) and +# return it to the output *without* buffering +# multiple lines. + +#TODO: if -j 1, run immediately, not via sentserver? possible differences in environment might make debugging harder + +#ANNOYANCE: if input is shorter than -j n lines, or at the very last few lines, repeatedly sleeps. time cut down to 15s from 60s + +my $SCRIPT_DIR; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $SCRIPT_DIR = dirname(abs_path($0)); push @INC, $SCRIPT_DIR, "$SCRIPT_DIR/../../environment"; } +use LocalConfig; + +use Cwd qw/ abs_path cwd getcwd /; +use File::Temp qw/ tempfile /; +use Getopt::Long; +use IPC::Open2; +use strict; +use POSIX ":sys_wait_h"; + +use File::Basename; +my $myDir = dirname(__FILE__); +print STDERR __FILE__." -> $myDir\n"; +push(@INC, $myDir); +require "libcall.pl"; + +my $tailn=5; # +0 = concatenate all the client logs. 5 = last 5 lines +my $recycle_clients; # spawn new clients when previous ones terminate +my $stay_alive; # dont let server die when having zero clients +my $joblist = ""; +my $errordir=""; +my $multiline; +my $workdir = '.'; +my $numnodes = 8; +my $user = $ENV{"USER"}; +my $pmem = "9g"; +my $basep=50300; +my $randp=300; +my $tryp=50; +my $no_which; +my $no_cd; + +my $DEBUG=$ENV{DEBUG}; +print STDERR "DEBUG=$DEBUG output enabled.\n" if $DEBUG; +my $verbose = 1; +sub verbose { + if ($verbose) { + print STDERR @_,"\n"; + } +} +sub debug { + if ($DEBUG) { + my ($package, $filename, $line) = caller; + print STDERR "DEBUG: $filename($line): ",join(' ',@_),"\n"; + } +} +my $is_shell_special=qr.[ \t\n\\><|&;"'`~*?{}$!()].; +my $shell_escape_in_quote=qr.[\\"\$`!].; +sub escape_shell { + my ($arg)=@_; + return undef unless defined $arg; + return '""' unless $arg; + if ($arg =~ /$is_shell_special/) { + $arg =~ s/($shell_escape_in_quote)/\\$1/g; + return "\"$arg\""; + } + return $arg; +} +sub preview_files { + my ($l,$skipempty,$footer,$n)=@_; + $n=$tailn unless defined $n; + my @f=grep { ! ($skipempty && -z $_) } @$l; + my $fn=join(' ',map {escape_shell($_)} @f); + my $cmd="tail -n $n $fn"; + unchecked_output("$cmd").($footer?"\nNONEMPTY FILES:\n$fn\n":""); +} +sub prefix_dirname($) { + #like `dirname but if ends in / then return the whole thing + local ($_)=@_; + if (/\/$/) { + $_; + } else { + s#/[^/]$##; + $_ ? $_ : ''; + } +} +sub ensure_final_slash($) { + local ($_)=@_; + m#/$# ? $_ : ($_."/"); +} +sub extend_path($$;$$) { + my ($base,$ext,$mkdir,$baseisdir)=@_; + if (-d $base) { + $base.="/"; + } else { + my $dir; + if ($baseisdir) { + $dir=$base; + $base.='/' unless $base =~ /\/$/; + } else { + $dir=prefix_dirname($base); + } + my @cmd=("/bin/mkdir","-p",$dir); + check_call(@cmd) if $mkdir; + } + return $base.$ext; +} + +my $abscwd=abs_path(&getcwd); +sub print_help; + +my $use_fork; +my @pids; + +# Process command-line options +unless (GetOptions( + "stay-alive" => \$stay_alive, + "recycle-clients" => \$recycle_clients, + "error-dir=s" => \$errordir, + "multi-line" => \$multiline, + "workdir=s" => \$workdir, + "use-fork" => \$use_fork, + "verbose" => \$verbose, + "jobs=i" => \$numnodes, + "pmem=s" => \$pmem, + "baseport=i" => \$basep, +# "iport=i" => \$randp, #for short name -i + "no-which!" => \$no_which, + "no-cd!" => \$no_cd, + "tailn=s" => \$tailn, +) && scalar @ARGV){ + print_help(); + die "bad options."; +} + +my $cmd = ""; +my $prog=shift; +if ($no_which) { + $cmd=$prog; +} else { + $cmd=check_output("which $prog"); + chomp $cmd; + die "$prog not found - $cmd" unless $cmd; +} +#$cmd=abs_path($cmd); +for my $arg (@ARGV) { + $cmd .= " ".escape_shell($arg); +} +die "Please specify a command to parallelize\n" if $cmd eq ''; + +my $cdcmd=$no_cd ? '' : ("cd ".escape_shell($abscwd)."\n"); + +my $executable = $cmd; +$executable =~ s/^\s*(\S+)($|\s.*)/$1/; +$executable=check_output("basename $executable"); +chomp $executable; + + +print STDERR "Parallelizing ($numnodes ways): $cmd\n\n"; + +# create -e dir and save .sh +use File::Temp qw/tempdir/; +unless ($errordir) { + $errordir=tempdir("$executable.XXXXXX",CLEANUP=>1); +} +if ($errordir) { + my $scriptfile=extend_path("$errordir/","$executable.sh",1,1); + -d $errordir || die "should have created -e dir $errordir"; + open SF,">",$scriptfile || die; + print SF "$cdcmd$cmd\n"; + close SF; + chmod 0755,$scriptfile; + $errordir=abs_path($errordir); + &verbose("-e dir: $errordir"); +} + +# set cleanup handler +my @cleanup_cmds; +sub cleanup; +sub cleanup_and_die; +$SIG{INT} = "cleanup_and_die"; +$SIG{TERM} = "cleanup_and_die"; +$SIG{HUP} = "cleanup_and_die"; + +# other subs: +sub numof_live_jobs; +sub launch_job_on_node; + + +# vars +my $mydir = check_output("dirname $0"); chomp $mydir; +my $sentserver = "$mydir/sentserver"; +my $sentclient = "$mydir/sentclient"; +my $host = check_output("hostname"); +chomp $host; + + +# find open port +srand; +my $port = 50300+int(rand($randp)); +my $endp=$port+$tryp; +sub listening_port_lines { + my $quiet=$verbose?'':'2>/dev/null'; + return unchecked_output("netstat -a -n $quiet | grep LISTENING | grep -i tcp"); +} +my $netstat=&listening_port_lines; + +if ($verbose){ print STDERR "Testing port $port...";} + +while ($netstat=~/$port/ || &listening_port_lines=~/$port/){ + if ($verbose){ print STDERR "port is busy\n";} + $port++; + if ($port > $endp){ + die "Unable to find open port\n"; + } + if ($verbose){ print STDERR "Testing port $port... "; } +} +if ($verbose){ + print STDERR "port $port is available\n"; +} + +my $key = int(rand()*1000000); + +my $multiflag = ""; +if ($multiline){ $multiflag = "-m"; print STDERR "expecting multiline output.\n"; } +my $stay_alive_flag = ""; +if ($stay_alive){ $stay_alive_flag = "--stay-alive"; print STDERR "staying alive while no clients are connected.\n"; } + +my $node_count = 0; +my $script = ""; +# fork == one thread runs the sentserver, while the +# other spawns the sentclient commands. +my $pid = fork; +if ($pid == 0) { # child + sleep 8; # give other thread time to start sentserver + $script = "$cdcmd$sentclient $host:$port:$key $cmd"; + + if ($verbose){ + print STDERR "Client script:\n====\n"; + print STDERR $script; + print STDERR "====\n"; + } + for (my $jobn=0; $jobn<$numnodes; $jobn++){ + launch_job(); + } + if ($recycle_clients) { + my $ret; + my $livejobs; + while (1) { + $ret = waitpid($pid, WNOHANG); + #print STDERR "waitpid $pid ret = $ret \n"; + last if ($ret != 0); + $livejobs = numof_live_jobs(); + if ($numnodes >= $livejobs ) { # a client terminated, OR # lines of input was less than -j + print STDERR "num of requested nodes = $numnodes; num of currently live jobs = $livejobs; Client terminated - launching another.\n"; + launch_job(); + } else { + sleep 15; + } + } + } + print STDERR "CHILD PROCESSES SPAWNED ... WAITING\n"; + for my $p (@pids) { + waitpid($p, 0); + } +} else { +# my $todo = "$sentserver -k $key $multiflag $port "; + my $todo = "$sentserver -k $key $multiflag $port $stay_alive_flag "; + if ($verbose){ print STDERR "Running: $todo\n"; } + check_call($todo); + print STDERR "Call to $sentserver returned.\n"; + cleanup(); + exit(0); +} + +sub numof_live_jobs { + if ($use_fork) { + die "not implemented"; + } else { + # We can probably continue decoding if the qstat error is only temporary + my @livejobs = grep(/$joblist/, split(/\n/, unchecked_output("qstat"))); + return ($#livejobs + 1); + } +} +my (@errors,@outs,@cmds); + +sub launch_job { + if ($use_fork) { return launch_job_fork(); } + my $errorfile = "/dev/null"; + my $outfile = "/dev/null"; + $node_count++; + my $clientname = $executable; + $clientname =~ s/^(.{4}).*$/$1/; + $clientname = "$clientname.$node_count"; + if ($errordir){ + $errorfile = "$errordir/$clientname.ER"; + $outfile = "$errordir/$clientname.OU"; + push @errors,$errorfile; + push @outs,$outfile; + } + my $todo = qsub_args($pmem) . " -N $clientname -o $outfile -e $errorfile"; + push @cmds,$todo; + + print STDERR "Running: $todo\n"; + local(*QOUT, *QIN); + open2(\*QOUT, \*QIN, $todo) or die "Failed to open2: $!"; + print QIN $script; + close QIN; + while (my $jobid=){ + chomp $jobid; + if ($verbose){ print STDERR "Launched client job: $jobid"; } + $jobid =~ s/^(\d+)(.*?)$/\1/g; + $jobid =~ s/^Your job (\d+) .*$/\1/; + print STDERR " short job id $jobid\n"; + if ($verbose){ + print STDERR "cd: $abscwd\n"; + print STDERR "cmd: $cmd\n"; + } + if ($joblist == "") { $joblist = $jobid; } + else {$joblist = $joblist . "\|" . $jobid; } + my $cleanfn="qdel $jobid 2> /dev/null"; + push(@cleanup_cmds, $cleanfn); + } + close QOUT; +} + +sub launch_job_fork { + my $errorfile = "/dev/null"; + my $outfile = "/dev/null"; + $node_count++; + my $clientname = $executable; + $clientname =~ s/^(.{4}).*$/$1/; + $clientname = "$clientname.$node_count"; + if ($errordir){ + $errorfile = "$errordir/$clientname.ER"; + $outfile = "$errordir/$clientname.OU"; + push @errors,$errorfile; + push @outs,$outfile; + } + my $pid = fork; + if ($pid == 0) { + my ($fh, $scr_name) = get_temp_script(); + print $fh $script; + close $fh; + my $todo = "/bin/bash -xeo pipefail $scr_name 1> $outfile 2> $errorfile"; + print STDERR "EXEC: $todo\n"; + my $out = check_output("$todo"); + unlink $scr_name or warn "Failed to remove $scr_name"; + exit 0; + } else { + push @pids, $pid; + } +} + +sub get_temp_script { + my ($fh, $filename) = tempfile( "$workdir/workXXXX", SUFFIX => '.sh'); + return ($fh, $filename); +} + +sub cleanup_and_die { + cleanup(); + die "\n"; +} + +sub cleanup { + print STDERR "Cleaning up...\n"; + for $cmd (@cleanup_cmds){ + print STDERR " Cleanup command: $cmd\n"; + eval $cmd; + } + print STDERR "outputs:\n",preview_files(\@outs,1),"\n"; + print STDERR "errors:\n",preview_files(\@errors,1),"\n"; + print STDERR "cmd:\n",$cmd,"\n"; + print STDERR " cat $errordir/*.ER\nfor logs.\n"; + print STDERR "Cleanup finished.\n"; +} + +sub print_help +{ + my $name = check_output("basename $0"); chomp $name; + print << "Help"; + +usage: $name [options] + + Automatic black-box parallelization of commands. + +options: + + --use-fork + Instead of using qsub, use fork. + + -e, --error-dir + Retain output files from jobs in , rather + than silently deleting them. + + -m, --multi-line + Expect that command may produce multiple output + lines for a single input line. $name makes a + reasonable attempt to obtain all output before + processing additional inputs. However, use of this + option is inherently unsafe. + + -v, --verbose + Print diagnostic informatoin on stderr. + + -j, --jobs + Number of jobs to use. + + -p, --pmem + pmem setting for each job. + +Help +} diff --git a/training/utils/risk.cc b/training/utils/risk.cc new file mode 100644 index 00000000..d5a12cfd --- /dev/null +++ b/training/utils/risk.cc @@ -0,0 +1,45 @@ +#include "risk.h" + +#include "prob.h" +#include "candidate_set.h" +#include "ns.h" + +using namespace std; + +namespace training { + +// g = \sum_e p(e|f) * loss(e) * (phi(e,f) - E[phi(e,f)]) +double CandidateSetRisk::operator()(const vector& params, + SparseVector* g) const { + prob_t z; + for (unsigned i = 0; i < cands_.size(); ++i) { + const prob_t u(cands_[i].fmap.dot(params), init_lnx()); + z += u; + } + const double log_z = log(z); + + SparseVector exp_feats; + if (g) { + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + exp_feats += cands_[i].fmap * prob; + } + } + + double risk = 0; + for (unsigned i = 0; i < cands_.size(); ++i) { + const double log_prob = cands_[i].fmap.dot(params) - log_z; + const double prob = exp(log_prob); + const double cost = metric_.IsErrorMetric() ? metric_.ComputeScore(cands_[i].eval_feats) + : 1.0 - metric_.ComputeScore(cands_[i].eval_feats); + const double r = prob * cost; + risk += r; + if (g) (*g) += (cands_[i].fmap - exp_feats) * r; + } + return risk; +} + +} + + diff --git a/training/utils/risk.h b/training/utils/risk.h new file mode 100644 index 00000000..2e8db0fb --- /dev/null +++ b/training/utils/risk.h @@ -0,0 +1,26 @@ +#ifndef _RISK_H_ +#define _RISK_H_ + +#include +#include "sparse_vector.h" +class EvaluationMetric; + +namespace training { + class CandidateSet; + + class CandidateSetRisk { + public: + explicit CandidateSetRisk(const CandidateSet& cs, const EvaluationMetric& metric) : + cands_(cs), + metric_(metric) {} + // compute the risk (expected loss) of a CandidateSet + // (optional) the gradient of the risk with respect to params + double operator()(const std::vector& params, + SparseVector* g = NULL) const; + private: + const CandidateSet& cands_; + const EvaluationMetric& metric_; + }; +}; + +#endif diff --git a/training/utils/sentclient.c b/training/utils/sentclient.c new file mode 100644 index 00000000..91d994ab --- /dev/null +++ b/training/utils/sentclient.c @@ -0,0 +1,76 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sentserver.h" + +int main (int argc, char *argv[]) { + int sock, port; + char *s, *key; + struct hostent *hp; + struct sockaddr_in server; + int errors = 0; + + if (argc < 3) { + fprintf(stderr, "Usage: sentclient host[:port[:key]] command [args ...]\n"); + exit(1); + } + + s = strchr(argv[1], ':'); + key = NULL; + + if (s == NULL) { + port = DEFAULT_PORT; + } else { + *s = '\0'; + s+=1; + /* dumb hack */ + key = strchr(s, ':'); + if (key != NULL){ + *key = '\0'; + key += 1; + } + port = atoi(s); + } + + sock = socket(AF_INET, SOCK_STREAM, 0); + + hp = gethostbyname(argv[1]); + if (hp == NULL) { + fprintf(stderr, "unknown host %s\n", argv[1]); + exit(1); + } + + bzero((char *)&server, sizeof(server)); + bcopy(hp->h_addr, (char *)&server.sin_addr, hp->h_length); + server.sin_family = hp->h_addrtype; + server.sin_port = htons(port); + + while (connect(sock, (struct sockaddr *)&server, sizeof(server)) < 0) { + perror("connect()"); + sleep(1); + errors++; + if (errors > 5) + exit(1); + } + + close(0); + close(1); + dup2(sock, 0); + dup2(sock, 1); + + if (key != NULL){ + write(1, key, strlen(key)); + write(1, "\n", 1); + } + + execvp(argv[2], argv+2); + return 0; +} diff --git a/training/utils/sentserver.c b/training/utils/sentserver.c new file mode 100644 index 00000000..c20b4fa6 --- /dev/null +++ b/training/utils/sentserver.c @@ -0,0 +1,515 @@ +/* Copyright (c) 2001 by David Chiang. All rights reserved.*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sentserver.h" + +#define MAX_CLIENTS 64 + +struct clientinfo { + int s; + struct sockaddr_in sin; +}; + +struct line { + int id; + char *s; + int status; + struct line *next; +} *head, **ptail; + +int n_sent = 0, n_received=0, n_flushed=0; + +#define STATUS_RUNNING 0 +#define STATUS_ABORTED 1 +#define STATUS_FINISHED 2 + +pthread_mutex_t queue_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER; +pthread_mutex_t input_mutex = PTHREAD_MUTEX_INITIALIZER; + +int n_clients = 0; +int s; +int expect_multiline_output = 0; +int log_mutex = 0; +int stay_alive = 0; /* dont panic and die with zero clients */ + +void queue_finish(struct line *node, char *s, int fid); +char * read_line(int fd, int multiline); +void done (int code); + +struct line * queue_get(int fid) { + struct line *cur; + char *s, *synch; + + if (log_mutex) fprintf(stderr, "Getting for data for fid %d\n", fid); + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + + /* First, check for aborted sentences. */ + + if (log_mutex) fprintf(stderr, " Checking queue for aborted jobs (fid %d)\n", fid); + for (cur = head; cur != NULL; cur = cur->next) { + if (cur->status == STATUS_ABORTED) { + cur->status = STATUS_RUNNING; + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + return cur; + } + } + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + /* Otherwise, read a new one. */ + if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); + if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); + pthread_mutex_lock(&input_mutex); + s = read_line(0,0); + + while (s) { + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); + pthread_mutex_unlock(&input_mutex); + + cur = malloc(sizeof (struct line)); + cur->id = n_sent; + cur->s = s; + cur->next = NULL; + + *ptail = cur; + ptail = &cur->next; + + n_sent++; + + if (strcmp(s,"===SYNCH===\n")==0){ + fprintf(stderr, "Received ===SYNCH=== signal (fid %d)\n", fid); + // Note: queue_finish calls free(cur->s). + // Therefore we need to create a new string here. + synch = malloc((strlen("===SYNCH===\n")+2) * sizeof (char)); + synch = strcpy(synch, s); + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + queue_finish(cur, synch, fid); /* handles its own lock */ + + if (log_mutex) fprintf(stderr, "Locking input mutex (%d)\n", fid); + if (log_mutex) fprintf(stderr, " Reading input for new data (fid %d)\n", fid); + pthread_mutex_lock(&input_mutex); + + s = read_line(0,0); + } else { + if (log_mutex) fprintf(stderr, " Received new data %d (fid %d)\n", cur->id, fid); + cur->status = STATUS_RUNNING; + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + return cur; + } + } + + if (log_mutex) fprintf(stderr, "Unlocking input mutex (%d)\n", fid); + pthread_mutex_unlock(&input_mutex); + /* Only way to reach this point: no more output */ + + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + if (head == NULL) { + fprintf(stderr, "Reached end of file. Exiting.\n"); + done(0); + } else + ptail = NULL; /* This serves as a signal that there is no more input */ + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + + return NULL; +} + +void queue_panic() { + struct line *next; + while (head && head->status == STATUS_FINISHED) { + /* Write out finished sentences */ + if (head->status == STATUS_FINISHED) { + fputs(head->s, stdout); + fflush(stdout); + } + /* Write out blank line for unfinished sentences */ + if (head->status == STATUS_ABORTED) { + fputs("\n", stdout); + fflush(stdout); + } + /* By defition, there cannot be any RUNNING sentences, since + function is only called when n_clients == 0 */ + free(head->s); + next = head->next; + free(head); + head = next; + n_flushed++; + } + fclose(stdout); + fprintf(stderr, "All clients died. Panicking, flushing completed sentences and exiting.\n"); + done(1); +} + +void queue_abort(struct line *node, int fid) { + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + node->status = STATUS_ABORTED; + if (n_clients == 0) { + if (stay_alive) { + fprintf(stderr, "Warning! No live clients detected! Staying alive, will retry soon.\n"); + } else { + queue_panic(); + } + } + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); +} + + +void queue_print() { + struct line *cur; + + fprintf(stderr, " Queue\n"); + + for (cur = head; cur != NULL; cur = cur->next) { + switch(cur->status) { + case STATUS_RUNNING: + fprintf(stderr, " %d running ", cur->id); break; + case STATUS_ABORTED: + fprintf(stderr, " %d aborted ", cur->id); break; + case STATUS_FINISHED: + fprintf(stderr, " %d finished ", cur->id); break; + + } + fprintf(stderr, "\n"); + //fprintf(stderr, cur->s); + } +} + +void queue_finish(struct line *node, char *s, int fid) { + struct line *next; + if (log_mutex) fprintf(stderr, "Locking queue mutex (%d)\n", fid); + pthread_mutex_lock(&queue_mutex); + + free(node->s); + node->s = s; + node->status = STATUS_FINISHED; + n_received++; + + /* Flush out finished nodes */ + while (head && head->status == STATUS_FINISHED) { + + if (log_mutex) fprintf(stderr, " Flushing finished node %d\n", head->id); + + fputs(head->s, stdout); + fflush(stdout); + if (log_mutex) fprintf(stderr, " Flushed node %d\n", head->id); + free(head->s); + + next = head->next; + free(head); + + head = next; + + n_flushed++; + + if (head == NULL) { /* empty queue */ + if (ptail == NULL) { /* This can only happen if set in queue_get as signal that there is no more input. */ + fprintf(stderr, "All sentences finished. Exiting.\n"); + done(0); + } else /* ptail pointed at something which was just popped off the stack -- reset to head*/ + ptail = &head; + } + } + + if (log_mutex) fprintf(stderr, " Flushing output %d\n", head->id); + fflush(stdout); + fprintf(stderr, "%d sentences sent, %d sentences finished, %d sentences flushed\n", n_sent, n_received, n_flushed); + + if (log_mutex) fprintf(stderr, "Unlocking queue mutex (%d)\n", fid); + pthread_mutex_unlock(&queue_mutex); + +} + +char * read_line(int fd, int multiline) { + int size = 80; + char errorbuf[100]; + char *s = malloc(size+2); + int result, errors=0; + int i = 0; + + result = read(fd, s+i, 1); + + while (1) { + if (result < 0) { + perror("read()"); + sprintf(errorbuf, "Error code: %d\n", errno); + fprintf(stderr, errorbuf); + errors++; + if (errors > 5) { + free(s); + return NULL; + } else { + sleep(1); /* retry after delay */ + } + } else if (result == 0) { + break; + } else if (multiline==0 && s[i] == '\n') { + break; + } else { + if (s[i] == '\n'){ + /* if we've reached this point, + then multiline must be 1, and we're + going to poll the fd for an additional + line of data. The basic design is to + run a select on the filedescriptor fd. + Select will return under two conditions: + if there is data on the fd, or if a + timeout is reached. We'll select on this + fd. If select returns because there's data + ready, keep going; else assume there's no + more and return the data we already have. + */ + + fd_set set; + FD_ZERO(&set); + FD_SET(fd, &set); + + struct timeval timeout; + timeout.tv_sec = 3; // number of seconds for timeout + timeout.tv_usec = 0; + + int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); + if (ready<1){ + break; // no more data, stop looping + } + } + i++; + + if (i == size) { + size = size*2; + s = realloc(s, size+2); + } + } + + result = read(fd, s+i, 1); + } + + if (result == 0 && i == 0) { /* end of file */ + free(s); + return NULL; + } + + s[i] = '\n'; + s[i+1] = '\0'; + + return s; +} + +void * new_client(void *arg) { + struct clientinfo *client = (struct clientinfo *)arg; + struct line *cur; + int result; + char *s; + char errorbuf[100]; + + pthread_mutex_lock(&clients_mutex); + n_clients++; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client connected (%d connected)\n", n_clients); + + for (;;) { + + cur = queue_get(client->s); + + if (cur) { + /* fprintf(stderr, "Sending to client: %s", cur->s); */ + fprintf(stderr, "Sending data %d to client (fid %d)\n", cur->id, client->s); + result = write(client->s, cur->s, strlen(cur->s)); + if (result < strlen(cur->s)){ + perror("write()"); + sprintf(errorbuf, "Error code: %d\n", errno); + fprintf(stderr, errorbuf); + + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client died (%d connected)\n", n_clients); + queue_abort(cur, client->s); + + close(client->s); + free(client); + + pthread_exit(NULL); + } + } else { + close(client->s); + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + fprintf(stderr, "Client dismissed (%d connected)\n", n_clients); + pthread_exit(NULL); + } + + s = read_line(client->s,expect_multiline_output); + if (s) { + /* fprintf(stderr, "Client (fid %d) returned: %s", client->s, s); */ + fprintf(stderr, "Client (fid %d) returned data %d\n", client->s, cur->id); +// queue_print(); + queue_finish(cur, s, client->s); + } else { + pthread_mutex_lock(&clients_mutex); + n_clients--; + pthread_mutex_unlock(&clients_mutex); + + fprintf(stderr, "Client died (%d connected)\n", n_clients); + queue_abort(cur, client->s); + + close(client->s); + free(client); + + pthread_exit(NULL); + } + + } + return 0; +} + +void done (int code) { + close(s); + exit(code); +} + + + +int main (int argc, char *argv[]) { + struct sockaddr_in sin, from; + int g; + socklen_t len; + struct clientinfo *client; + int port; + int opt; + int errors = 0; + int argi; + char *key = NULL, *client_key; + int use_key = 0; + /* the key stuff here doesn't provide any + real measure of security, it's mainly to keep + jobs from bumping into each other. */ + + pthread_t tid; + port = DEFAULT_PORT; + + for (argi=1; argi < argc; argi++){ + if (strcmp(argv[argi], "-m")==0){ + expect_multiline_output = 1; + } else if (strcmp(argv[argi], "-k")==0){ + argi++; + if (argi == argc){ + fprintf(stderr, "Key must be specified after -k\n"); + exit(1); + } + key = argv[argi]; + use_key = 1; + } else if (strcmp(argv[argi], "--stay-alive")==0){ + stay_alive = 1; /* dont panic and die with zero clients */ + } else { + port = atoi(argv[argi]); + } + } + + /* Initialize data structures */ + head = NULL; + ptail = &head; + + /* Set up listener */ + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + opt = 1; + setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); + + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = htonl(INADDR_ANY); + sin.sin_port = htons(port); + while (bind(s, (struct sockaddr *) &sin, sizeof(sin)) < 0) { + perror("bind()"); + sleep(1); + errors++; + if (errors > 100) + exit(1); + } + + len = sizeof(sin); + getsockname(s, (struct sockaddr *) &sin, &len); + + fprintf(stderr, "Listening on port %hu\n", ntohs(sin.sin_port)); + + while (listen(s, MAX_CLIENTS) < 0) { + perror("listen()"); + sleep(1); + errors++; + if (errors > 100) + exit(1); + } + + for (;;) { + len = sizeof(from); + g = accept(s, (struct sockaddr *)&from, &len); + if (g < 0) { + perror("accept()"); + sleep(1); + continue; + } + client = malloc(sizeof(struct clientinfo)); + client->s = g; + bcopy(&from, &client->sin, len); + + if (use_key){ + fd_set set; + FD_ZERO(&set); + FD_SET(client->s, &set); + + struct timeval timeout; + timeout.tv_sec = 3; // number of seconds for timeout + timeout.tv_usec = 0; + + int ready = select(FD_SETSIZE, &set, NULL, NULL, &timeout); + if (ready<1){ + fprintf(stderr, "Prospective client failed to respond with correct key.\n"); + close(client->s); + free(client); + } else { + client_key = read_line(client->s,0); + client_key[strlen(client_key)-1]='\0'; /* chop trailing newline */ + if (strcmp(key, client_key)==0){ + pthread_create(&tid, NULL, new_client, client); + } else { + fprintf(stderr, "Prospective client failed to respond with correct key.\n"); + close(client->s); + free(client); + } + free(client_key); + } + } else { + pthread_create(&tid, NULL, new_client, client); + } + } + +} + + + diff --git a/training/utils/sentserver.h b/training/utils/sentserver.h new file mode 100644 index 00000000..cd17a546 --- /dev/null +++ b/training/utils/sentserver.h @@ -0,0 +1,6 @@ +#ifndef SENTSERVER_H +#define SENTSERVER_H + +#define DEFAULT_PORT 50000 + +#endif diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am new file mode 100644 index 00000000..280d3ae7 --- /dev/null +++ b/word-aligner/Makefile.am @@ -0,0 +1,6 @@ +bin_PROGRAMS = fast_align + +fast_align_SOURCES = fast_align.cc ttables.cc +fast_align_LDADD = $(top_srcdir)/utils/libutils.a -lz + +AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/training diff --git a/word-aligner/fast_align.cc b/word-aligner/fast_align.cc new file mode 100644 index 00000000..7492d26f --- /dev/null +++ b/word-aligner/fast_align.cc @@ -0,0 +1,281 @@ +#include +#include + +#include +#include + +#include "m.h" +#include "corpus_tools.h" +#include "stringlib.h" +#include "filelib.h" +#include "ttables.h" +#include "tdict.h" + +namespace po = boost::program_options; +using namespace std; + +bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i",po::value(),"Parallel corpus input file") + ("reverse,r","Reverse estimation (swap source and target during training)") + ("iterations,I",po::value()->default_value(5),"Number of iterations of EM training") + //("bidir,b", "Run bidirectional alignment") + ("favor_diagonal,d", "Use a static alignment distribution that assigns higher probabilities to alignments near the diagonal") + ("prob_align_null", po::value()->default_value(0.08), "When --favor_diagonal is set, what's the probability of a null alignment?") + ("diagonal_tension,T", po::value()->default_value(4.0), "How sharp or flat around the diagonal is the alignment distribution (<1 = flat >1 = sharp)") + ("variational_bayes,v","Infer VB estimate of parameters under a symmetric Dirichlet prior") + ("alpha,a", po::value()->default_value(0.01), "Hyperparameter for optional Dirichlet prior") + ("no_null_word,N","Do not generate from a null token") + ("output_parameters,p", "Write model parameters instead of alignments") + ("beam_threshold,t",po::value()->default_value(-4),"When writing parameters, log_10 of beam threshold for writing parameter (-10000 to include everything, 0 max parameter only)") + ("hide_training_alignments,H", "Hide training alignments (only useful if you want to use -x option and just compute testset statistics)") + ("testset,x", po::value(), "After training completes, compute the log likelihood of this set of sentence pairs under the learned model") + ("no_add_viterbi,V","When writing model parameters, do not add Viterbi alignment points (may generate a grammar where some training sentence pairs are unreachable)"); + po::options_description clo("Command line options"); + clo.add_options() + ("config", po::value(), "Configuration file") + ("help,h", "Print this help message and exit"); + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + dcmdline_options.add(opts).add(clo); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + if (conf->count("config")) { + ifstream config((*conf)["config"].as().c_str()); + po::store(po::parse_config_file(config, dconfig_options), *conf); + } + po::notify(*conf); + + if (conf->count("help") || conf->count("input") == 0) { + cerr << "Usage " << argv[0] << " [OPTIONS] -i corpus.fr-en\n"; + cerr << dcmdline_options << endl; + return false; + } + return true; +} + +int main(int argc, char** argv) { + po::variables_map conf; + if (!InitCommandLine(argc, argv, &conf)) return 1; + const string fname = conf["input"].as(); + const bool reverse = conf.count("reverse") > 0; + const int ITERATIONS = conf["iterations"].as(); + const double BEAM_THRESHOLD = pow(10.0, conf["beam_threshold"].as()); + const bool use_null = (conf.count("no_null_word") == 0); + const WordID kNULL = TD::Convert(""); + const bool add_viterbi = (conf.count("no_add_viterbi") == 0); + const bool variational_bayes = (conf.count("variational_bayes") > 0); + const bool write_alignments = (conf.count("output_parameters") == 0); + const double diagonal_tension = conf["diagonal_tension"].as(); + const double prob_align_null = conf["prob_align_null"].as(); + const bool hide_training_alignments = (conf.count("hide_training_alignments") > 0); + string testset; + if (conf.count("testset")) testset = conf["testset"].as(); + const double prob_align_not_null = 1.0 - prob_align_null; + const double alpha = conf["alpha"].as(); + const bool favor_diagonal = conf.count("favor_diagonal"); + if (variational_bayes && alpha <= 0.0) { + cerr << "--alpha must be > 0\n"; + return 1; + } + + TTable s2t, t2s; + TTable::Word2Word2Double s2t_viterbi; + double tot_len_ratio = 0; + double mean_srclen_multiplier = 0; + vector unnormed_a_i; + for (int iter = 0; iter < ITERATIONS; ++iter) { + const bool final_iteration = (iter == (ITERATIONS - 1)); + cerr << "ITERATION " << (iter + 1) << (final_iteration ? " (FINAL)" : "") << endl; + ReadFile rf(fname); + istream& in = *rf.stream(); + double likelihood = 0; + double denom = 0.0; + int lc = 0; + bool flag = false; + string line; + string ssrc, strg; + vector src, trg; + while(true) { + getline(in, line); + if (!in) break; + ++lc; + if (lc % 1000 == 0) { cerr << '.'; flag = true; } + if (lc %50000 == 0) { cerr << " [" << lc << "]\n" << flush; flag = false; } + src.clear(); trg.clear(); + CorpusTools::ReadLine(line, &src, &trg); + if (reverse) swap(src, trg); + if (src.size() == 0 || trg.size() == 0) { + cerr << "Error: " << lc << "\n" << line << endl; + return 1; + } + if (src.size() > unnormed_a_i.size()) + unnormed_a_i.resize(src.size()); + if (iter == 0) + tot_len_ratio += static_cast(trg.size()) / static_cast(src.size()); + denom += trg.size(); + vector probs(src.size() + 1); + bool first_al = true; // used for write_alignments + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j]; + double sum = 0; + const double j_over_ts = double(j) / trg.size(); + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) + if (use_null) { + if (favor_diagonal) prob_a_i = prob_align_null; + probs[0] = s2t.prob(kNULL, f_j) * prob_a_i; + sum += probs[0]; + } + double az = 0; + if (favor_diagonal) { + for (int ta = 0; ta < src.size(); ++ta) { + unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az += unnormed_a_i[ta]; + } + az /= prob_align_not_null; + } + for (int i = 1; i <= src.size(); ++i) { + if (favor_diagonal) + prob_a_i = unnormed_a_i[i-1] / az; + probs[i] = s2t.prob(src[i-1], f_j) * prob_a_i; + sum += probs[i]; + } + if (final_iteration) { + if (add_viterbi || write_alignments) { + WordID max_i = 0; + double max_p = -1; + int max_index = -1; + if (use_null) { + max_i = kNULL; + max_index = 0; + max_p = probs[0]; + } + for (int i = 1; i <= src.size(); ++i) { + if (probs[i] > max_p) { + max_index = i; + max_p = probs[i]; + max_i = src[i-1]; + } + } + if (!hide_training_alignments && write_alignments) { + if (max_index > 0) { + if (first_al) first_al = false; else cout << ' '; + if (reverse) + cout << j << '-' << (max_index - 1); + else + cout << (max_index - 1) << '-' << j; + } + } + s2t_viterbi[max_i][f_j] = 1.0; + } + } else { + if (use_null) + s2t.Increment(kNULL, f_j, probs[0] / sum); + for (int i = 1; i <= src.size(); ++i) + s2t.Increment(src[i-1], f_j, probs[i] / sum); + } + likelihood += log(sum); + } + if (write_alignments && final_iteration && !hide_training_alignments) cout << endl; + } + + // log(e) = 1.0 + double base2_likelihood = likelihood / log(2); + + if (flag) { cerr << endl; } + if (iter == 0) { + mean_srclen_multiplier = tot_len_ratio / lc; + cerr << "expected target length = source length * " << mean_srclen_multiplier << endl; + } + cerr << " log_e likelihood: " << likelihood << endl; + cerr << " log_2 likelihood: " << base2_likelihood << endl; + cerr << " cross entropy: " << (-base2_likelihood / denom) << endl; + cerr << " perplexity: " << pow(2.0, -base2_likelihood / denom) << endl; + if (!final_iteration) { + if (variational_bayes) + s2t.NormalizeVB(alpha); + else + s2t.Normalize(); + } + } + if (testset.size()) { + ReadFile rf(testset); + istream& in = *rf.stream(); + int lc = 0; + double tlp = 0; + string line; + while (getline(in, line)) { + ++lc; + vector src, trg; + CorpusTools::ReadLine(line, &src, &trg); + cout << TD::GetString(src) << " ||| " << TD::GetString(trg) << " |||"; + if (reverse) swap(src, trg); + double log_prob = Md::log_poisson(trg.size(), 0.05 + src.size() * mean_srclen_multiplier); + if (src.size() > unnormed_a_i.size()) + unnormed_a_i.resize(src.size()); + + // compute likelihood + for (int j = 0; j < trg.size(); ++j) { + const WordID& f_j = trg[j]; + double sum = 0; + int a_j = 0; + double max_pat = 0; + const double j_over_ts = double(j) / trg.size(); + double prob_a_i = 1.0 / (src.size() + use_null); // uniform (model 1) + if (use_null) { + if (favor_diagonal) prob_a_i = prob_align_null; + max_pat = s2t.prob(kNULL, f_j) * prob_a_i; + sum += max_pat; + } + double az = 0; + if (favor_diagonal) { + for (int ta = 0; ta < src.size(); ++ta) { + unnormed_a_i[ta] = exp(-fabs(double(ta) / src.size() - j_over_ts) * diagonal_tension); + az += unnormed_a_i[ta]; + } + az /= prob_align_not_null; + } + for (int i = 1; i <= src.size(); ++i) { + if (favor_diagonal) + prob_a_i = unnormed_a_i[i-1] / az; + double pat = s2t.prob(src[i-1], f_j) * prob_a_i; + if (pat > max_pat) { max_pat = pat; a_j = i; } + sum += pat; + } + log_prob += log(sum); + if (write_alignments) { + if (a_j > 0) { + cout << ' '; + if (reverse) + cout << j << '-' << (a_j - 1); + else + cout << (a_j - 1) << '-' << j; + } + } + } + tlp += log_prob; + cout << " ||| " << log_prob << endl << flush; + } // loop over test set sentences + cerr << "TOTAL LOG PROB " << tlp << endl; + } + + if (write_alignments) return 0; + + for (TTable::Word2Word2Double::iterator ei = s2t.ttable.begin(); ei != s2t.ttable.end(); ++ei) { + const TTable::Word2Double& cpd = ei->second; + const TTable::Word2Double& vit = s2t_viterbi[ei->first]; + const string& esym = TD::Convert(ei->first); + double max_p = -1; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) + if (fi->second > max_p) max_p = fi->second; + const double threshold = max_p * BEAM_THRESHOLD; + for (TTable::Word2Double::const_iterator fi = cpd.begin(); fi != cpd.end(); ++fi) { + if (fi->second > threshold || (vit.find(fi->first) != vit.end())) { + cout << esym << ' ' << TD::Convert(fi->first) << ' ' << log(fi->second) << endl; + } + } + } + return 0; +} + diff --git a/word-aligner/makefiles/makefile.grammars b/word-aligner/makefiles/makefile.grammars index 08ff33e1..ce3e1638 100644 --- a/word-aligner/makefiles/makefile.grammars +++ b/word-aligner/makefiles/makefile.grammars @@ -16,7 +16,7 @@ STEM_E = $(SCRIPT_DIR)/stemmers/$(E_LANG).pl CLASSIFY = $(SUPPORT_DIR)/classify.pl MAKE_LEX_GRAMMAR = $(SUPPORT_DIR)/make_lex_grammar.pl -MODEL1 = $(TRAINING_DIR)/fast_align +MODEL1 = $(SCRIPT_DIR)/fast_align MERGE_CORPUS = $(SUPPORT_DIR)/merge_corpus.pl e.voc: corpus.e diff --git a/word-aligner/paste-parallel-files.pl b/word-aligner/paste-parallel-files.pl deleted file mode 100755 index ce53b325..00000000 --- a/word-aligner/paste-parallel-files.pl +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/perl -w -use strict; - -my @fs = (); -for my $file (@ARGV) { - my $fh; - open $fh, "<$file" or die "Can't open $file for reading: $!"; - push @fs, $fh; -} -my $num = scalar @fs; -die "Usage: $0 file1.txt file2.txt [...]\n" unless $num > 1; - -my $first = $fs[0]; -while(<$first>) { - chomp; - my @out = (); - push @out, $_; - for (my $i=1; $i < $num; $i++) { - my $f = $fs[$i]; - my $line = <$f>; - die "Mismatched number of lines!" unless defined $line; - chomp $line; - push @out, $line; - } - print join(' ||| ', @out) . "\n"; -} - -for my $fh (@fs) { - my $x=<$fh>; - die "Mismatched number of lines!" if defined $x; - close $fh; -} - -exit 0; - diff --git a/word-aligner/ttables.cc b/word-aligner/ttables.cc new file mode 100644 index 00000000..45bf14c5 --- /dev/null +++ b/word-aligner/ttables.cc @@ -0,0 +1,31 @@ +#include "ttables.h" + +#include + +#include "dict.h" + +using namespace std; +using namespace std::tr1; + +void TTable::DeserializeProbsFromText(std::istream* in) { + int c = 0; + while(*in) { + string e; + string f; + double p; + (*in) >> e >> f >> p; + if (e.empty()) break; + ++c; + ttable[TD::Convert(e)][TD::Convert(f)] = p; + } + cerr << "Loaded " << c << " translation parameters.\n"; +} + +void TTable::SerializeHelper(string* out, const Word2Word2Double& o) { + assert(!"not implemented"); +} + +void TTable::DeserializeHelper(const string& in, Word2Word2Double* o) { + assert(!"not implemented"); +} + diff --git a/word-aligner/ttables.h b/word-aligner/ttables.h new file mode 100644 index 00000000..9baa13ca --- /dev/null +++ b/word-aligner/ttables.h @@ -0,0 +1,101 @@ +#ifndef _TTABLES_H_ +#define _TTABLES_H_ + +#include +#include + +#include "sparse_vector.h" +#include "m.h" +#include "wordid.h" +#include "tdict.h" + +class TTable { + public: + TTable() {} + typedef std::tr1::unordered_map Word2Double; + typedef std::tr1::unordered_map Word2Word2Double; + inline double prob(const int& e, const int& f) const { + const Word2Word2Double::const_iterator cit = ttable.find(e); + if (cit != ttable.end()) { + const Word2Double& cpd = cit->second; + const Word2Double::const_iterator it = cpd.find(f); + if (it == cpd.end()) return 1e-9; + return it->second; + } else { + return 1e-9; + } + } + inline void Increment(const int& e, const int& f) { + counts[e][f] += 1.0; + } + inline void Increment(const int& e, const int& f, double x) { + counts[e][f] += x; + } + void NormalizeVB(const double alpha) { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second + alpha; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second = exp(Md::digamma(it->second + alpha) - Md::digamma(tot)); + } + counts.clear(); + } + void Normalize() { + ttable.swap(counts); + for (Word2Word2Double::iterator cit = ttable.begin(); + cit != ttable.end(); ++cit) { + double tot = 0; + Word2Double& cpd = cit->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + tot += it->second; + for (Word2Double::iterator it = cpd.begin(); it != cpd.end(); ++it) + it->second /= tot; + } + counts.clear(); + } + // adds counts from another TTable - probabilities remain unchanged + TTable& operator+=(const TTable& rhs) { + for (Word2Word2Double::const_iterator it = rhs.counts.begin(); + it != rhs.counts.end(); ++it) { + const Word2Double& cpd = it->second; + Word2Double& tgt = counts[it->first]; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + tgt[j->first] += j->second; + } + } + return *this; + } + void ShowTTable() const { + for (Word2Word2Double::const_iterator it = ttable.begin(); it != ttable.end(); ++it) { + const Word2Double& cpd = it->second; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "P(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void ShowCounts() const { + for (Word2Word2Double::const_iterator it = counts.begin(); it != counts.end(); ++it) { + const Word2Double& cpd = it->second; + for (Word2Double::const_iterator j = cpd.begin(); j != cpd.end(); ++j) { + std::cerr << "c(" << TD::Convert(j->first) << '|' << TD::Convert(it->first) << ") = " << j->second << std::endl; + } + } + } + void DeserializeProbsFromText(std::istream* in); + void SerializeCounts(std::string* out) const { SerializeHelper(out, counts); } + void DeserializeCounts(const std::string& in) { DeserializeHelper(in, &counts); } + void SerializeProbs(std::string* out) const { SerializeHelper(out, ttable); } + void DeserializeProbs(const std::string& in) { DeserializeHelper(in, &ttable); } + private: + static void SerializeHelper(std::string*, const Word2Word2Double& o); + static void DeserializeHelper(const std::string&, Word2Word2Double* o); + public: + Word2Word2Double ttable; + Word2Word2Double counts; +}; + +#endif -- cgit v1.2.3 From 212decb4382b84c2370c369b0507a5534399aa56 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 14 Dec 2012 13:35:11 -0500 Subject: add compression libraries --- configure.ac | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 366112a3..f1b9d132 100644 --- a/configure.ac +++ b/configure.ac @@ -78,6 +78,13 @@ LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_ AC_CHECK_HEADER(google/dense_hash_map, [AC_DEFINE([HAVE_SPARSEHASH], [1], [flag for google::dense_hash_map])]) +AC_CHECK_HEADER(zlib.h, + [AC_DEFINE([HAVE_ZLIB], [1], [zlib])]) +AC_CHECK_HEADER(bzlib.h, + [AC_DEFINE([HAVE_BZLIB], [1], [bzlib])]) +AC_CHECK_HEADER(lzma.h, + [AC_DEFINE([HAVE_XZLIB], [1], [xzlib])]) + AC_PROG_INSTALL CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H" -- cgit v1.2.3 From dd0fdabb1db41a4230e487c80b61ace9697f150d Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 14 Dec 2012 12:48:26 -0800 Subject: Updated kenlm --- configure.ac | 2 +- klm/lm/binary_format.cc | 21 +- klm/lm/config.cc | 1 + klm/lm/config.hh | 59 +++--- klm/lm/max_order.hh | 2 - klm/lm/model.cc | 30 +-- klm/lm/search_trie.cc | 47 ++--- klm/util/Makefile.am | 1 + klm/util/exception.hh | 8 +- klm/util/file.cc | 38 ++-- klm/util/file.hh | 8 +- klm/util/file_piece.cc | 66 ++----- klm/util/file_piece.hh | 41 ++-- klm/util/file_piece_test.cc | 4 +- klm/util/have.hh | 12 +- klm/util/joint_sort.hh | 4 +- klm/util/read_compressed.cc | 403 +++++++++++++++++++++++++++++++++++++++ klm/util/read_compressed.hh | 74 +++++++ klm/util/read_compressed_test.cc | 94 +++++++++ klm/util/scoped.hh | 65 ++++--- klm/util/string_piece.hh | 19 +- klm/util/tokenize_piece.hh | 14 +- 22 files changed, 781 insertions(+), 232 deletions(-) create mode 100644 klm/util/read_compressed.cc create mode 100644 klm/util/read_compressed.hh create mode 100644 klm/util/read_compressed_test.cc (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index f1b9d132..f4650ca4 100644 --- a/configure.ac +++ b/configure.ac @@ -87,7 +87,7 @@ AC_CHECK_HEADER(lzma.h, AC_PROG_INSTALL -CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H" +CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H -DKENLM_MAX_ORDER=6" # core cdec stuff AC_CONFIG_FILES([Makefile]) diff --git a/klm/lm/binary_format.cc b/klm/lm/binary_format.cc index efa67056..39c4a9b6 100644 --- a/klm/lm/binary_format.cc +++ b/klm/lm/binary_format.cc @@ -16,11 +16,11 @@ namespace ngram { namespace { const char kMagicBeforeVersion[] = "mmap lm http://kheafield.com/code format version"; const char kMagicBytes[] = "mmap lm http://kheafield.com/code format version 5\n\0"; -// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). +// This must be shorter than kMagicBytes and indicates an incomplete binary file (i.e. build failed). const char kMagicIncomplete[] = "mmap lm http://kheafield.com/code incomplete\n"; const long int kMagicVersion = 5; -// Old binary files built on 32-bit machines have this header. +// Old binary files built on 32-bit machines have this header. // TODO: eliminate with next binary release. struct OldSanity { char magic[sizeof(kMagicBytes)]; @@ -39,7 +39,7 @@ struct OldSanity { }; -// Test values aligned to 8 bytes. +// Test values aligned to 8 bytes. struct Sanity { char magic[ALIGN8(sizeof(kMagicBytes))]; float zero_f, one_f, minus_half_f; @@ -101,7 +101,7 @@ uint8_t *SetupJustVocab(const Config &config, uint8_t order, std::size_t memory_ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t memory_size, Backing &backing) { std::size_t adjusted_vocab = backing.vocab.size() + vocab_pad; if (config.write_mmap) { - // Grow the file to accomodate the search, using zeros. + // Grow the file to accomodate the search, using zeros. try { util::ResizeOrThrow(backing.file.get(), adjusted_vocab + memory_size); } catch (util::ErrnoException &e) { @@ -114,7 +114,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t return reinterpret_cast(backing.search.get()); } // mmap it now. - // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. + // We're skipping over the header and vocab for the search space mmap. mmap likes page aligned offsets, so some arithmetic to round the offset down. std::size_t page_size = util::SizePage(); std::size_t alignment_cruft = adjusted_vocab % page_size; backing.search.reset(util::MapOrThrow(alignment_cruft + memory_size, true, util::kFileFlags, false, backing.file.get(), adjusted_vocab - alignment_cruft), alignment_cruft + memory_size, util::scoped_memory::MMAP_ALLOCATED); @@ -122,7 +122,7 @@ uint8_t *GrowForSearch(const Config &config, std::size_t vocab_pad, std::size_t } else { util::MapAnonymous(memory_size, backing.search); return reinterpret_cast(backing.search.get()); - } + } } void FinishFile(const Config &config, ModelType model_type, unsigned int search_version, const std::vector &counts, std::size_t vocab_pad, Backing &backing) { @@ -140,7 +140,7 @@ void FinishFile(const Config &config, ModelType model_type, unsigned int search_ util::FSyncOrThrow(backing.file.get()); break; } - // header and vocab share the same mmap. The header is written here because we know the counts. + // header and vocab share the same mmap. The header is written here because we know the counts. Parameters params = Parameters(); params.counts = counts; params.fixed.order = counts.size(); @@ -160,7 +160,7 @@ namespace detail { bool IsBinaryFormat(int fd) { const uint64_t size = util::SizeFile(fd); if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) return false; - // Try reading the header. + // Try reading the header. util::scoped_memory memory; try { util::MapRead(util::LAZY, fd, 0, sizeof(Sanity), memory); @@ -214,7 +214,7 @@ void SeekPastHeader(int fd, const Parameters ¶ms) { uint8_t *SetupBinary(const Config &config, const Parameters ¶ms, uint64_t memory_size, Backing &backing) { const uint64_t file_size = util::SizeFile(backing.file.get()); - // The header is smaller than a page, so we have to map the whole header as well. + // The header is smaller than a page, so we have to map the whole header as well. std::size_t total_map = util::CheckOverflow(TotalHeaderSize(params.counts.size()) + memory_size); if (file_size != util::kBadSize && static_cast(file_size) < total_map) UTIL_THROW(FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); @@ -233,7 +233,8 @@ void ComplainAboutARPA(const Config &config, ModelType model_type) { if (config.write_mmap || !config.messages) return; if (config.arpa_complain == Config::ALL) { *config.messages << "Loading the LM will be faster if you build a binary file." << std::endl; - } else if (config.arpa_complain == Config::EXPENSIVE && model_type == TRIE_SORTED) { + } else if (config.arpa_complain == Config::EXPENSIVE && + (model_type == TRIE || model_type == QUANT_TRIE || model_type == ARRAY_TRIE || model_type == QUANT_ARRAY_TRIE)) { *config.messages << "Building " << kModelNames[model_type] << " from ARPA is expensive. Save time by building a binary format." << std::endl; } } diff --git a/klm/lm/config.cc b/klm/lm/config.cc index f9d988ca..9520c41c 100644 --- a/klm/lm/config.cc +++ b/klm/lm/config.cc @@ -6,6 +6,7 @@ namespace lm { namespace ngram { Config::Config() : + show_progress(true), messages(&std::cerr), enumerate_vocab(NULL), unknown_missing(COMPLAIN), diff --git a/klm/lm/config.hh b/klm/lm/config.hh index 739cee9c..0de7b7c6 100644 --- a/klm/lm/config.hh +++ b/klm/lm/config.hh @@ -11,46 +11,52 @@ /* Configuration for ngram model. Separate header to reduce pollution. */ namespace lm { - + class EnumerateVocab; namespace ngram { struct Config { - // EFFECTIVE FOR BOTH ARPA AND BINARY READS + // EFFECTIVE FOR BOTH ARPA AND BINARY READS + + // (default true) print progress bar to messages + bool show_progress; // Where to log messages including the progress bar. Set to NULL for // silence. std::ostream *messages; + std::ostream *ProgressMessages() const { + return show_progress ? messages : 0; + } + // This will be called with every string in the vocabulary. See // enumerate_vocab.hh for more detail. Config does not take ownership; you - // are still responsible for deleting it (or stack allocating). + // are still responsible for deleting it (or stack allocating). EnumerateVocab *enumerate_vocab; - // ONLY EFFECTIVE WHEN READING ARPA - // What to do when isn't in the provided model. + // What to do when isn't in the provided model. WarningAction unknown_missing; - // What to do when or is missing from the model. - // If THROW_UP, the exception will be of type util::SpecialWordMissingException. + // What to do when or is missing from the model. + // If THROW_UP, the exception will be of type util::SpecialWordMissingException. WarningAction sentence_marker_missing; // What to do with a positive log probability. For COMPLAIN and SILENT, map - // to 0. + // to 0. WarningAction positive_log_probability; - // The probability to substitute for if it's missing from the model. + // The probability to substitute for if it's missing from the model. // No effect if the model has or unknown_missing == THROW_UP. float unknown_missing_logprob; // Size multiplier for probing hash table. Must be > 1. Space is linear in // this. Time is probing_multiplier / (probing_multiplier - 1). No effect - // for sorted variant. + // for sorted variant. // If you find yourself setting this to a low number, consider using the - // TrieModel which has lower memory consumption. + // TrieModel which has lower memory consumption. float probing_multiplier; // Amount of memory to use for building. The actual memory usage will be @@ -58,10 +64,10 @@ struct Config { // models. std::size_t building_memory; - // Template for temporary directory appropriate for passing to mkdtemp. + // Template for temporary directory appropriate for passing to mkdtemp. // The characters XXXXXX are appended before passing to mkdtemp. Only // applies to trie. If NULL, defaults to write_mmap. If that's NULL, - // defaults to input file name. + // defaults to input file name. const char *temporary_directory_prefix; // Level of complaining to do when loading from ARPA instead of binary format. @@ -69,49 +75,46 @@ struct Config { ARPALoadComplain arpa_complain; // While loading an ARPA file, also write out this binary format file. Set - // to NULL to disable. + // to NULL to disable. const char *write_mmap; enum WriteMethod { - WRITE_MMAP, // Map the file directly. - WRITE_AFTER // Write after we're done. + WRITE_MMAP, // Map the file directly. + WRITE_AFTER // Write after we're done. }; WriteMethod write_method; - // Include the vocab in the binary file? Only effective if write_mmap != NULL. + // Include the vocab in the binary file? Only effective if write_mmap != NULL. bool include_vocab; - // Left rest options. Only used when the model includes rest costs. + // Left rest options. Only used when the model includes rest costs. enum RestFunction { REST_MAX, // Maximum of any score to the left - REST_LOWER, // Use lower-order files given below. + REST_LOWER, // Use lower-order files given below. }; RestFunction rest_function; - // Only used for REST_LOWER. + // Only used for REST_LOWER. std::vector rest_lower_files; - // Quantization options. Only effective for QuantTrieModel. One value is // reserved for each of prob and backoff, so 2^bits - 1 buckets will be used - // to quantize (and one of the remaining backoffs will be 0). + // to quantize (and one of the remaining backoffs will be 0). uint8_t prob_bits, backoff_bits; // Bhiksha compression (simple form). Only works with trie. uint8_t pointer_bhiksha_bits; - - + // ONLY EFFECTIVE WHEN READING BINARY - + // How to get the giant array into memory: lazy mmap, populate, read etc. - // See util/mmap.hh for details of MapMethod. + // See util/mmap.hh for details of MapMethod. util::LoadMethod load_method; - - // Set defaults. + // Set defaults. Config(); }; diff --git a/klm/lm/max_order.hh b/klm/lm/max_order.hh index ea0dea46..3eb97ccd 100644 --- a/klm/lm/max_order.hh +++ b/klm/lm/max_order.hh @@ -7,5 +7,3 @@ #ifndef KENLM_ORDER_MESSAGE #define KENLM_ORDER_MESSAGE "If your build system supports changing KENLM_MAX_ORDER, change it there and recompile. In the KenLM tarball or Moses, use e.g. `bjam --max-kenlm-order=6 -a'. Otherwise, edit lm/max_order.hh." #endif - -#define KENLM_MAX_ORDER 5 diff --git a/klm/lm/model.cc b/klm/lm/model.cc index fc61efee..a40fd2fb 100644 --- a/klm/lm/model.cc +++ b/klm/lm/model.cc @@ -37,7 +37,7 @@ template void GenericModel GenericModel::GenericModel(const char *file, const Config &config) { LoadLM(file, config, *this); - // g++ prints warnings unless these are fully initialized. + // g++ prints warnings unless these are fully initialized. State begin_sentence = State(); begin_sentence.length = 1; begin_sentence.words[0] = vocab_.BeginSentence(); @@ -69,8 +69,8 @@ template void GenericModel void GenericModel::InitializeFromARPA(const char *file, const Config &config) { - // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. - util::FilePiece f(backing_.file.release(), file, config.messages); + // Backing file is the ARPA. Steal it so we can make the backing file the mmap output if any. + util::FilePiece f(backing_.file.release(), file, config.ProgressMessages()); try { std::vector counts; // File counts do not include pruned trigrams that extend to quadgrams etc. These will be fixed by search_. @@ -80,7 +80,7 @@ template void GenericModel 1.0"); std::size_t vocab_size = util::CheckOverflow(VocabularyT::Size(counts[0], config)); - // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. + // Setup the binary file for writing the vocab lookup table. The search_ is responsible for growing the binary file to its needs. vocab_.SetupMemory(SetupJustVocab(config, counts.size(), vocab_size, backing_), vocab_size, counts[0], config); if (config.write_mmap) { @@ -95,7 +95,7 @@ template void GenericModel FullScoreReturn GenericModel void GenericModel::GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const { - // Generate a state from context. + // Generate a state from context. context_rend = std::min(context_rend, context_rbegin + P::Order() - 1); if (context_rend == context_rbegin) { out_state.length = 0; @@ -191,7 +191,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel FullScoreReturn GenericModel(out_state.length) - 1; @@ -217,10 +217,10 @@ void CopyRemainingHistory(const WordIndex *from, State &out_state) { } } // namespace -/* Ugly optimized function. Produce a score excluding backoff. - * The search goes in increasing order of ngram length. +/* Ugly optimized function. Produce a score excluding backoff. + * The search goes in increasing order of ngram length. * Context goes backward, so context_begin is the word immediately preceeding - * new_word. + * new_word. */ template FullScoreReturn GenericModel::ScoreExceptBackoff( const WordIndex *const context_rbegin, @@ -229,7 +229,7 @@ template FullScoreReturn GenericModel FullScoreReturn GenericModel(current_); const unsigned char order = (entry_size_ - sizeof(ProbPointer)) / sizeof(WordIndex); for (reader.Rewind(); reader && (current_ != allocated_); ) { @@ -109,7 +109,7 @@ class BackoffMessages { ++reader; break; case 1: - // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. + // Message but nobody to receive it. Write it down at the beginning of the buffer so we can inform this blank that it extends. for (const WordIndex *w = reinterpret_cast(current_); w != reinterpret_cast(current_) + order; ++w, ++extend_out) *extend_out = *w; current_ += entry_size_; break; @@ -126,7 +126,7 @@ class BackoffMessages { break; } } - // Now this is a list of blanks that extend right. + // Now this is a list of blanks that extend right. entry_size_ = sizeof(WordIndex) * order; Resize(sizeof(WordIndex) * (extend_out - (const WordIndex*)backing_.get())); current_ = (uint8_t*)backing_.get(); @@ -153,7 +153,7 @@ class BackoffMessages { private: void FinishedAdding() { Resize(current_ - (uint8_t*)backing_.get()); - // Sort requests in same order as files. + // Sort requests in same order as files. std::sort( util::SizedIterator(util::SizedProxy(backing_.get(), entry_size_)), util::SizedIterator(util::SizedProxy(current_, entry_size_)), @@ -220,7 +220,7 @@ class SRISucks { } private: - // This used to be one array. Then I needed to separate it by order for quantization to work. + // This used to be one array. Then I needed to separate it by order for quantization to work. std::vector values_[KENLM_MAX_ORDER - 1]; BackoffMessages messages_[KENLM_MAX_ORDER - 1]; @@ -253,7 +253,7 @@ class FindBlanks { ++counts_.back(); } - // Unigrams wrote one past. + // Unigrams wrote one past. void Cleanup() { --counts_[0]; } @@ -270,15 +270,15 @@ class FindBlanks { SRISucks &sri_; }; -// Phase to actually write n-grams to the trie. +// Phase to actually write n-grams to the trie. template class WriteEntries { public: - WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : + WriteEntries(RecordReader *contexts, const Quant &quant, UnigramValue *unigrams, BitPackedMiddle *middle, BitPackedLongest &longest, unsigned char order, SRISucks &sri) : contexts_(contexts), quant_(quant), unigrams_(unigrams), middle_(middle), - longest_(longest), + longest_(longest), bigram_pack_((order == 2) ? static_cast(longest_) : static_cast(*middle_)), order_(order), sri_(sri) {} @@ -328,7 +328,7 @@ struct Gram { const WordIndex *begin, *end; - // For queue, this is the direction we want. + // For queue, this is the direction we want. bool operator<(const Gram &other) const { return std::lexicographical_compare(other.begin, other.end, begin, end); } @@ -353,7 +353,7 @@ template class BlankManager { been_length_ = length; return; } - // There are blanks to insert starting with order blank. + // There are blanks to insert starting with order blank. unsigned char blank = cur - to + 1; UTIL_THROW_IF(blank == 1, FormatLoadException, "Missing a unigram that appears as context."); const float *lower_basis; @@ -363,7 +363,7 @@ template class BlankManager { assert(*lower_basis != kBadProb); doing_.MiddleBlank(blank, to, based_on, *lower_basis); *pre = *cur; - // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. + // Mark that the probability is a blank so it shouldn't be used as the basis for a later n-gram. basis_[blank - 1] = kBadProb; } *pre = *cur; @@ -377,7 +377,7 @@ template class BlankManager { unsigned char been_length_; float basis_[KENLM_MAX_ORDER]; - + Doing &doing_; }; @@ -451,7 +451,7 @@ template void TrainProbQuantizer(uint8_t order, uint64_t count, Re } void PopulateUnigramWeights(FILE *file, WordIndex unigram_count, RecordReader &contexts, UnigramValue *unigrams) { - // Fill unigram probabilities. + // Fill unigram probabilities. try { rewind(file); for (WordIndex i = 0; i < unigram_count; ++i) { @@ -486,7 +486,7 @@ template void BuildTrie(SortedFiles &files, std::ve util::scoped_memory unigrams; MapRead(util::POPULATE_OR_READ, unigram_fd.get(), 0, counts[0] * sizeof(ProbBackoff), unigrams); FindBlanks finder(counts.size(), reinterpret_cast(unigrams.get()), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Identifying n-grams omitted by SRI", finder); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Identifying n-grams omitted by SRI", finder); fixed_counts = finder.Counts(); } unigram_file.reset(util::FDOpenOrThrow(unigram_fd)); @@ -504,7 +504,8 @@ template void BuildTrie(SortedFiles &files, std::ve inputs[i-2].Rewind(); } if (Quant::kTrain) { - util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), config.messages, "Quantizing"); + util::ErsatzProgress progress(std::accumulate(counts.begin() + 1, counts.end(), 0), + config.ProgressMessages(), "Quantizing"); for (unsigned char i = 2; i < counts.size(); ++i) { TrainQuantizer(i, counts[i-1], sri.Values(i), inputs[i-2], progress, quant); } @@ -519,13 +520,13 @@ template void BuildTrie(SortedFiles &files, std::ve for (unsigned char i = 2; i <= counts.size(); ++i) { inputs[i-2].Rewind(); } - // Fill entries except unigram probabilities. + // Fill entries except unigram probabilities. { WriteEntries writer(contexts, quant, unigrams, out.middle_begin_, out.longest_, counts.size(), sri); - RecursiveInsert(counts.size(), counts[0], inputs, config.messages, "Writing trie", writer); + RecursiveInsert(counts.size(), counts[0], inputs, config.ProgressMessages(), "Writing trie", writer); } - // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. + // Do not disable this error message or else too little state will be returned. Both WriteEntries::Middle and returning state based on found n-grams will need to be fixed to handle this situation. for (unsigned char order = 2; order <= counts.size(); ++order) { const RecordReader &context = contexts[order - 2]; if (context) { @@ -541,13 +542,13 @@ template void BuildTrie(SortedFiles &files, std::ve } /* Set ending offsets so the last entry will be sized properly */ - // Last entry for unigrams was already set. + // Last entry for unigrams was already set. if (out.middle_begin_ != out.middle_end_) { for (typename TrieSearch::Middle *i = out.middle_begin_; i != out.middle_end_ - 1; ++i) { i->FinishedLoading((i+1)->InsertIndex(), config); } (out.middle_end_ - 1)->FinishedLoading(out.longest_.InsertIndex(), config); - } + } } template uint8_t *TrieSearch::SetupMemory(uint8_t *start, const std::vector &counts, const Config &config) { @@ -595,7 +596,7 @@ template void TrieSearch::Initializ } else { temporary_prefix = file; } - // At least 1MB sorting memory. + // At least 1MB sorting memory. SortedFiles sorted(config, f, counts, std::max(config.building_memory, 1048576), temporary_prefix, vocab); BuildTrie(sorted, counts, config, *this, quant_, vocab, backing); diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 5306850f..a676bdb3 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -27,6 +27,7 @@ libklm_util_a_SOURCES = \ mmap.cc \ murmur_hash.cc \ pool.cc \ + read_compressed.cc \ string_piece.cc \ usage.cc diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 053a850b..0165a7a3 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -87,8 +87,14 @@ template typename Except::template ExceptionTag= 3 +#define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) +#else +#define UTIL_UNLIKELY(x) (x) +#endif + #define UTIL_THROW_IF(Condition, Exception, Modify) do { \ - if (Condition) { \ + if (UTIL_UNLIKELY(Condition)) { \ Exception UTIL_e; \ UTIL_SET_LOCATION(UTIL_e, #Exception, #Condition); \ UTIL_e << Modify; \ diff --git a/klm/util/file.cc b/klm/util/file.cc index 6bf879ac..b9a77cf9 100644 --- a/klm/util/file.cc +++ b/klm/util/file.cc @@ -15,6 +15,8 @@ #if defined(_WIN32) || defined(_WIN64) #include #include +#include +#include #else #include #endif @@ -48,7 +50,7 @@ int OpenReadOrThrow(const char *name) { int CreateOrThrow(const char *name) { int ret; #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); + UTIL_THROW_IF(-1 == (ret = _open(name, _O_CREAT | _O_TRUNC | _O_RDWR | _O_BINARY, _S_IREAD | _S_IWRITE)), ErrnoException, "while creating " << name); #else UTIL_THROW_IF(-1 == (ret = open(name, O_CREAT | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH)), ErrnoException, "while creating " << name); #endif @@ -74,16 +76,22 @@ void ResizeOrThrow(int fd, uint64_t to) { #endif } -#ifdef WIN32 -typedef int ssize_t; +std::size_t PartialRead(int fd, void *to, std::size_t amount) { +#if defined(_WIN32) || defined(_WIN64) + amount = min(static_cast(INT_MAX), amount); + int ret = _read(fd, to, amount); +#else + ssize_t ret = read(fd, to, amount); #endif + UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + return static_cast(ret); +} void ReadOrThrow(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast(to_void); while (amount) { - ssize_t ret = read(fd, to, amount); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); - UTIL_THROW_IF(ret == 0, EndOfFileException, "Hit EOF in fd " << fd << " but there should be " << amount << " more bytes to read."); + std::size_t ret = PartialRead(fd, to, amount); + UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read."); amount -= ret; to += ret; } @@ -93,8 +101,7 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast(to_void); std::size_t remaining = amount; while (remaining) { - ssize_t ret = read(fd, to, remaining); - UTIL_THROW_IF(ret == -1, ErrnoException, "Reading " << remaining << " from fd " << fd << " failed."); + std::size_t ret = PartialRead(fd, to, remaining); if (!ret) return amount - remaining; remaining -= ret; to += ret; @@ -105,7 +112,11 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { void WriteOrThrow(int fd, const void *data_void, std::size_t size) { const uint8_t *data = static_cast(data_void); while (size) { +#if defined(_WIN32) || defined(_WIN64) + int ret = write(fd, data, min(static_cast(INT_MAX), size)); +#else ssize_t ret = write(fd, data, size); +#endif if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); data += ret; size -= ret; @@ -114,7 +125,7 @@ void WriteOrThrow(int fd, const void *data_void, std::size_t size) { void WriteOrThrow(FILE *to, const void *data, std::size_t size) { assert(size); - if (1 != std::fwrite(data, size, 1, to)) UTIL_THROW(util::ErrnoException, "Short write; requested size " << size); + UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size); } void FSyncOrThrow(int fd) { @@ -149,14 +160,15 @@ void SeekEnd(int fd) { std::FILE *FDOpenOrThrow(scoped_fd &file) { std::FILE *ret = fdopen(file.get(), "r+b"); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); file.release(); return ret; } -std::FILE *FOpenOrThrow(const char *path, const char *mode) { - std::FILE *ret; - UTIL_THROW_IF(!(ret = fopen(path, mode)), util::ErrnoException, "Could not fopen " << path << " for " << mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file) { + std::FILE *ret = fdopen(file.get(), "rb"); + if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); + file.release(); return ret; } diff --git a/klm/util/file.hh b/klm/util/file.hh index 185cb1f3..c24580d6 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -32,8 +32,6 @@ class scoped_fd { return ret; } - operator bool() { return fd_ != -1; } - private: int fd_; @@ -76,8 +74,9 @@ uint64_t SizeFile(int fd); void ResizeOrThrow(int fd, uint64_t to); +std::size_t PartialRead(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size); -std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount); +std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size); @@ -90,8 +89,7 @@ void AdvanceOrThrow(int fd, int64_t off); void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); - -std::FILE *FOpenOrThrow(const char *path, const char *mode); +std::FILE *FDOpenReadOrThrow(scoped_fd &file); class TempMaker { public: diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 280f438c..5a208eff 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -14,7 +14,6 @@ #include #include -#include #include #include #include @@ -26,13 +25,6 @@ ParseNumberException::ParseNumberException(StringPiece value) throw() { *this << "Could not parse \"" << value << "\" into a number"; } -#ifdef HAVE_ZLIB -GZException::GZException(gzFile file) { - int num; - *this << gzerror(file, &num) << " from zlib"; -} -#endif // HAVE_ZLIB - // Sigh this is the only way I could come up with to do a _const_ bool. It has ' ', '\f', '\n', '\r', '\t', and '\v' (same as isspace on C locale). const bool kSpaces[256] = {0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; @@ -48,19 +40,7 @@ FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std: Initialize(name, show_progress, min_buffer); } -FilePiece::~FilePiece() { -#ifdef HAVE_ZLIB - if (gz_file_) { - // zlib took ownership - file_.release(); - int ret; - if (Z_OK != (ret = gzclose(gz_file_))) { - std::cerr << "could not close file " << file_name_ << " using zlib" << std::endl; - abort(); - } - } -#endif -} +FilePiece::~FilePiece() {} StringPiece FilePiece::ReadLine(char delim) { std::size_t skip = 0; @@ -95,9 +75,6 @@ unsigned long int FilePiece::ReadULong() { } void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer) { -#ifdef HAVE_ZLIB - gz_file_ = NULL; -#endif file_name_ = name; default_map_size_ = page_ * std::max((min_buffer / page_ + 1), 2); @@ -117,10 +94,7 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s } Shift(); // gzip detect. - if ((position_end_ - position_) > 2 && *position_ == 0x1f && static_cast(*(position_ + 1)) == 0x8b) { -#ifndef HAVE_ZLIB - UTIL_THROW(GZException, "Looks like a gzip file but support was not compiled in."); -#endif + if ((position_end_ - position_) >= ReadCompressed::kMagicSize && ReadCompressed::DetectCompressedMagic(position_)) { if (!fallback_to_read_) { at_end_ = false; TransitionToRead(); @@ -197,7 +171,7 @@ void FilePiece::Shift() { if (fallback_to_read_) ReadShift(); for (last_space_ = position_end_ - 1; last_space_ >= position_; --last_space_) { - if (isspace(*last_space_)) break; + if (kSpaces[static_cast(*last_space_)]) break; } } @@ -248,17 +222,14 @@ void FilePiece::TransitionToRead() { position_ = data_.begin(); position_end_ = position_; -#ifdef HAVE_ZLIB - assert(!gz_file_); - gz_file_ = gzdopen(file_.get(), "r"); - UTIL_THROW_IF(!gz_file_, GZException, "zlib failed to open " << file_name_); -#endif + try { + fell_back_.Reset(file_.release()); + } catch (util::Exception &e) { + e << " in file " << file_name_; + throw; + } } -#ifdef WIN32 -typedef int ssize_t; -#endif - void FilePiece::ReadShift() { assert(fallback_to_read_); // Bytes [data_.begin(), position_) have been consumed. @@ -283,7 +254,7 @@ void FilePiece::ReadShift() { position_ = data_.begin(); position_end_ = position_ + valid_length; } else { - size_t moving = position_end_ - position_; + std::size_t moving = position_end_ - position_; memmove(data_.get(), position_, moving); position_ = data_.begin(); position_end_ = position_ + moving; @@ -291,20 +262,9 @@ void FilePiece::ReadShift() { } } - ssize_t read_return; -#ifdef HAVE_ZLIB - read_return = gzread(gz_file_, static_cast(data_.get()) + already_read, default_map_size_ - already_read); - if (read_return == -1) throw GZException(gz_file_); - if (total_size_ != kBadSize) { - // Just get the position, don't actually seek. Apparently this is how you do it. . . - off_t ret = lseek(file_.get(), 0, SEEK_CUR); - if (ret != -1) progress_.Set(ret); - } -#else - read_return = read(file_.get(), static_cast(data_.get()) + already_read, default_map_size_ - already_read); - UTIL_THROW_IF(read_return == -1, ErrnoException, "read failed"); - progress_.Set(mapped_offset_); -#endif + std::size_t read_return = fell_back_.Read(static_cast(data_.get()) + already_read, default_map_size_ - already_read); + progress_.Set(fell_back_.RawAmount()); + if (read_return == 0) { at_end_ = true; } diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index af93d8aa..39bd1581 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -4,8 +4,8 @@ #include "util/ersatz_progress.hh" #include "util/exception.hh" #include "util/file.hh" -#include "util/have.hh" #include "util/mmap.hh" +#include "util/read_compressed.hh" #include "util/string_piece.hh" #include @@ -13,10 +13,6 @@ #include -#ifdef HAVE_ZLIB -#include -#endif - namespace util { class ParseNumberException : public Exception { @@ -25,28 +21,19 @@ class ParseNumberException : public Exception { ~ParseNumberException() throw() {} }; -class GZException : public Exception { - public: -#ifdef HAVE_ZLIB - explicit GZException(gzFile file); -#endif - GZException() throw() {} - ~GZException() throw() {} -}; - extern const bool kSpaces[256]; -// Memory backing the returned StringPiece may vanish on the next call. +// Memory backing the returned StringPiece may vanish on the next call. class FilePiece { public: - // 32 MB default. - explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); - // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 33554432); + // 1 MB default. + explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); + // Takes ownership of fd. name is used for messages. + explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); ~FilePiece(); - - char get() { + + char get() { if (position_ == position_end_) { Shift(); if (at_end_) throw EndOfFileException(); @@ -54,14 +41,14 @@ class FilePiece { return *(position_++); } - // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). + // Leaves the delimiter, if any, to be returned by get(). Delimiters defined by isspace(). StringPiece ReadDelimited(const bool *delim = kSpaces) { SkipSpaces(delim); return Consume(FindDelimiterOrEOF(delim)); } // Unlike ReadDelimited, this includes leading spaces and consumes the delimiter. - // It is similar to getline in that way. + // It is similar to getline in that way. StringPiece ReadLine(char delim = '\n'); float ReadFloat(); @@ -69,7 +56,7 @@ class FilePiece { long int ReadLong(); unsigned long int ReadULong(); - // Skip spaces defined by isspace. + // Skip spaces defined by isspace. void SkipSpaces(const bool *delim = kSpaces) { for (; ; ++position_) { if (position_ == position_end_) Shift(); @@ -82,7 +69,7 @@ class FilePiece { } const std::string &FileName() const { return file_name_; } - + private: void Initialize(const char *name, std::ostream *show_progress, std::size_t min_buffer); @@ -122,9 +109,7 @@ class FilePiece { std::string file_name_; -#ifdef HAVE_ZLIB - gzFile gz_file_; -#endif // HAVE_ZLIB + ReadCompressed fell_back_; }; } // namespace util diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index f912e18a..e79ece7a 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -38,7 +38,7 @@ BOOST_AUTO_TEST_CASE(MMapReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); } -#ifndef __APPLE__ +#if !defined(_WIN32) && !defined(_WIN64) && !defined(__APPLE__) /* Apple isn't happy with the popen, fileno, dup. And I don't want to * reimplement popen. This is an issue with the test. */ @@ -65,7 +65,7 @@ BOOST_AUTO_TEST_CASE(StreamReadLine) { BOOST_CHECK_THROW(test.get(), EndOfFileException); BOOST_REQUIRE(!pclose(catter)); } -#endif // __APPLE__ +#endif #ifdef HAVE_ZLIB diff --git a/klm/util/have.hh b/klm/util/have.hh index b8181e99..1523c0c5 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -2,22 +2,12 @@ #ifndef UTIL_HAVE__ #define UTIL_HAVE__ -#ifndef HAVE_ZLIB -#if !defined(_WIN32) && !defined(_WIN64) -#define HAVE_ZLIB -#endif -#endif - #ifndef HAVE_ICU //#define HAVE_ICU #endif #ifndef HAVE_BOOST -#define HAVE_BOOST -#endif - -#ifndef HAVE_THREADS -//#define HAVE_THREADS +//#define HAVE_BOOST #endif #endif // UTIL_HAVE__ diff --git a/klm/util/joint_sort.hh b/klm/util/joint_sort.hh index cf3d8432..1b43ddcf 100644 --- a/klm/util/joint_sort.hh +++ b/klm/util/joint_sort.hh @@ -60,7 +60,7 @@ template class JointProxy { JointProxy(const KeyIter &key_iter, const ValueIter &value_iter) : inner_(key_iter, value_iter) {} JointProxy(const JointProxy &other) : inner_(other.inner_) {} - operator const value_type() const { + operator value_type() const { value_type ret; ret.key = *inner_.key_; ret.value = *inner_.value_; @@ -121,7 +121,7 @@ template class LessWrapper : public std::binary_functi template class PairedIterator : public ProxyIterator > { public: - PairedIterator(const KeyIter &key, const ValueIter &value) : + PairedIterator(const KeyIter &key, const ValueIter &value) : ProxyIterator >(detail::JointProxy(key, value)) {} }; diff --git a/klm/util/read_compressed.cc b/klm/util/read_compressed.cc new file mode 100644 index 00000000..4ec94c4e --- /dev/null +++ b/klm/util/read_compressed.cc @@ -0,0 +1,403 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" +#include "util/scoped.hh" + +#include +#include + +#include +#include +#include +#include + +#ifdef HAVE_ZLIB +#include +#endif + +#ifdef HAVE_BZLIB +#include +#endif + +#ifdef HAVE_XZLIB +#include +#endif + +namespace util { + +CompressedException::CompressedException() throw() {} +CompressedException::~CompressedException() throw() {} + +GZException::GZException() throw() {} +GZException::~GZException() throw() {} + +BZException::BZException() throw() {} +BZException::~BZException() throw() {} + +XZException::XZException() throw() {} +XZException::~XZException() throw() {} + +class ReadBase { + public: + virtual ~ReadBase() {} + + virtual std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) = 0; + + protected: + static void ReplaceThis(ReadBase *with, ReadCompressed &thunk) { + thunk.internal_.reset(with); + } + + static uint64_t &ReadCount(ReadCompressed &thunk) { + return thunk.raw_amount_; + } +}; + +namespace { + +// Completed file that other classes can thunk to. +class Complete : public ReadBase { + public: + std::size_t Read(void *, std::size_t, ReadCompressed &) { + return 0; + } +}; + +class Uncompressed : public ReadBase { + public: + explicit Uncompressed(int fd) : fd_(fd) {} + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + std::size_t got = PartialRead(fd_.get(), to, amount); + ReadCount(thunk) += got; + return got; + } + + private: + scoped_fd fd_; +}; + +class UncompressedWithHeader : public ReadBase { + public: + UncompressedWithHeader(int fd, void *already_data, std::size_t already_size) : fd_(fd) { + assert(already_size); + buf_.reset(malloc(already_size)); + if (!buf_.get()) throw std::bad_alloc(); + memcpy(buf_.get(), already_data, already_size); + remain_ = static_cast(buf_.get()); + end_ = remain_ + already_size; + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + assert(buf_.get()); + std::size_t sending = std::min(amount, end_ - remain_); + memcpy(to, remain_, sending); + remain_ += sending; + if (remain_ == end_) { + ReplaceThis(new Uncompressed(fd_.release()), thunk); + } + return sending; + } + + private: + scoped_malloc buf_; + uint8_t *remain_; + uint8_t *end_; + + scoped_fd fd_; +}; + +#ifdef HAVE_ZLIB +class GZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + GZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.zalloc = Z_NULL; + stream_.zfree = Z_NULL; + stream_.opaque = Z_NULL; + stream_.msg = NULL; + // 32 for zlib and gzip decoding with automatic header detection. + // 15 for maximum window size. + UTIL_THROW_IF(Z_OK != inflateInit2(&stream_, 32 + 15), GZException, "Failed to initialize zlib."); + } + + ~GZip() { + if (Z_OK != inflateEnd(&stream_)) { + std::cerr << "zlib could not close properly." << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast(to); + stream_.avail_out = std::min(std::numeric_limits::max(), amount); + do { + if (!stream_.avail_in) ReadInput(thunk); + int result = inflate(&stream_, 0); + switch (result) { + case Z_OK: + break; + case Z_STREAM_END: + { + std::size_t ret = static_cast(stream_.next_out) - static_cast(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case Z_ERRNO: + UTIL_THROW(ErrnoException, "zlib error"); + default: + UTIL_THROW(GZException, "zlib encountered " << (stream_.msg ? stream_.msg : "an error ") << " code " << result); + } + } while (stream_.next_out == to); + return static_cast(stream_.next_out) - static_cast(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + z_stream stream_; +}; +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +class BZip : public ReadBase { + public: + explicit BZip(int fd, void *already_data, std::size_t already_size) { + scoped_fd hold(fd); + closer_.reset(FDOpenReadOrThrow(hold)); + int bzerror = BZ_OK; + file_ = BZ2_bzReadOpen(&bzerror, closer_.get(), 0, 0, already_data, already_size); + switch (bzerror) { + case BZ_OK: + return; + case BZ_CONFIG_ERROR: + UTIL_THROW(BZException, "Looks like bzip2 was miscompiled."); + case BZ_PARAM_ERROR: + UTIL_THROW(BZException, "Parameter error"); + case BZ_IO_ERROR: + UTIL_THROW(BZException, "IO error reading file"); + case BZ_MEM_ERROR: + throw std::bad_alloc(); + } + } + + ~BZip() { + int bzerror = BZ_OK; + BZ2_bzReadClose(&bzerror, file_); + if (bzerror != BZ_OK) { + std::cerr << "bz2 readclose error" << std::endl; + abort(); + } + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + int bzerror = BZ_OK; + int ret = BZ2_bzRead(&bzerror, file_, to, std::min(static_cast(INT_MAX), amount)); + long pos; + switch (bzerror) { + case BZ_STREAM_END: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + ReplaceThis(new Complete(), thunk); + return ret; + case BZ_OK: + pos = ftell(closer_.get()); + if (pos != -1) ReadCount(thunk) = pos; + return ret; + default: + UTIL_THROW(BZException, "bzip2 error " << BZ2_bzerror(file_, &bzerror) << " code " << bzerror); + } + } + + private: + scoped_FILE closer_; + BZFILE *file_; +}; +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +class XZip : public ReadBase { + private: + static const std::size_t kInputBuffer = 16384; + public: + XZip(int fd, void *already_data, std::size_t already_size) + : file_(fd), in_buffer_(malloc(kInputBuffer)), stream_(), action_(LZMA_RUN) { + if (!in_buffer_.get()) throw std::bad_alloc(); + assert(already_size < kInputBuffer); + if (already_size) { + memcpy(in_buffer_.get(), already_data, already_size); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = already_size; + stream_.avail_in += ReadOrEOF(file_.get(), static_cast(in_buffer_.get()) + already_size, kInputBuffer - already_size); + } else { + stream_.avail_in = 0; + } + stream_.allocator = NULL; + lzma_ret ret = lzma_stream_decoder(&stream_, UINT64_MAX, LZMA_CONCATENATED); + switch (ret) { + case LZMA_OK: + break; + case LZMA_MEM_ERROR: + UTIL_THROW(ErrnoException, "xz open error"); + default: + UTIL_THROW(XZException, "xz error code " << ret); + } + } + + ~XZip() { + lzma_end(&stream_); + } + + std::size_t Read(void *to, std::size_t amount, ReadCompressed &thunk) { + if (amount == 0) return 0; + stream_.next_out = static_cast(to); + stream_.avail_out = amount; + do { + if (!stream_.avail_in) ReadInput(thunk); + lzma_ret status = lzma_code(&stream_, action_); + switch (status) { + case LZMA_OK: + break; + case LZMA_STREAM_END: + UTIL_THROW_IF(action_ != LZMA_FINISH, XZException, "Input not finished yet."); + { + std::size_t ret = static_cast(stream_.next_out) - static_cast(to); + ReplaceThis(new Complete(), thunk); + return ret; + } + case LZMA_MEM_ERROR: + throw std::bad_alloc(); + case LZMA_FORMAT_ERROR: + UTIL_THROW(XZException, "xzlib says file format not recognized"); + case LZMA_OPTIONS_ERROR: + UTIL_THROW(XZException, "xzlib says unsupported compression options"); + case LZMA_DATA_ERROR: + UTIL_THROW(XZException, "xzlib says this file is corrupt"); + case LZMA_BUF_ERROR: + UTIL_THROW(XZException, "xzlib says unexpected end of input"); + default: + UTIL_THROW(XZException, "unrecognized xzlib error " << status); + } + } while (stream_.next_out == to); + return static_cast(stream_.next_out) - static_cast(to); + } + + private: + void ReadInput(ReadCompressed &thunk) { + assert(!stream_.avail_in); + stream_.next_in = static_cast(in_buffer_.get()); + stream_.avail_in = ReadOrEOF(file_.get(), in_buffer_.get(), kInputBuffer); + if (!stream_.avail_in) action_ = LZMA_FINISH; + ReadCount(thunk) += stream_.avail_in; + } + + scoped_fd file_; + scoped_malloc in_buffer_; + lzma_stream stream_; + + lzma_action action_; +}; +#endif // HAVE_XZLIB + +enum MagicResult { + UNKNOWN, GZIP, BZIP, XZIP +}; + +MagicResult DetectMagic(const void *from_void) { + const uint8_t *header = static_cast(from_void); + if (header[0] == 0x1f && header[1] == 0x8b) { + return GZIP; + } + if (header[0] == 'B' && header[1] == 'Z') { + return BZIP; + } + const uint8_t xzmagic[6] = { 0xFD, '7', 'z', 'X', 'Z', 0x00 }; + if (!memcmp(header, xzmagic, 6)) { + return XZIP; + } + return UNKNOWN; +} + +ReadBase *ReadFactory(int fd, uint64_t &raw_amount) { + scoped_fd hold(fd); + unsigned char header[ReadCompressed::kMagicSize]; + raw_amount = ReadOrEOF(fd, header, ReadCompressed::kMagicSize); + if (!raw_amount) + return new Uncompressed(hold.release()); + if (raw_amount != ReadCompressed::kMagicSize) + return new UncompressedWithHeader(hold.release(), header, raw_amount); + switch (DetectMagic(header)) { + case GZIP: +#ifdef HAVE_ZLIB + return new GZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a gzip file but gzip support was not compiled in."); +#endif + case BZIP: +#ifdef HAVE_BZLIB + return new BZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like a bzip file (it begins with BZ), but bzip support was not compiled in."); +#endif + case XZIP: +#ifdef HAVE_XZLIB + return new XZip(hold.release(), header, ReadCompressed::kMagicSize); +#else + UTIL_THROW(CompressedException, "This looks like an xz file, but xz support was not compiled in."); +#endif + case UNKNOWN: + break; + } + try { + AdvanceOrThrow(fd, -ReadCompressed::kMagicSize); + } catch (const util::ErrnoException &e) { + return new UncompressedWithHeader(hold.release(), header, ReadCompressed::kMagicSize); + } + return new Uncompressed(hold.release()); +} + +} // namespace + +bool ReadCompressed::DetectCompressedMagic(const void *from_void) { + return DetectMagic(from_void) != UNKNOWN; +} + +ReadCompressed::ReadCompressed(int fd) { + Reset(fd); +} + +ReadCompressed::ReadCompressed() {} + +ReadCompressed::~ReadCompressed() {} + +void ReadCompressed::Reset(int fd) { + internal_.reset(); + internal_.reset(ReadFactory(fd, raw_amount_)); +} + +std::size_t ReadCompressed::Read(void *to, std::size_t amount) { + return internal_->Read(to, amount, *this); +} + +} // namespace util diff --git a/klm/util/read_compressed.hh b/klm/util/read_compressed.hh new file mode 100644 index 00000000..83ca9fb2 --- /dev/null +++ b/klm/util/read_compressed.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_READ_COMPRESSED__ +#define UTIL_READ_COMPRESSED__ + +#include "util/exception.hh" +#include "util/scoped.hh" + +#include + +#include + +namespace util { + +class CompressedException : public Exception { + public: + CompressedException() throw(); + virtual ~CompressedException() throw(); +}; + +class GZException : public CompressedException { + public: + GZException() throw(); + ~GZException() throw(); +}; + +class BZException : public CompressedException { + public: + BZException() throw(); + ~BZException() throw(); +}; + +class XZException : public CompressedException { + public: + XZException() throw(); + ~XZException() throw(); +}; + +class ReadBase; + +class ReadCompressed { + public: + static const std::size_t kMagicSize = 6; + // Must have at least kMagicSize bytes. + static bool DetectCompressedMagic(const void *from); + + // Takes ownership of fd. + explicit ReadCompressed(int fd); + + // Must call Reset later. + ReadCompressed(); + + ~ReadCompressed(); + + // Takes ownership of fd. + void Reset(int fd); + + std::size_t Read(void *to, std::size_t amount); + + uint64_t RawAmount() const { return raw_amount_; } + + private: + friend class ReadBase; + + scoped_ptr internal_; + + uint64_t raw_amount_; + + // No copying. + ReadCompressed(const ReadCompressed &); + void operator=(const ReadCompressed &); +}; + +} // namespace util + +#endif // UTIL_READ_COMPRESSED__ diff --git a/klm/util/read_compressed_test.cc b/klm/util/read_compressed_test.cc new file mode 100644 index 00000000..6fd97e5e --- /dev/null +++ b/klm/util/read_compressed_test.cc @@ -0,0 +1,94 @@ +#include "util/read_compressed.hh" + +#include "util/file.hh" +#include "util/have.hh" + +#define BOOST_TEST_MODULE ReadCompressedTest +#include +#include + +#include +#include + +#include + +namespace util { +namespace { + +void ReadLoop(ReadCompressed &reader, void *to_void, std::size_t amount) { + uint8_t *to = static_cast(to_void); + while (amount) { + std::size_t ret = reader.Read(to, amount); + BOOST_REQUIRE(ret); + to += ret; + amount -= ret; + } +} + +void TestRandom(const char *compressor) { + const uint32_t kSize4 = 100000 / 4; + char name[] = "tempXXXXXX"; + + // Write test file. + { + scoped_fd original(mkstemp(name)); + BOOST_REQUIRE(original.get() > 0); + for (uint32_t i = 0; i < kSize4; ++i) { + WriteOrThrow(original.get(), &i, sizeof(uint32_t)); + } + } + + char gzname[] = "tempXXXXXX"; + scoped_fd gzipped(mkstemp(gzname)); + + std::string command(compressor); +#ifdef __CYGWIN__ + command += ".exe"; +#endif + command += " <\""; + command += name; + command += "\" >\""; + command += gzname; + command += "\""; + BOOST_REQUIRE_EQUAL(0, system(command.c_str())); + + BOOST_CHECK_EQUAL(0, unlink(name)); + BOOST_CHECK_EQUAL(0, unlink(gzname)); + + ReadCompressed reader(gzipped.release()); + for (uint32_t i = 0; i < kSize4; ++i) { + uint32_t got; + ReadLoop(reader, &got, sizeof(uint32_t)); + BOOST_CHECK_EQUAL(i, got); + } + + char ignored; + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); + // Test double EOF call. + BOOST_CHECK_EQUAL((std::size_t)0, reader.Read(&ignored, 1)); +} + +BOOST_AUTO_TEST_CASE(Uncompressed) { + TestRandom("cat"); +} + +#ifdef HAVE_ZLIB +BOOST_AUTO_TEST_CASE(ReadGZ) { + TestRandom("gzip"); +} +#endif // HAVE_ZLIB + +#ifdef HAVE_BZLIB +BOOST_AUTO_TEST_CASE(ReadBZ) { + TestRandom("bzip2"); +} +#endif // HAVE_BZLIB + +#ifdef HAVE_XZLIB +BOOST_AUTO_TEST_CASE(ReadXZ) { + TestRandom("xz"); +} +#endif + +} // namespace +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index 93e2e817..d62c6df1 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -1,40 +1,13 @@ #ifndef UTIL_SCOPED__ #define UTIL_SCOPED__ +/* Other scoped objects in the style of scoped_ptr. */ #include "util/exception.hh" - -/* Other scoped objects in the style of scoped_ptr. */ #include #include namespace util { -template class scoped_thing { - public: - explicit scoped_thing(T *c = static_cast(0)) : c_(c) {} - - ~scoped_thing() { if (c_) Free(c_); } - - void reset(T *c) { - if (c_) Free(c_); - c_ = c; - } - - T &operator*() { return *c_; } - const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - - T *get() { return c_; } - const T *get() const { return c_; } - - private: - T *c_; - - scoped_thing(const scoped_thing &); - scoped_thing &operator=(const scoped_thing &); -}; - class scoped_malloc { public: scoped_malloc() : p_(NULL) {} @@ -77,9 +50,6 @@ template class scoped_array { T &operator*() { return *c_; } const T&operator*() const { return *c_; } - T &operator->() { return *c_; } - const T&operator->() const { return *c_; } - T &operator[](std::size_t idx) { return c_[idx]; } const T &operator[](std::size_t idx) const { return c_[idx]; } @@ -90,6 +60,39 @@ template class scoped_array { private: T *c_; + + scoped_array(const scoped_array &); + void operator=(const scoped_array &); +}; + +template class scoped_ptr { + public: + explicit scoped_ptr(T *content = NULL) : c_(content) {} + + ~scoped_ptr() { delete c_; } + + T *get() { return c_; } + const T* get() const { return c_; } + + T &operator*() { return *c_; } + const T&operator*() const { return *c_; } + + T *operator->() { return c_; } + const T*operator->() const { return c_; } + + T &operator[](std::size_t idx) { return c_[idx]; } + const T &operator[](std::size_t idx) const { return c_[idx]; } + + void reset(T *to = NULL) { + scoped_ptr other(c_); + c_ = to; + } + + private: + T *c_; + + scoped_ptr(const scoped_ptr &); + void operator=(const scoped_ptr &); }; } // namespace util diff --git a/klm/util/string_piece.hh b/klm/util/string_piece.hh index be6a643d..51481646 100644 --- a/klm/util/string_piece.hh +++ b/klm/util/string_piece.hh @@ -1,6 +1,6 @@ /* If you use ICU in your program, then compile with -DHAVE_ICU -licui18n. If * you don't use ICU, then this will use the Google implementation from Chrome. - * This has been modified from the original version to let you choose. + * This has been modified from the original version to let you choose. */ // Copyright 2008, Google Inc. @@ -62,9 +62,9 @@ #include #include -// Old versions of ICU don't define operator== and operator!=. +// Old versions of ICU don't define operator== and operator!=. #if (U_ICU_VERSION_MAJOR_NUM < 4) || ((U_ICU_VERSION_MAJOR_NUM == 4) && (U_ICU_VERSION_MINOR_NUM < 4)) -#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. +#warning You are using an old version of ICU. Consider upgrading to ICU >= 4.6. inline bool operator==(const StringPiece& x, const StringPiece& y) { if (x.size() != y.size()) return false; @@ -274,15 +274,28 @@ struct StringPieceCompatibleEquals : public std::binary_function typename T::const_iterator FindStringPiece(const T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } + template typename T::iterator FindStringPiece(T &t, const StringPiece &key) { +#if BOOST_VERSION < 104200 + std::string temp(key.data(), key.size()); + return t.find(temp); +#else return t.find(key, StringPieceCompatibleHash(), StringPieceCompatibleEquals()); +#endif } #endif #ifdef HAVE_ICU U_NAMESPACE_END +using U_NAMESPACE_QUALIFIER StringPiece; #endif + #endif // BASE_STRING_PIECE_H__ diff --git a/klm/util/tokenize_piece.hh b/klm/util/tokenize_piece.hh index 4a7f5460..a588c3fc 100644 --- a/klm/util/tokenize_piece.hh +++ b/klm/util/tokenize_piece.hh @@ -20,6 +20,7 @@ class OutOfTokens : public Exception { class SingleCharacter { public: + SingleCharacter() {} explicit SingleCharacter(char delim) : delim_(delim) {} StringPiece Find(const StringPiece &in) const { @@ -32,6 +33,8 @@ class SingleCharacter { class MultiCharacter { public: + MultiCharacter() {} + explicit MultiCharacter(const StringPiece &delimiter) : delimiter_(delimiter) {} StringPiece Find(const StringPiece &in) const { @@ -44,6 +47,7 @@ class MultiCharacter { class AnyCharacter { public: + AnyCharacter() {} explicit AnyCharacter(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -56,6 +60,8 @@ class AnyCharacter { class AnyCharacterLast { public: + AnyCharacterLast() {} + explicit AnyCharacterLast(const StringPiece &chars) : chars_(chars) {} StringPiece Find(const StringPiece &in) const { @@ -81,8 +87,8 @@ template class TokenIter : public boost::it return current_.data() != 0; } - static TokenIter end() { - return TokenIter(); + static TokenIter end() { + return TokenIter(); } private: @@ -100,8 +106,8 @@ template class TokenIter : public boost::it } while (SkipEmpty && current_.data() && current_.empty()); // Compiler should optimize this away if SkipEmpty is false. } - bool equal(const TokenIter &other) const { - return after_.data() == other.after_.data(); + bool equal(const TokenIter &other) const { + return current_.data() == other.current_.data(); } const StringPiece &dereference() const { -- cgit v1.2.3 From c7b1dc8eabd50eefb7403ce36d2746f2df39e30e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 15 Dec 2012 02:53:56 -0500 Subject: enable kenlm compression --- configure.ac | 26 ++++++++++++++++++-------- decoder/Makefile.am | 11 +++++------ example_extff/Makefile.am | 2 +- klm/util/have.hh | 3 +-- mteval/Makefile.am | 6 +++--- python/setup.py.in | 2 +- training/dpmert/Makefile.am | 10 +++++----- training/dtrain/Makefile.am | 2 +- training/minrisk/Makefile.am | 2 +- training/mira/Makefile.am | 2 +- training/pro/Makefile.am | 4 ++-- training/rampion/Makefile.am | 2 +- training/utils/Makefile.am | 4 ++-- utils/Makefile.am | 18 +++++++++--------- word-aligner/Makefile.am | 2 +- 15 files changed, 52 insertions(+), 44 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index f4650ca4..eabb8645 100644 --- a/configure.ac +++ b/configure.ac @@ -18,6 +18,23 @@ BOOST_TEST AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) AC_CHECK_LIB(dl, dlopen) +AC_CHECK_HEADERS(zlib.h, + AC_CHECK_LIB(z, gzread,[ + AC_DEFINE(HAVE_ZLIB,[],[Do we have zlib]) + ZLIBS="$ZLIBS -lz" + ])) + +AC_CHECK_HEADERS(bzlib.h, + AC_CHECK_LIB(bz2, BZ2_bzReadOpen,[ + AC_DEFINE(HAVE_BZLIB,[],[Do we have bzlib]) + ZLIBS="$ZLIBS -lbz2" + ])) + +AC_CHECK_HEADERS(lzma.h, + AC_CHECK_LIB(lzma, lzma_code,[ + AC_DEFINE(HAVE_XZLIB,[],[Do we have lzma]) + ZLIBS="$ZLIBS -llzma" + ])) AC_ARG_ENABLE(mpi, [ --enable-mpi Build MPI binaries, assumes mpi.h is present ], @@ -72,19 +89,12 @@ fi CPPFLAGS="$CPPFLAGS $BOOST_CPPFLAGS" LDFLAGS="$LDFLAGS $BOOST_PROGRAM_OPTIONS_LDFLAGS $BOOST_SERIALIZATION_LDFLAGS $BOOST_SYSTEM_LDFLAGS" # $BOOST_THREAD_LDFLAGS" -LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_LIBS" +LIBS="$LIBS $BOOST_PROGRAM_OPTIONS_LIBS $BOOST_SERIALIZATION_LIBS $BOOST_SYSTEM_LIBS $ZLIBS" # $BOOST_THREAD_LIBS" AC_CHECK_HEADER(google/dense_hash_map, [AC_DEFINE([HAVE_SPARSEHASH], [1], [flag for google::dense_hash_map])]) -AC_CHECK_HEADER(zlib.h, - [AC_DEFINE([HAVE_ZLIB], [1], [zlib])]) -AC_CHECK_HEADER(bzlib.h, - [AC_DEFINE([HAVE_BZLIB], [1], [bzlib])]) -AC_CHECK_HEADER(lzma.h, - [AC_DEFINE([HAVE_XZLIB], [1], [xzlib])]) - AC_PROG_INSTALL CPPFLAGS="-DPIC -fPIC $CPPFLAGS -DHAVE_CONFIG_H -DKENLM_MAX_ORDER=6" diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 6914fa0f..88a6116c 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -8,16 +8,16 @@ noinst_PROGRAMS = \ TESTS = trule_test parser_test grammar_test hg_test parser_test_SOURCES = parser_test.cc -parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +parser_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a grammar_test_SOURCES = grammar_test.cc -grammar_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +grammar_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a hg_test_SOURCES = hg_test.cc -hg_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +hg_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a trule_test_SOURCES = trule_test.cc -trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a -lz +trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a -lz +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm @@ -82,4 +82,3 @@ libcdec_a_SOURCES = \ JSON_parser.c \ json_parse.cc \ grammar.cc - diff --git a/example_extff/Makefile.am b/example_extff/Makefile.am index ac2694ca..7b7c34b5 100644 --- a/example_extff/Makefile.am +++ b/example_extff/Makefile.am @@ -1,4 +1,4 @@ -AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. -I../mteval -I../utils -I../klm -I../decoder +AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I.. -I../mteval -I../utils -I../klm -I../decoder lib_LTLIBRARIES = libff_example.la libff_example_la_SOURCES = ff_example.cc diff --git a/klm/util/have.hh b/klm/util/have.hh index b86ba11e..85b838e4 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -11,8 +11,7 @@ #endif #ifdef HAVE_CONFIG_H -// Chris; uncomment this line. -//#include "config.h" +#include "config.h" #endif #endif // UTIL_HAVE__ diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 5e9bba91..4444285f 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -23,12 +23,12 @@ libmteval_a_SOURCES = \ ter.cc fast_score_SOURCES = fast_score.cc -fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz +fast_score_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a mbr_kbest_SOURCES = mbr_kbest.cc -mbr_kbest_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a -lz +mbr_kbest_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a scorer_test_SOURCES = scorer_test.cc -scorer_test_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +scorer_test_LDADD = libmteval.a $(top_srcdir)/utils/libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils diff --git a/python/setup.py.in b/python/setup.py.in index dac72903..fa8a9f5e 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -17,7 +17,7 @@ ext_modules = [ sources=['src/_cdec.cpp'], include_dirs=INC, library_dirs=LIB, - libraries=LIBS + ['z', 'cdec', 'utils', 'mteval', 'training_utils', 'klm', 'klm_util', 'ksearch'], + libraries=['cdec', 'utils', 'mteval', 'training_utils', 'klm', 'klm_util', 'ksearch'] + LIBS, extra_compile_args=CPPFLAGS, extra_link_args=LDFLAGS), Extension(name='cdec.sa._sa', diff --git a/training/dpmert/Makefile.am b/training/dpmert/Makefile.am index ff318bef..3dbdfa69 100644 --- a/training/dpmert/Makefile.am +++ b/training/dpmert/Makefile.am @@ -8,18 +8,18 @@ noinst_PROGRAMS = \ TESTS = lo_test mr_dpmert_generate_mapper_input_SOURCES = mr_dpmert_generate_mapper_input.cc line_optimizer.cc -mr_dpmert_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +mr_dpmert_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a # nbest2hg_SOURCES = nbest2hg.cc -# nbest2hg_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lfst -lz +# nbest2hg_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lfst mr_dpmert_map_SOURCES = mert_geometry.cc ces.cc error_surface.cc mr_dpmert_map.cc line_optimizer.cc -mr_dpmert_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +mr_dpmert_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a mr_dpmert_reduce_SOURCES = error_surface.cc ces.cc mr_dpmert_reduce.cc line_optimizer.cc mert_geometry.cc -mr_dpmert_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +mr_dpmert_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a lo_test_SOURCES = lo_test.cc ces.cc mert_geometry.cc error_surface.cc line_optimizer.cc -lo_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +lo_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 5b48e756..4f51b0c8 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc -dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz +dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/minrisk/Makefile.am b/training/minrisk/Makefile.am index a15e821e..821730c2 100644 --- a/training/minrisk/Makefile.am +++ b/training/minrisk/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = minrisk_optimize minrisk_optimize_SOURCES = minrisk_optimize.cc -minrisk_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/training/liblbfgs/liblbfgs.a -lz +minrisk_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/training/liblbfgs/liblbfgs.a AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training -I$(top_srcdir)/training/utils diff --git a/training/mira/Makefile.am b/training/mira/Makefile.am index ae609ede..c8f404fb 100644 --- a/training/mira/Makefile.am +++ b/training/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz +kbest_mira_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/pro/Makefile.am b/training/pro/Makefile.am index 1916b6b2..e0a45a33 100644 --- a/training/pro/Makefile.am +++ b/training/pro/Makefile.am @@ -3,9 +3,9 @@ bin_PROGRAMS = \ mr_pro_reduce mr_pro_map_SOURCES = mr_pro_map.cc -mr_pro_map_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +mr_pro_map_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a mr_pro_reduce_SOURCES = mr_pro_reduce.cc -mr_pro_reduce_LDADD = $(top_srcdir)/training/liblbfgs/liblbfgs.a $(top_srcdir)/utils/libutils.a -lz +mr_pro_reduce_LDADD = $(top_srcdir)/training/liblbfgs/liblbfgs.a $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils -I$(top_srcdir)/training diff --git a/training/rampion/Makefile.am b/training/rampion/Makefile.am index 1633d0f7..ef0ca147 100644 --- a/training/rampion/Makefile.am +++ b/training/rampion/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = rampion_cccp rampion_cccp_SOURCES = rampion_cccp.cc -rampion_cccp_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lz +rampion_cccp_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -I$(top_srcdir)/training/utils diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am index 189d9a76..c9405d4e 100644 --- a/training/utils/Makefile.am +++ b/training/utils/Makefile.am @@ -24,10 +24,10 @@ libtraining_utils_a_SOURCES = \ risk.cc optimize_test_SOURCES = optimize_test.cc -optimize_test_LDADD = libtraining_utils.a $(top_srcdir)/utils/libutils.a -lz +optimize_test_LDADD = libtraining_utils.a $(top_srcdir)/utils/libutils.a lbfgs_test_SOURCES = lbfgs_test.cc -lbfgs_test_LDADD = $(top_srcdir)/utils/libutils.a -lz +lbfgs_test_LDADD = $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/decoder -I$(top_srcdir)/utils -I$(top_srcdir)/mteval -I$(top_srcdir)/klm diff --git a/utils/Makefile.am b/utils/Makefile.am index 3ad9d69e..639c30b8 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -33,24 +33,24 @@ if HAVE_CMPH endif reconstruct_weights_SOURCES = reconstruct_weights.cc -reconstruct_weights_LDADD = libutils.a -lz +reconstruct_weights_LDADD = libutils.a atools_SOURCES = atools.cc -atools_LDADD = libutils.a -lz +atools_LDADD = libutils.a phmt_SOURCES = phmt.cc -phmt_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +phmt_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) ts_SOURCES = ts.cc -ts_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +ts_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) m_test_SOURCES = m_test.cc -m_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +m_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) dict_test_SOURCES = dict_test.cc -dict_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +dict_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) weights_test_SOURCES = weights_test.cc -weights_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +weights_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) logval_test_SOURCES = logval_test.cc -logval_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +logval_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) small_vector_test_SOURCES = small_vector_test.cc -small_vector_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) -lz +small_vector_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) ################################################################ # do NOT NOT NOT add any other -I includes NO NO NO NO NO ###### diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am index 280d3ae7..2dcb688e 100644 --- a/word-aligner/Makefile.am +++ b/word-aligner/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = fast_align fast_align_SOURCES = fast_align.cc ttables.cc -fast_align_LDADD = $(top_srcdir)/utils/libutils.a -lz +fast_align_LDADD = $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/training -- cgit v1.2.3 From 41bc60a856dc2d0bf9659b443c0cd03be8016db7 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 8 Jan 2013 15:44:45 -0500 Subject: add header files to sources to create correct distributions --- configure.ac | 5 ++-- decoder/Makefile.am | 70 +++++++++++++++++++++++++++++++++++++++++++ klm/lm/Makefile.am | 27 +++++++++++++++++ klm/search/Makefile.am | 14 ++++++++- klm/util/Makefile.am | 20 +++++++++++++ mteval/Makefile.am | 11 +++++++ training/crf/Makefile.am | 4 +-- training/dpmert/Makefile.am | 6 ++-- training/dtrain/Makefile.am | 2 +- training/liblbfgs/Makefile.am | 9 +++++- training/utils/Makefile.am | 7 +++++ utils/Makefile.am | 48 +++++++++++++++++++++++++++++ word-aligner/Makefile.am | 2 +- 13 files changed, 214 insertions(+), 11 deletions(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index eabb8645..dcd0a0d8 100644 --- a/configure.ac +++ b/configure.ac @@ -1,5 +1,6 @@ -AC_INIT -AM_INIT_AUTOMAKE(cdec,0.1) +AC_INIT([cdec],[1.0]) +AC_CONFIG_SRCDIR([decoder/cdec.cc]) +AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) AC_PROG_LIBTOOL AC_PROG_LEX diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 88a6116c..21187da8 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -26,7 +26,77 @@ rule_lexer.cc: rule_lexer.ll noinst_LIBRARIES = libcdec.a +EXTRA_DIST = rule_lexer.ll + libcdec_a_SOURCES = \ + JSON_parser.h \ + aligner.h \ + apply_fsa_models.h \ + apply_models.h \ + bottom_up_parser.h \ + cfg.h \ + cfg_binarize.h \ + cfg_format.h \ + cfg_options.h \ + csplit.h \ + decoder.h \ + dwarf.h \ + earley_composer.h \ + exp_semiring.h \ + factored_lexicon_helper.h \ + ff.h \ + ff_basic.h \ + ff_bleu.h \ + ff_charset.h \ + ff_context.h \ + ff_csplit.h \ + ff_dwarf.h \ + ff_external.h \ + ff_factory.h \ + ff_klm.h \ + ff_lm.h \ + ff_ngrams.h \ + ff_register.h \ + ff_rules.h \ + ff_ruleshape.h \ + ff_sample_fsa.h \ + ff_source_syntax.h \ + ff_spans.h \ + ff_tagger.h \ + ff_wordalign.h \ + ff_wordset.h \ + ffset.h \ + forest_writer.h \ + freqdict.h \ + grammar.h \ + hg.h \ + hg_cfg.h \ + hg_intersect.h \ + hg_io.h \ + hg_remove_eps.h \ + hg_sampler.h \ + hg_test.h \ + hg_union.h \ + incremental.h \ + inside_outside.h \ + json_parse.h \ + kbest.h \ + lattice.h \ + lexalign.h \ + lextrans.h \ + nt_span.h \ + oracle_bleu.h \ + phrasebased_translator.h \ + phrasetable_fst.h \ + program_options.h \ + rule_lexer.h \ + sentence_metadata.h \ + sentences.h \ + tagger.h \ + translator.h \ + tromble_loss.h \ + trule.h \ + viterbi.h \ forest_writer.cc \ maxtrans_blunsom.cc \ cdec_ff.cc \ diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index a12c5f03..436cfd08 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -12,6 +12,33 @@ build_binary_LDADD = libklm.a ../util/libklm_util.a -lz noinst_LIBRARIES = libklm.a libklm_a_SOURCES = \ + bhiksha.hh \ + binary_format.hh \ + blank.hh \ + config.hh \ + enumerate_vocab.hh \ + facade.hh \ + left.hh \ + lm_exception.hh \ + max_order.hh \ + model.hh \ + model_type.hh \ + ngram_query.hh \ + partial.hh \ + quantize.hh \ + read_arpa.hh \ + return.hh \ + search_hashed.hh \ + search_trie.hh \ + state.hh \ + trie.hh \ + trie_sort.hh \ + value.hh \ + value_build.hh \ + virtual_interface.hh \ + vocab.hh \ + weights.hh \ + word_index.hh \ bhiksha.cc \ binary_format.cc \ config.cc \ diff --git a/klm/search/Makefile.am b/klm/search/Makefile.am index 5aea33c2..a34f6cea 100644 --- a/klm/search/Makefile.am +++ b/klm/search/Makefile.am @@ -1,11 +1,23 @@ noinst_LIBRARIES = libksearch.a libksearch_a_SOURCES = \ + applied.hh \ + config.hh \ + context.hh \ + dedupe.hh \ + edge.hh \ + edge_generator.hh \ + header.hh \ + nbest.hh \ + rule.hh \ + types.hh \ + vertex.hh \ + vertex_generator.hh \ edge_generator.cc \ nbest.cc \ rule.cc \ vertex.cc \ vertex_generator.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I.. +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I.. diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index a676bdb3..bb441432 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -19,6 +19,26 @@ noinst_LIBRARIES = libklm_util.a libklm_util_a_SOURCES = \ + bit_packing.hh \ + ersatz_progress.hh \ + exception.hh \ + file.hh \ + file_piece.hh \ + getopt.hh \ + have.hh \ + joint_sort.hh \ + mmap.hh \ + murmur_hash.hh \ + pool.hh \ + probing_hash_table.hh \ + proxy_iterator.hh \ + read_compressed.hh \ + scoped.hh \ + sized_iterator.hh \ + sorted_uniform.hh \ + string_piece.hh \ + tokenize_piece.hh \ + usage.hh \ ersatz_progress.cc \ bit_packing.cc \ exception.cc \ diff --git a/mteval/Makefile.am b/mteval/Makefile.am index 4444285f..b19e4bb1 100644 --- a/mteval/Makefile.am +++ b/mteval/Makefile.am @@ -9,6 +9,17 @@ TESTS = scorer_test noinst_LIBRARIES = libmteval.a libmteval_a_SOURCES = \ + aer_scorer.h \ + comb_scorer.h \ + external_scorer.h \ + ns.h \ + ns_cer.h \ + ns_comb.h \ + ns_docscorer.h \ + ns_ext.h \ + ns_ter.h \ + scorer.h \ + ter.h \ aer_scorer.cc \ comb_scorer.cc \ external_scorer.cc \ diff --git a/training/crf/Makefile.am b/training/crf/Makefile.am index d203df25..f72d8f92 100644 --- a/training/crf/Makefile.am +++ b/training/crf/Makefile.am @@ -18,10 +18,10 @@ mpi_extract_reachable_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/ mpi_extract_features_SOURCES = mpi_extract_features.cc mpi_extract_features_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz -mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc +mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc cllh_observer.h mpi_batch_optimize_LDADD = $(top_srcdir)/training/utils/libtraining_utils.a $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz -mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc +mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc cllh_observer.h mpi_compute_cllh_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir)/training -I$(top_srcdir)/training/utils -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dpmert/Makefile.am b/training/dpmert/Makefile.am index 3dbdfa69..e5f13944 100644 --- a/training/dpmert/Makefile.am +++ b/training/dpmert/Makefile.am @@ -13,13 +13,13 @@ mr_dpmert_generate_mapper_input_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_sr # nbest2hg_SOURCES = nbest2hg.cc # nbest2hg_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lfst -mr_dpmert_map_SOURCES = mert_geometry.cc ces.cc error_surface.cc mr_dpmert_map.cc line_optimizer.cc +mr_dpmert_map_SOURCES = mert_geometry.cc ces.cc error_surface.cc mr_dpmert_map.cc line_optimizer.cc ces.h error_surface.h line_optimizer.h mert_geometry.h mr_dpmert_map_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -mr_dpmert_reduce_SOURCES = error_surface.cc ces.cc mr_dpmert_reduce.cc line_optimizer.cc mert_geometry.cc +mr_dpmert_reduce_SOURCES = error_surface.cc ces.cc mr_dpmert_reduce.cc line_optimizer.cc mert_geometry.cc ces.h error_surface.h line_optimizer.h mert_geometry.h mr_dpmert_reduce_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a -lo_test_SOURCES = lo_test.cc ces.cc mert_geometry.cc error_surface.cc line_optimizer.cc +lo_test_SOURCES = lo_test.cc ces.cc mert_geometry.cc error_surface.cc line_optimizer.cc ces.h error_surface.h line_optimizer.h mert_geometry.h lo_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 4f51b0c8..ee337ca8 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = dtrain -dtrain_SOURCES = dtrain.cc score.cc +dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h dtrain_LDADD = $(top_srcdir)/decoder/libcdec.a $(top_srcdir)/klm/search/libksearch.a $(top_srcdir)/mteval/libmteval.a $(top_srcdir)/utils/libutils.a $(top_srcdir)/klm/lm/libklm.a $(top_srcdir)/klm/util/libklm_util.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/liblbfgs/Makefile.am b/training/liblbfgs/Makefile.am index 64a3794d..f0d5c8aa 100644 --- a/training/liblbfgs/Makefile.am +++ b/training/liblbfgs/Makefile.am @@ -6,10 +6,17 @@ ll_test_LDADD = liblbfgs.a -lz noinst_LIBRARIES = liblbfgs.a -liblbfgs_a_SOURCES = lbfgs.c +liblbfgs_a_SOURCES = \ + lbfgs.c \ + arithmetic_ansi.h \ + arithmetic_sse_double.h \ + arithmetic_sse_float.h \ + lbfgs++.h \ + lbfgs.h ################################################################ # do NOT NOT NOT add any other -I includes NO NO NO NO NO ###### AM_LDFLAGS = liblbfgs.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -I. -I.. ################################################################ + diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am index d708a9f5..a2ab86fd 100644 --- a/training/utils/Makefile.am +++ b/training/utils/Makefile.am @@ -18,6 +18,13 @@ sentclient_LDFLAGS = -pthread TESTS = lbfgs_test optimize_test libtraining_utils_a_SOURCES = \ + candidate_set.h \ + entropy.h \ + lbfgs.h \ + online_optimizer.h \ + optimize.h \ + risk.h \ + sentserver.h \ candidate_set.cc \ entropy.cc \ optimize.cc \ diff --git a/utils/Makefile.am b/utils/Makefile.am index 639c30b8..3177325b 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -14,6 +14,53 @@ TESTS = ts small_vector_test logval_test weights_test dict_test m_test noinst_LIBRARIES = libutils.a libutils_a_SOURCES = \ + alias_sampler.h \ + alignment_io.h \ + array2d.h \ + b64tools.h \ + batched_append.h \ + city.h \ + citycrc.h \ + corpus_tools.h \ + dict.h \ + fast_sparse_vector.h \ + fdict.h \ + feature_vector.h \ + filelib.h \ + gzstream.h \ + hash.h \ + have_64_bits.h \ + indices_after.h \ + kernel_string_subseq.h \ + logval.h \ + m.h \ + murmur_hash.h \ + named_enum.h \ + null_deleter.h \ + null_traits.h \ + perfect_hash.h \ + prob.h \ + sampler.h \ + semiring.h \ + show.h \ + small_vector.h \ + sparse_vector.h \ + static_utoa.h \ + stringlib.h \ + swap_pod.h \ + tdict.h \ + timing_stats.h \ + utoa.h \ + value_array.h \ + verbose.h \ + warning_compiler.h \ + warning_pop.h \ + warning_push.h \ + weights.h \ + wordid.h \ + writer.h \ + fast_lexical_cast.hpp \ + intrusive_refcount.hpp \ alignment_io.cc \ b64tools.cc \ corpus_tools.cc \ @@ -56,3 +103,4 @@ small_vector_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOS # do NOT NOT NOT add any other -I includes NO NO NO NO NO ###### AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -I. ################################################################ + diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am index 2dcb688e..e274b209 100644 --- a/word-aligner/Makefile.am +++ b/word-aligner/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = fast_align -fast_align_SOURCES = fast_align.cc ttables.cc +fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h fast_align_LDADD = $(top_srcdir)/utils/libutils.a AM_CPPFLAGS = -W -Wall $(GTEST_CPPFLAGS) -I$(top_srcdir)/utils -I$(top_srcdir)/training -- cgit v1.2.3 From 9d7167751a3712a79ad356764d803106a71ce5e3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Tue, 15 Jan 2013 01:20:00 -0500 Subject: corpus files --- Makefile.am | 1 + configure.ac | 2 +- corpus/add-self-translations.pl | 2 +- corpus/filter-length.pl | 6 ++++-- corpus/paste-files.pl | 12 +++++++++++- 5 files changed, 18 insertions(+), 5 deletions(-) (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index dbf604a1..1d898156 100644 --- a/Makefile.am +++ b/Makefile.am @@ -15,6 +15,7 @@ SUBDIRS = \ #gi/pyp-topics/src gi/clda/src gi/posterior-regularisation/prjava +EXTRA_DIST = python/pkg python/src python/tests python/examples AUTOMAKE_OPTIONS = foreign ACLOCAL_AMFLAGS = -I m4 AM_CPPFLAGS = -D_GLIBCXX_PARALLEL diff --git a/configure.ac b/configure.ac index dcd0a0d8..69971dc3 100644 --- a/configure.ac +++ b/configure.ac @@ -1,4 +1,4 @@ -AC_INIT([cdec],[1.0]) +AC_INIT([cdec],[2013-01-15]) AC_CONFIG_SRCDIR([decoder/cdec.cc]) AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) diff --git a/corpus/add-self-translations.pl b/corpus/add-self-translations.pl index 153bc454..d707ce29 100755 --- a/corpus/add-self-translations.pl +++ b/corpus/add-self-translations.pl @@ -6,7 +6,7 @@ use strict; my %df; my %def; while(<>) { - print; +# print; chomp; my ($sf, $se) = split / \|\|\| /; die "Format error: $_\n" unless defined $sf && defined $se; diff --git a/corpus/filter-length.pl b/corpus/filter-length.pl index 70032ca7..3cfa40cc 100755 --- a/corpus/filter-length.pl +++ b/corpus/filter-length.pl @@ -3,8 +3,8 @@ use strict; use utf8; ##### EDIT THESE SETTINGS #################################################### -my $MAX_LENGTH = 99; # discard a sentence if it is longer than this -my $AUTOMATIC_INCLUDE_IF_SHORTER_THAN = 6; # if both are shorter, include +my $MAX_LENGTH = 150; # discard a sentence if it is longer than this +my $AUTOMATIC_INCLUDE_IF_SHORTER_THAN = 7; # if both are shorter, include my $MAX_ZSCORE = 1.8; # how far from the mean can the (log)ratio be? ############################################################################## @@ -128,6 +128,8 @@ while() { next; } print; + } else { + print; } $to++; } diff --git a/corpus/paste-files.pl b/corpus/paste-files.pl index 24c70599..0b788386 100755 --- a/corpus/paste-files.pl +++ b/corpus/paste-files.pl @@ -17,6 +17,7 @@ for my $file (@ARGV) { binmode(STDOUT,":utf8"); binmode(STDERR,":utf8"); +my $bad = 0; my $lc = 0; my $done = 0; my $fl = 0; @@ -34,7 +35,15 @@ while(1) { last; } chomp $r; - die "$ARGV[$anum]:$lc contains a ||| symbol - please remove.\n" if $r =~ /\|\|\|/; + if ($r =~ /\|\|\|/) { + $r = ''; + $bad++; + } + warn "$ARGV[$anum]:$lc contains a ||| symbol - please remove.\n" if $r =~ /\|\|\|/; + $r =~ s/\|\|\|/ /g; + $r =~ s/ +//g; + $r =~ s/^ //; + $r =~ s/ $//; $anum++; push @line, $r; } @@ -47,4 +56,5 @@ for (my $i = 1; $i < scalar @fhs; $i++) { my $r = <$fh>; die "Mismatched number of lines.\n" if defined $r; } +print STDERR "Bad lines containing ||| were $bad\n"; -- cgit v1.2.3 From 0b9031042500d45a098762f0a930bd6a66a58fac Mon Sep 17 00:00:00 2001 From: Kenneth Heafield Date: Fri, 18 Jan 2013 17:12:51 +0000 Subject: KenLM dffafbf with lmplz source (but not built) --- Makefile.am | 1 + configure.ac | 1 + decoder/Makefile.am | 2 +- klm/LICENSE | 8 +- klm/README | 31 - klm/README.md | 105 +++ klm/clean.sh | 2 - klm/compile.sh | 15 - klm/lm/Makefile.am | 4 +- klm/lm/build_binary.cc | 77 +- klm/lm/builder/README.md | 47 ++ klm/lm/builder/TODO | 5 + klm/lm/builder/adjust_counts.cc | 216 ++++++ klm/lm/builder/adjust_counts.hh | 44 ++ klm/lm/builder/adjust_counts_test.cc | 106 +++ klm/lm/builder/corpus_count.cc | 223 ++++++ klm/lm/builder/corpus_count.hh | 42 ++ klm/lm/builder/corpus_count_test.cc | 76 ++ klm/lm/builder/discount.hh | 26 + klm/lm/builder/header_info.hh | 20 + klm/lm/builder/initial_probabilities.cc | 136 ++++ klm/lm/builder/initial_probabilities.hh | 34 + klm/lm/builder/interpolate.cc | 65 ++ klm/lm/builder/interpolate.hh | 27 + klm/lm/builder/joint_order.hh | 43 ++ klm/lm/builder/main.cc | 94 +++ klm/lm/builder/multi_stream.hh | 180 +++++ klm/lm/builder/ngram.hh | 84 +++ klm/lm/builder/ngram_stream.hh | 55 ++ klm/lm/builder/pipeline.cc | 320 +++++++++ klm/lm/builder/pipeline.hh | 40 ++ klm/lm/builder/print.cc | 135 ++++ klm/lm/builder/print.hh | 102 +++ klm/lm/builder/sort.hh | 103 +++ klm/lm/filter/arpa_io.cc | 122 ++++ klm/lm/filter/arpa_io.hh | 122 ++++ klm/lm/filter/count_io.hh | 91 +++ klm/lm/filter/format.hh | 250 +++++++ klm/lm/filter/main.cc | 249 +++++++ klm/lm/filter/phrase.cc | 281 ++++++++ klm/lm/filter/phrase.hh | 153 ++++ klm/lm/filter/thread.hh | 167 +++++ klm/lm/filter/vocab.cc | 54 ++ klm/lm/filter/vocab.hh | 132 ++++ klm/lm/filter/wrapper.hh | 58 ++ klm/lm/model_test.cc | 10 +- klm/lm/read_arpa.cc | 11 +- klm/lm/sizes.cc | 63 ++ klm/lm/sizes.hh | 17 + klm/lm/state.hh | 6 +- klm/lm/trie_sort.cc | 27 +- klm/lm/trie_sort.hh | 3 +- klm/search/Jamfile | 5 - klm/test.sh | 8 - klm/util/Makefile.am | 5 +- klm/util/double-conversion/LICENSE | 26 + klm/util/double-conversion/Makefile.am | 23 + klm/util/double-conversion/bignum-dtoa.cc | 640 +++++++++++++++++ klm/util/double-conversion/bignum-dtoa.h | 84 +++ klm/util/double-conversion/bignum.cc | 764 ++++++++++++++++++++ klm/util/double-conversion/bignum.h | 145 ++++ klm/util/double-conversion/cached-powers.cc | 175 +++++ klm/util/double-conversion/cached-powers.h | 64 ++ klm/util/double-conversion/diy-fp.cc | 57 ++ klm/util/double-conversion/diy-fp.h | 118 ++++ klm/util/double-conversion/double-conversion.cc | 889 ++++++++++++++++++++++++ klm/util/double-conversion/double-conversion.h | 536 ++++++++++++++ klm/util/double-conversion/fast-dtoa.cc | 664 ++++++++++++++++++ klm/util/double-conversion/fast-dtoa.h | 88 +++ klm/util/double-conversion/fixed-dtoa.cc | 402 +++++++++++ klm/util/double-conversion/fixed-dtoa.h | 56 ++ klm/util/double-conversion/ieee.h | 398 +++++++++++ klm/util/double-conversion/strtod.cc | 554 +++++++++++++++ klm/util/double-conversion/strtod.h | 45 ++ klm/util/double-conversion/utils.h | 313 +++++++++ klm/util/ersatz_progress.cc | 6 +- klm/util/ersatz_progress.hh | 4 +- klm/util/exception.cc | 5 - klm/util/exception.hh | 42 +- klm/util/file.cc | 196 +++++- klm/util/file.hh | 54 +- klm/util/file_piece.cc | 66 +- klm/util/file_piece.hh | 2 +- klm/util/file_piece_test.cc | 1 + klm/util/have.hh | 4 - klm/util/multi_intersection.hh | 80 +++ klm/util/multi_intersection_test.cc | 63 ++ klm/util/pcqueue.hh | 105 +++ klm/util/pool.cc | 5 +- klm/util/probing_hash_table.hh | 5 + klm/util/read_compressed.cc | 2 +- klm/util/scoped.cc | 29 + klm/util/scoped.hh | 17 +- klm/util/stream/block.hh | 43 ++ klm/util/stream/chain.cc | 155 +++++ klm/util/stream/chain.hh | 198 ++++++ klm/util/stream/config.hh | 32 + klm/util/stream/io.cc | 64 ++ klm/util/stream/io.hh | 76 ++ klm/util/stream/io_test.cc | 38 + klm/util/stream/line_input.cc | 52 ++ klm/util/stream/line_input.hh | 22 + klm/util/stream/multi_progress.cc | 86 +++ klm/util/stream/multi_progress.hh | 90 +++ klm/util/stream/sort.hh | 542 +++++++++++++++ klm/util/stream/sort_test.cc | 62 ++ klm/util/stream/stream.hh | 74 ++ klm/util/stream/stream_test.cc | 35 + klm/util/stream/timer.hh | 14 + klm/util/thread_pool.hh | 95 +++ klm/util/usage.cc | 60 ++ klm/util/usage.hh | 10 + training/crf/Makefile.am | 12 +- training/dtrain/Makefile.am | 2 +- training/mira/Makefile.am | 2 +- 115 files changed, 12537 insertions(+), 257 deletions(-) delete mode 100644 klm/README create mode 100644 klm/README.md delete mode 100755 klm/clean.sh delete mode 100755 klm/compile.sh create mode 100644 klm/lm/builder/README.md create mode 100644 klm/lm/builder/TODO create mode 100644 klm/lm/builder/adjust_counts.cc create mode 100644 klm/lm/builder/adjust_counts.hh create mode 100644 klm/lm/builder/adjust_counts_test.cc create mode 100644 klm/lm/builder/corpus_count.cc create mode 100644 klm/lm/builder/corpus_count.hh create mode 100644 klm/lm/builder/corpus_count_test.cc create mode 100644 klm/lm/builder/discount.hh create mode 100644 klm/lm/builder/header_info.hh create mode 100644 klm/lm/builder/initial_probabilities.cc create mode 100644 klm/lm/builder/initial_probabilities.hh create mode 100644 klm/lm/builder/interpolate.cc create mode 100644 klm/lm/builder/interpolate.hh create mode 100644 klm/lm/builder/joint_order.hh create mode 100644 klm/lm/builder/main.cc create mode 100644 klm/lm/builder/multi_stream.hh create mode 100644 klm/lm/builder/ngram.hh create mode 100644 klm/lm/builder/ngram_stream.hh create mode 100644 klm/lm/builder/pipeline.cc create mode 100644 klm/lm/builder/pipeline.hh create mode 100644 klm/lm/builder/print.cc create mode 100644 klm/lm/builder/print.hh create mode 100644 klm/lm/builder/sort.hh create mode 100644 klm/lm/filter/arpa_io.cc create mode 100644 klm/lm/filter/arpa_io.hh create mode 100644 klm/lm/filter/count_io.hh create mode 100644 klm/lm/filter/format.hh create mode 100644 klm/lm/filter/main.cc create mode 100644 klm/lm/filter/phrase.cc create mode 100644 klm/lm/filter/phrase.hh create mode 100644 klm/lm/filter/thread.hh create mode 100644 klm/lm/filter/vocab.cc create mode 100644 klm/lm/filter/vocab.hh create mode 100644 klm/lm/filter/wrapper.hh create mode 100644 klm/lm/sizes.cc create mode 100644 klm/lm/sizes.hh delete mode 100644 klm/search/Jamfile delete mode 100755 klm/test.sh create mode 100644 klm/util/double-conversion/LICENSE create mode 100644 klm/util/double-conversion/Makefile.am create mode 100644 klm/util/double-conversion/bignum-dtoa.cc create mode 100644 klm/util/double-conversion/bignum-dtoa.h create mode 100644 klm/util/double-conversion/bignum.cc create mode 100644 klm/util/double-conversion/bignum.h create mode 100644 klm/util/double-conversion/cached-powers.cc create mode 100644 klm/util/double-conversion/cached-powers.h create mode 100644 klm/util/double-conversion/diy-fp.cc create mode 100644 klm/util/double-conversion/diy-fp.h create mode 100644 klm/util/double-conversion/double-conversion.cc create mode 100644 klm/util/double-conversion/double-conversion.h create mode 100644 klm/util/double-conversion/fast-dtoa.cc create mode 100644 klm/util/double-conversion/fast-dtoa.h create mode 100644 klm/util/double-conversion/fixed-dtoa.cc create mode 100644 klm/util/double-conversion/fixed-dtoa.h create mode 100644 klm/util/double-conversion/ieee.h create mode 100644 klm/util/double-conversion/strtod.cc create mode 100644 klm/util/double-conversion/strtod.h create mode 100644 klm/util/double-conversion/utils.h create mode 100644 klm/util/multi_intersection.hh create mode 100644 klm/util/multi_intersection_test.cc create mode 100644 klm/util/pcqueue.hh create mode 100644 klm/util/scoped.cc create mode 100644 klm/util/stream/block.hh create mode 100644 klm/util/stream/chain.cc create mode 100644 klm/util/stream/chain.hh create mode 100644 klm/util/stream/config.hh create mode 100644 klm/util/stream/io.cc create mode 100644 klm/util/stream/io.hh create mode 100644 klm/util/stream/io_test.cc create mode 100644 klm/util/stream/line_input.cc create mode 100644 klm/util/stream/line_input.hh create mode 100644 klm/util/stream/multi_progress.cc create mode 100644 klm/util/stream/multi_progress.hh create mode 100644 klm/util/stream/sort.hh create mode 100644 klm/util/stream/sort_test.cc create mode 100644 klm/util/stream/stream.hh create mode 100644 klm/util/stream/stream_test.cc create mode 100644 klm/util/stream/timer.hh create mode 100644 klm/util/thread_pool.hh (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index 1d898156..c2444928 100644 --- a/Makefile.am +++ b/Makefile.am @@ -4,6 +4,7 @@ SUBDIRS = \ utils \ mteval \ + klm/util/double-conversion \ klm/util \ klm/lm \ klm/search \ diff --git a/configure.ac b/configure.ac index 69971dc3..d6030752 100644 --- a/configure.ac +++ b/configure.ac @@ -110,6 +110,7 @@ AC_CONFIG_FILES([python/setup.py]) AC_CONFIG_FILES([word-aligner/Makefile]) # KenLM stuff +AC_CONFIG_FILES([klm/util/double-conversion/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) AC_CONFIG_FILES([klm/search/Makefile]) diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 558aeaed..6499b38b 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -17,7 +17,7 @@ trule_test_SOURCES = trule_test.cc trule_test_LDADD = $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_UNIT_TEST_FRAMEWORK_LIBS) libcdec.a ../mteval/libmteval.a ../utils/libutils.a cdec_SOURCES = cdec.cc -cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a +cdec_LDADD = libcdec.a ../mteval/libmteval.a ../utils/libutils.a ../klm/search/libksearch.a ../klm/lm/libklm.a ../klm/util/libklm_util.a ../klm/util/double-conversion/libklm_util_double.a AM_CPPFLAGS = -DTEST_DATA=\"$(top_srcdir)/decoder/test_data\" -DBOOST_TEST_DYN_LINK -W -Wno-sign-compare -I$(top_srcdir) -I$(top_srcdir)/mteval -I$(top_srcdir)/utils -I$(top_srcdir)/klm diff --git a/klm/LICENSE b/klm/LICENSE index 20b76c13..38b6865c 100644 --- a/klm/LICENSE +++ b/klm/LICENSE @@ -1,16 +1,18 @@ Most of the code here is licensed under the LGPL. There are exceptions which have their own licenses, listed below. See comments in those files for more details. util/murmur_hash.cc is under the MIT license. -util/string_piece.hh and util/string_piece.cc are Google code and contains its own license. +util/string_piece.hh, util/string_piece.cc, and util/double-conversion are from Google; see the licenses in the code or util/double-conversion/LICENSE +util/file.cc contains a modified implementation of mkstemp under the LGPL. +jam-files/LICENSE_1_0.txt covers most of that directory (except sanity.jam which is mine). For the rest: - Avenue code is free software: you can redistribute it and/or modify + KenLM is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. - Avenue code is distributed in the hope that it will be useful, + KenLM is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. diff --git a/klm/README b/klm/README deleted file mode 100644 index 8d1050a8..00000000 --- a/klm/README +++ /dev/null @@ -1,31 +0,0 @@ -Language model inference code by Kenneth Heafield -See LICENSE for list of files by other people and their licenses. - -Compile: ./compile.sh -Run: ./query lm/test.arpa = 1.36.0 (and preferably newer). Compile with +```bash +./bjam +``` +If you don't have boost and only need the query code, compile with +```bash +./compile_query_only.sh +``` + +## Estimation +lmplz estimates unpruned language models with modified Kneser-Ney smoothing. After compiling with bjam, run +```bash +bin/lmplz -o 5 text.arpa +``` +The algorithm is on-disk, using an amount of memory that you specify. See http://kheafield.com/code/kenlm/estimation/ for more. + +MT Marathon 2012 team members Ivan Pouzyrevsky and Mohammed Mediani contributed to the computation design and early implementation. Jon Clark contributed to the design, clarified points about smoothing, and added logging. + +## Filtering + +filter takes an ARPA or count file and removes entries that will never be queried. The filter criterion can be corpus-level vocabulary, sentence-level vocabulary, or sentence-level phrases. Run +```bash +bin/filter +``` +and see http://kheafield.com/code/kenlm/filter.html for more documentation. + +## Querying + +Two data structures are supported: probing and trie. Probing is a probing hash table with keys that are 64-bit hashes of n-grams and floats as values. Trie is a fairly standard trie but with bit-level packing so it uses the minimum number of bits to store word indices and pointers. The trie node entries are sorted by word index. Probing is the fastest and uses the most memory. Trie uses the least memory and a bit slower. + +With trie, resident memory is 58% of IRST's smallest version and 21% of SRI's compact version. Simultaneously, trie CPU's use is 81% of IRST's fastest version and 84% of SRI's fast version. KenLM's probing hash table implementation goes even faster at the expense of using more memory. See http://kheafield.com/code/kenlm/benchmark/. + +Binary format via mmap is supported. Run `./build_binary` to make one then pass the binary file name to the appropriate Model constructor. + +## Platforms +`murmur_hash.cc` and `bit_packing.hh` perform unaligned reads and writes that make the code architecture-dependent. +It has been sucessfully tested on x86\_64, x86, and PPC64. +ARM support is reportedly working, at least on the iphone. + +Runs on Linux, OS X, Cygwin, and MinGW. + +Hideo Okuma and Tomoyuki Yoshimura from NICT contributed ports to ARM and MinGW. + +## Compile-time configuration +There are a number of macros you can set on the g++ command line or in util/have.hh . + +* `KENLM_MAX_ORDER` is the maximum order that can be loaded. This is done to make state an efficient POD rather than a vector. +* `HAVE_BOOST` enables Boost-style hashing of StringPiece. This is only needed if you intend to hash StringPiece in your code. +* `HAVE_ICU` If your code links against ICU, define this to disable the internal StringPiece and replace it with ICU's copy of StringPiece, avoiding naming conflicts. + +ARPA files can be read in compressed format with these options: +* `HAVE_ZLIB` Supports gzip. Link with -lz. I have enabled this by default. +* `HAVE_BZLIB` Supports bzip2. Link with -lbz2. +* `HAVE_XZLIB` Supports xz. Link with -llzma. + +Note that these macros impact only `read_compressed.cc` and `read_compressed_test.cc`. The bjam build system will auto-detect bzip2 and xz support. + +## Decoder developers +- I recommend copying the code and distributing it with your decoder. However, please send improvements upstream. + +- Omit the lm/filter directory if you do not want the language model filter. Only that and tests depend on Boost. + +- Select the macros you want, listed in the previous section. + +- There are two build systems: compile.sh and Jamroot+Jamfile. They're pretty simple and are intended to be reimplemented in your build system. + +- Use either the interface in `lm/model.hh` or `lm/virtual_interface.hh`. Interface documentation is in comments of `lm/virtual_interface.hh` and `lm/model.hh`. + +- There are several possible data structures in `model.hh`. Use `RecognizeBinary` in `binary_format.hh` to determine which one a user has provided. You probably already implement feature functions as an abstract virtual base class with several children. I suggest you co-opt this existing virtual dispatch by templatizing the language model feature implementation on the KenLM model identified by `RecognizeBinary`. This is the strategy used in Moses and cdec. + +- See `lm/config.hh` for run-time tuning options. + +## Contributors +Contributions to KenLM are welcome. Please base your contributions on https://github.com/kpu/kenlm and send pull requests (or I might give you commit access). Downstream copies in Moses and cdec are maintained by overwriting them so do not make changes there. + +## Python module +Contributed by Victor Chahuneau. + +### Installation + +```bash +pip install -e git+https://github.com/kpu/kenlm.git#egg=kenlm +``` + +### Basic Usage +```python +import kenlm +model = kenlm.LanguageModel('lm/test.arpa') +sentence = 'this is a sentence .' +print(model.score(sentence)) +``` + +--- + +The name was Hieu Hoang's idea, not mine. diff --git a/klm/clean.sh b/klm/clean.sh deleted file mode 100755 index e8c6dbde..00000000 --- a/klm/clean.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -rm -rf */*.o query build_binary */*_test lm/test.binary* lm/test.arpa?????? util/file_piece.cc.gz diff --git a/klm/compile.sh b/klm/compile.sh deleted file mode 100755 index 55759f97..00000000 --- a/klm/compile.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -#This is just an example compilation. You should integrate these files into your build system. I can provide boost jam if you want. -#If your code uses ICU, edit util/string_piece.hh and uncomment #define USE_ICU -#I use zlib by default. If you don't want to depend on zlib, remove #define USE_ZLIB from util/file_piece.hh - -#don't need to use if compiling with moses Makefiles already - -set -e - -rm {lm,util}/*.o -for i in util/{bit_packing,ersatz_progress,exception,file_piece,murmur_hash,file,mmap,usage} lm/{bhiksha,binary_format,config,lm_exception,model,quantize,read_arpa,search_hashed,search_trie,trie,trie_sort,value_build,virtual_interface,vocab}; do - g++ -I. -O3 -DNDEBUG $CXXFLAGS -c $i.cc -o $i.o -done -g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/build_binary.cc {lm,util}/*.o -lz -o build_binary -g++ -I. -O3 -DNDEBUG $CXXFLAGS lm/ngram_query.cc {lm,util}/*.o -lz -o query diff --git a/klm/lm/Makefile.am b/klm/lm/Makefile.am index 870f7128..f15cbd77 100644 --- a/klm/lm/Makefile.am +++ b/klm/lm/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = build_binary build_binary_SOURCES = build_binary.cc -build_binary_LDADD = libklm.a ../util/libklm_util.a -lz +build_binary_LDADD = libklm.a ../util/libklm_util.a ../util/double-conversion/libklm_util_double.a -lz #noinst_PROGRAMS = \ # ngram_test @@ -30,6 +30,7 @@ libklm_a_SOURCES = \ return.hh \ search_hashed.hh \ search_trie.hh \ + sizes.hh \ state.hh \ trie.hh \ trie_sort.hh \ @@ -49,6 +50,7 @@ libklm_a_SOURCES = \ read_arpa.cc \ search_hashed.cc \ search_trie.cc \ + sizes.cc \ trie.cc \ trie_sort.cc \ value_build.cc \ diff --git a/klm/lm/build_binary.cc b/klm/lm/build_binary.cc index 2b8c9d5b..ab2c0c32 100644 --- a/klm/lm/build_binary.cc +++ b/klm/lm/build_binary.cc @@ -1,10 +1,14 @@ #include "lm/model.hh" +#include "lm/sizes.hh" #include "util/file_piece.hh" +#include "util/usage.hh" +#include #include #include #include #include +#include #include #include @@ -19,8 +23,8 @@ namespace lm { namespace ngram { namespace { -void Usage(const char *name) { - std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-t trie_temporary] [-m trie_building_megabytes] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" +void Usage(const char *name, const char *default_mem) { + std::cerr << "Usage: " << name << " [-u log10_unknown_probability] [-s] [-i] [-w mmap|after] [-p probing_multiplier] [-T trie_temporary] [-S trie_building_mem] [-q bits] [-b bits] [-a bits] [type] input.arpa [output.mmap]\n\n" "-u sets the log10 probability for if the ARPA file does not have one.\n" " Default is -100. The ARPA file will always take precedence.\n" "-s allows models to be built even if they do not have and .\n" @@ -38,8 +42,11 @@ void Usage(const char *name) { "trie is a straightforward trie with bit-level packing. It uses the least\n" "memory and is still faster than SRI or IRST. Building the trie format uses an\n" "on-disk sort to save memory.\n" -"-t is the temporary directory prefix. Default is the output file name.\n" -"-m limits memory use for sorting. Measured in MB. Default is 1024MB.\n" +"-T is the temporary directory prefix. Default is the output file name.\n" +"-S determines memory use for sorting. Default is " << default_mem << ". This is compatible\n" +" with GNU sort. The number is followed by a unit: \% for percent of physical\n" +" memory, b for bytes, K for Kilobytes, M for megabytes, then G,T,P,E,Z,Y. \n" +" Default unit is K for Kilobytes.\n" "-q turns quantization on and sets the number of bits (e.g. -q 8).\n" "-b sets backoff quantization bits. Requires -q and defaults to that value.\n" "-a compresses pointers using an array of offsets. The parameter is the\n" @@ -83,47 +90,6 @@ void ParseFileList(const char *from, std::vector &to) { } } -void ShowSizes(const char *file, const lm::ngram::Config &config) { - std::vector counts; - util::FilePiece f(file); - lm::ReadARPACounts(f, counts); - uint64_t sizes[6]; - sizes[0] = ProbingModel::Size(counts, config); - sizes[1] = RestProbingModel::Size(counts, config); - sizes[2] = TrieModel::Size(counts, config); - sizes[3] = QuantTrieModel::Size(counts, config); - sizes[4] = ArrayTrieModel::Size(counts, config); - sizes[5] = QuantArrayTrieModel::Size(counts, config); - uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); - uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); - uint64_t divide; - char prefix; - if (min_length < (1 << 10) * 10) { - prefix = ' '; - divide = 1; - } else if (min_length < (1 << 20) * 10) { - prefix = 'k'; - divide = 1 << 10; - } else if (min_length < (1ULL << 30) * 10) { - prefix = 'M'; - divide = 1 << 20; - } else { - prefix = 'G'; - divide = 1 << 30; - } - long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); - std::cout << "Memory estimate:\ntype "; - // right align bytes. - for (long int i = 0; i < length - 2; ++i) std::cout << ' '; - std::cout << prefix << "B\n" - "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" - "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" - "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" - "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" - "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" - "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; -} - void ProbingQuantizationUnsupported() { std::cerr << "Quantization is only implemented in the trie data structure." << std::endl; exit(1); @@ -136,11 +102,14 @@ void ProbingQuantizationUnsupported() { int main(int argc, char *argv[]) { using namespace lm::ngram; + const char *default_mem = util::GuessPhysicalMemory() ? "80%" : "1G"; + try { bool quantize = false, set_backoff_bits = false, bhiksha = false, set_write_method = false, rest = false; lm::ngram::Config config; + config.building_memory = util::ParseSize(default_mem); int opt; - while ((opt = getopt(argc, argv, "q:b:a:u:p:t:m:w:sir:")) != -1) { + while ((opt = getopt(argc, argv, "q:b:a:u:p:t:T:m:S:w:sir:")) != -1) { switch(opt) { case 'q': config.prob_bits = ParseBitCount(optarg); @@ -161,12 +130,16 @@ int main(int argc, char *argv[]) { case 'p': config.probing_multiplier = ParseFloat(optarg); break; - case 't': + case 't': // legacy + case 'T': config.temporary_directory_prefix = optarg; break; - case 'm': + case 'm': // legacy config.building_memory = ParseUInt(optarg) * 1048576; break; + case 'S': + config.building_memory = std::min(static_cast(std::numeric_limits::max()), util::ParseSize(optarg)); + break; case 'w': set_write_method = true; if (!strcmp(optarg, "mmap")) { @@ -174,7 +147,7 @@ int main(int argc, char *argv[]) { } else if (!strcmp(optarg, "after")) { config.write_method = Config::WRITE_AFTER; } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } break; case 's': @@ -189,7 +162,7 @@ int main(int argc, char *argv[]) { config.rest_function = Config::REST_LOWER; break; default: - Usage(argv[0]); + Usage(argv[0], default_mem); } } if (!quantize && set_backoff_bits) { @@ -212,7 +185,7 @@ int main(int argc, char *argv[]) { from_file = argv[optind + 1]; config.write_mmap = argv[optind + 2]; } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } if (!strcmp(model_type, "probing")) { if (!set_write_method) config.write_method = Config::WRITE_AFTER; @@ -242,7 +215,7 @@ int main(int argc, char *argv[]) { } } } else { - Usage(argv[0]); + Usage(argv[0], default_mem); } } catch (const std::exception &e) { diff --git a/klm/lm/builder/README.md b/klm/lm/builder/README.md new file mode 100644 index 00000000..be0d35e2 --- /dev/null +++ b/klm/lm/builder/README.md @@ -0,0 +1,47 @@ +Dependencies +============ + +Boost >= 1.42.0 is required. + +For Ubuntu, +```bash +sudo apt-get install libboost1.48-all-dev +``` + +Alternatively, you can download, compile, and install it yourself: + +```bash +wget http://sourceforge.net/projects/boost/files/boost/1.52.0/boost_1_52_0.tar.gz/download -O boost_1_52_0.tar.gz +tar -xvzf boost_1_52_0.tar.gz +cd boost_1_52_0 +./bootstrap.sh +./b2 +sudo ./b2 install +``` + +Local install options (in a user-space prefix directory) are also possible. See http://www.boost.org/doc/libs/1_52_0/doc/html/bbv2/installation.html. + + +Building +======== + +```bash +bjam +``` +Your distribution might package bjam and boost-build separately from Boost. Both are required. + +Usage +===== + +Run +```bash +$ bin/lmplz +``` +to see command line arguments + +Running +======= + +```bash +bin/lmplz -o 5 text.arpa +``` diff --git a/klm/lm/builder/TODO b/klm/lm/builder/TODO new file mode 100644 index 00000000..cb5aef3a --- /dev/null +++ b/klm/lm/builder/TODO @@ -0,0 +1,5 @@ +More tests! +Sharding. +Some way to manage all the crazy config options. +Option to build the binary file directly. +Interpolation of different orders. diff --git a/klm/lm/builder/adjust_counts.cc b/klm/lm/builder/adjust_counts.cc new file mode 100644 index 00000000..a6f48011 --- /dev/null +++ b/klm/lm/builder/adjust_counts.cc @@ -0,0 +1,216 @@ +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/multi_stream.hh" +#include "util/stream/timer.hh" + +#include + +namespace lm { namespace builder { + +BadDiscountException::BadDiscountException() throw() {} +BadDiscountException::~BadDiscountException() throw() {} + +namespace { +// Return last word in full that is different. +const WordIndex* FindDifference(const NGram &full, const NGram &lower_last) { + const WordIndex *cur_word = full.end() - 1; + const WordIndex *pre_word = lower_last.end() - 1; + // Find last difference. + for (; pre_word >= lower_last.begin() && *pre_word == *cur_word; --cur_word, --pre_word) {} + return cur_word; +} + +class StatCollector { + public: + StatCollector(std::size_t order, std::vector &counts, std::vector &discounts) + : orders_(order), full_(orders_.back()), counts_(counts), discounts_(discounts) { + memset(&orders_[0], 0, sizeof(OrderStat) * order); + } + + ~StatCollector() {} + + void CalculateDiscounts() { + counts_.resize(orders_.size()); + discounts_.resize(orders_.size()); + for (std::size_t i = 0; i < orders_.size(); ++i) { + const OrderStat &s = orders_[i]; + counts_[i] = s.count; + + for (unsigned j = 1; j < 4; ++j) { + // TODO: Specialize error message for j == 3, meaning 3+ + UTIL_THROW_IF(s.n[j] == 0, BadDiscountException, "Could not calculate Kneser-Ney discounts for " + << (i+1) << "-grams with adjusted count " << (j+1) << " because we didn't observe any " + << (i+1) << "-grams with adjusted count " << j << "; Is this small or artificial data?"); + } + + // See equation (26) in Chen and Goodman. + discounts_[i].amount[0] = 0.0; + float y = static_cast(s.n[1]) / static_cast(s.n[1] + 2.0 * s.n[2]); + for (unsigned j = 1; j < 4; ++j) { + discounts_[i].amount[j] = static_cast(j) - static_cast(j + 1) * y * static_cast(s.n[j+1]) / static_cast(s.n[j]); + UTIL_THROW_IF(discounts_[i].amount[j] < 0.0 || discounts_[i].amount[j] > j, BadDiscountException, "ERROR: " << (i+1) << "-gram discount out of range for adjusted count " << j << ": " << discounts_[i].amount[j]); + } + } + } + + void Add(std::size_t order_minus_1, uint64_t count) { + OrderStat &stat = orders_[order_minus_1]; + ++stat.count; + if (count < 5) ++stat.n[count]; + } + + void AddFull(uint64_t count) { + ++full_.count; + if (count < 5) ++full_.n[count]; + } + + private: + struct OrderStat { + // n_1 in equation 26 of Chen and Goodman etc + uint64_t n[5]; + uint64_t count; + }; + + std::vector orders_; + OrderStat &full_; + + std::vector &counts_; + std::vector &discounts_; +}; + +// Reads all entries in order like NGramStream does. +// But deletes any entries that have in the 1st (not 0th) position on the +// way out by putting other entries in their place. This disrupts the sort +// order but we don't care because the data is going to be sorted again. +class CollapseStream { + public: + CollapseStream(const util::stream::ChainPosition &position) : + current_(NULL, NGram::OrderFromSize(position.GetChain().EntrySize())), + block_(position) { + StartBlock(); + } + + const NGram &operator*() const { return current_; } + const NGram *operator->() const { return ¤t_; } + + operator bool() const { return block_; } + + CollapseStream &operator++() { + assert(block_); + if (current_.begin()[1] == kBOS && current_.Base() < copy_from_) { + memcpy(current_.Base(), copy_from_, current_.TotalSize()); + UpdateCopyFrom(); + } + current_.NextInMemory(); + uint8_t *block_base = static_cast(block_->Get()); + if (current_.Base() == block_base + block_->ValidSize()) { + block_->SetValidSize(copy_from_ + current_.TotalSize() - block_base); + ++block_; + StartBlock(); + } + return *this; + } + + private: + void StartBlock() { + for (; ; ++block_) { + if (!block_) return; + if (block_->ValidSize()) break; + } + current_.ReBase(block_->Get()); + copy_from_ = static_cast(block_->Get()) + block_->ValidSize(); + UpdateCopyFrom(); + } + + // Find last without bos. + void UpdateCopyFrom() { + for (copy_from_ -= current_.TotalSize(); copy_from_ >= current_.Base(); copy_from_ -= current_.TotalSize()) { + if (NGram(copy_from_, current_.Order()).begin()[1] != kBOS) break; + } + } + + NGram current_; + + // Goes backwards in the block + uint8_t *copy_from_; + + util::stream::Link block_; +}; + +} // namespace + +void AdjustCounts::Run(const ChainPositions &positions) { + UTIL_TIMER("(%w s) Adjusted counts\n"); + + const std::size_t order = positions.size(); + StatCollector stats(order, counts_, discounts_); + if (order == 1) { + // Only unigrams. Just collect stats. + for (NGramStream full(positions[0]); full; ++full) + stats.AddFull(full->Count()); + stats.CalculateDiscounts(); + return; + } + + NGramStreams streams; + streams.Init(positions, positions.size() - 1); + CollapseStream full(positions[positions.size() - 1]); + + // Initialization: has count 0 and so does . + NGramStream *lower_valid = streams.begin(); + streams[0]->Count() = 0; + *streams[0]->begin() = kUNK; + stats.Add(0, 0); + (++streams[0])->Count() = 0; + *streams[0]->begin() = kBOS; + // not in stats because it will get put in later. + + // iterate over full (the stream of the highest order ngrams) + for (; full; ++full) { + const WordIndex *different = FindDifference(*full, **lower_valid); + std::size_t same = full->end() - 1 - different; + // Increment the adjusted count. + if (same) ++streams[same - 1]->Count(); + + // Output all the valid ones that changed. + for (; lower_valid >= &streams[same]; --lower_valid) { + stats.Add(lower_valid - streams.begin(), (*lower_valid)->Count()); + ++*lower_valid; + } + + // This is here because bos is also const WordIndex *, so copy gets + // consistent argument types. + const WordIndex *full_end = full->end(); + // Initialize and mark as valid up to bos. + const WordIndex *bos; + for (bos = different; (bos > full->begin()) && (*bos != kBOS); --bos) { + ++lower_valid; + std::copy(bos, full_end, (*lower_valid)->begin()); + (*lower_valid)->Count() = 1; + } + // Now bos indicates where is or is the 0th word of full. + if (bos != full->begin()) { + // There is an beyond the 0th word. + NGramStream &to = *++lower_valid; + std::copy(bos, full_end, to->begin()); + to->Count() = full->Count(); + } else { + stats.AddFull(full->Count()); + } + assert(lower_valid >= &streams[0]); + } + + // Output everything valid. + for (NGramStream *s = streams.begin(); s <= lower_valid; ++s) { + stats.Add(s - streams.begin(), (*s)->Count()); + ++*s; + } + // Poison everyone! Except the N-grams which were already poisoned by the input. + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) + s->Poison(); + + stats.CalculateDiscounts(); + + // NOTE: See special early-return case for unigrams near the top of this function +} + +}} // namespaces diff --git a/klm/lm/builder/adjust_counts.hh b/klm/lm/builder/adjust_counts.hh new file mode 100644 index 00000000..f38ff79d --- /dev/null +++ b/klm/lm/builder/adjust_counts.hh @@ -0,0 +1,44 @@ +#ifndef LM_BUILDER_ADJUST_COUNTS__ +#define LM_BUILDER_ADJUST_COUNTS__ + +#include "lm/builder/discount.hh" +#include "util/exception.hh" + +#include + +#include + +namespace lm { +namespace builder { + +class ChainPositions; + +class BadDiscountException : public util::Exception { + public: + BadDiscountException() throw(); + ~BadDiscountException() throw(); +}; + +/* Compute adjusted counts. + * Input: unique suffix sorted N-grams (and just the N-grams) with raw counts. + * Output: [1,N]-grams with adjusted counts. + * [1,N)-grams are in suffix order + * N-grams are in undefined order (they're going to be sorted anyway). + */ +class AdjustCounts { + public: + AdjustCounts(std::vector &counts, std::vector &discounts) + : counts_(counts), discounts_(discounts) {} + + void Run(const ChainPositions &positions); + + private: + std::vector &counts_; + std::vector &discounts_; +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_ADJUST_COUNTS__ + diff --git a/klm/lm/builder/adjust_counts_test.cc b/klm/lm/builder/adjust_counts_test.cc new file mode 100644 index 00000000..68b5f33e --- /dev/null +++ b/klm/lm/builder/adjust_counts_test.cc @@ -0,0 +1,106 @@ +#include "lm/builder/adjust_counts.hh" + +#include "lm/builder/multi_stream.hh" +#include "util/scoped.hh" + +#include +#define BOOST_TEST_MODULE AdjustCounts +#include + +namespace lm { namespace builder { namespace { + +class KeepCopy { + public: + KeepCopy() : size_(0) {} + + void Run(const util::stream::ChainPosition &position) { + for (util::stream::Link link(position); link; ++link) { + mem_.call_realloc(size_ + link->ValidSize()); + memcpy(static_cast(mem_.get()) + size_, link->Get(), link->ValidSize()); + size_ += link->ValidSize(); + } + } + + uint8_t *Get() { return static_cast(mem_.get()); } + std::size_t Size() const { return size_; } + + private: + util::scoped_malloc mem_; + std::size_t size_; +}; + +struct Gram4 { + WordIndex ids[4]; + uint64_t count; +}; + +class WriteInput { + public: + void Run(const util::stream::ChainPosition &position) { + NGramStream input(position); + Gram4 grams[] = { + {{0,0,0,0},10}, + {{0,0,3,0},3}, + // bos + {{1,1,1,2},5}, + {{0,0,3,2},5}, + }; + for (size_t i = 0; i < sizeof(grams) / sizeof(Gram4); ++i, ++input) { + memcpy(input->begin(), grams[i].ids, sizeof(WordIndex) * 4); + input->Count() = grams[i].count; + } + input.Poison(); + } +}; + +BOOST_AUTO_TEST_CASE(Simple) { + KeepCopy outputs[4]; + std::vector counts; + std::vector discount; + { + util::stream::ChainConfig config; + config.total_memory = 100; + config.block_count = 1; + Chains chains(4); + for (unsigned i = 0; i < 4; ++i) { + config.entry_size = NGram::TotalSize(i + 1); + chains.push_back(config); + } + + chains[3] >> WriteInput(); + ChainPositions for_adjust(chains); + for (unsigned i = 0; i < 4; ++i) { + chains[i] >> boost::ref(outputs[i]); + } + chains >> util::stream::kRecycle; + BOOST_CHECK_THROW(AdjustCounts(counts, discount).Run(for_adjust), BadDiscountException); + } + BOOST_REQUIRE_EQUAL(4UL, counts.size()); + BOOST_CHECK_EQUAL(4UL, counts[0]); + // These are no longer set because the discounts are bad. +/* BOOST_CHECK_EQUAL(4UL, counts[1]); + BOOST_CHECK_EQUAL(3UL, counts[2]); + BOOST_CHECK_EQUAL(3UL, counts[3]);*/ + BOOST_REQUIRE_EQUAL(NGram::TotalSize(1) * 4, outputs[0].Size()); + NGram uni(outputs[0].Get(), 1); + BOOST_CHECK_EQUAL(kUNK, *uni.begin()); + BOOST_CHECK_EQUAL(0ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(kBOS, *uni.begin()); + BOOST_CHECK_EQUAL(0ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(0UL, *uni.begin()); + BOOST_CHECK_EQUAL(2ULL, uni.Count()); + uni.NextInMemory(); + BOOST_CHECK_EQUAL(2ULL, uni.Count()); + BOOST_CHECK_EQUAL(2UL, *uni.begin()); + + BOOST_REQUIRE_EQUAL(NGram::TotalSize(2) * 4, outputs[1].Size()); + NGram bi(outputs[1].Get(), 2); + BOOST_CHECK_EQUAL(0UL, *bi.begin()); + BOOST_CHECK_EQUAL(0UL, *(bi.begin() + 1)); + BOOST_CHECK_EQUAL(1ULL, bi.Count()); + bi.NextInMemory(); +} + +}}} // namespaces diff --git a/klm/lm/builder/corpus_count.cc b/klm/lm/builder/corpus_count.cc new file mode 100644 index 00000000..8c3de57d --- /dev/null +++ b/klm/lm/builder/corpus_count.cc @@ -0,0 +1,223 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/lm_exception.hh" +#include "lm/word_index.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/murmur_hash.hh" +#include "util/probing_hash_table.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" +#include "util/stream/timer.hh" +#include "util/tokenize_piece.hh" + +#include +#include + +#include + +#include + +namespace lm { +namespace builder { +namespace { + +class VocabHandout { + public: + explicit VocabHandout(int fd) { + util::scoped_fd duped(util::DupOrThrow(fd)); + word_list_.reset(util::FDOpenOrThrow(duped)); + + Lookup(""); // Force 0 + Lookup(""); // Force 1 + Lookup(""); // Force 2 + } + + WordIndex Lookup(const StringPiece &word) { + uint64_t hashed = util::MurmurHashNative(word.data(), word.size()); + std::pair ret(seen_.insert(std::pair(hashed, seen_.size()))); + if (ret.second) { + char null_delimit = 0; + util::WriteOrThrow(word_list_.get(), word.data(), word.size()); + util::WriteOrThrow(word_list_.get(), &null_delimit, 1); + UTIL_THROW_IF(seen_.size() >= std::numeric_limits::max(), VocabLoadException, "Too many vocabulary words. Change WordIndex to uint64_t in lm/word_index.hh."); + } + return ret.first->second; + } + + WordIndex Size() const { + return seen_.size(); + } + + private: + typedef boost::unordered_map Seen; + + Seen seen_; + + util::scoped_FILE word_list_; +}; + +class DedupeHash : public std::unary_function { + public: + explicit DedupeHash(std::size_t order) : size_(order * sizeof(WordIndex)) {} + + std::size_t operator()(const WordIndex *start) const { + return util::MurmurHashNative(start, size_); + } + + private: + const std::size_t size_; +}; + +class DedupeEquals : public std::binary_function { + public: + explicit DedupeEquals(std::size_t order) : size_(order * sizeof(WordIndex)) {} + + bool operator()(const WordIndex *first, const WordIndex *second) const { + return !memcmp(first, second, size_); + } + + private: + const std::size_t size_; +}; + +struct DedupeEntry { + typedef WordIndex *Key; + Key GetKey() const { return key; } + Key key; + static DedupeEntry Construct(WordIndex *at) { + DedupeEntry ret; + ret.key = at; + return ret; + } +}; + +typedef util::ProbingHashTable Dedupe; + +const float kProbingMultiplier = 1.5; + +class Writer { + public: + Writer(std::size_t order, const util::stream::ChainPosition &position, void *dedupe_mem, std::size_t dedupe_mem_size) + : block_(position), gram_(block_->Get(), order), + dedupe_invalid_(order, std::numeric_limits::max()), + dedupe_(dedupe_mem, dedupe_mem_size, &dedupe_invalid_[0], DedupeHash(order), DedupeEquals(order)), + buffer_(new WordIndex[order - 1]), + block_size_(position.GetChain().BlockSize()) { + dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + assert(Dedupe::Size(position.GetChain().BlockSize() / position.GetChain().EntrySize(), kProbingMultiplier) == dedupe_mem_size); + if (order == 1) { + // Add special words. AdjustCounts is responsible if order != 1. + AddUnigramWord(kUNK); + AddUnigramWord(kBOS); + } + } + + ~Writer() { + block_->SetValidSize(reinterpret_cast(gram_.begin()) - static_cast(block_->Get())); + (++block_).Poison(); + } + + // Write context with a bunch of + void StartSentence() { + for (WordIndex *i = gram_.begin(); i != gram_.end() - 1; ++i) { + *i = kBOS; + } + } + + void Append(WordIndex word) { + *(gram_.end() - 1) = word; + Dedupe::MutableIterator at; + bool found = dedupe_.FindOrInsert(DedupeEntry::Construct(gram_.begin()), at); + if (found) { + // Already present. + NGram already(at->key, gram_.Order()); + ++(already.Count()); + // Shift left by one. + memmove(gram_.begin(), gram_.begin() + 1, sizeof(WordIndex) * (gram_.Order() - 1)); + return; + } + // Complete the write. + gram_.Count() = 1; + // Prepare the next n-gram. + if (reinterpret_cast(gram_.begin()) + gram_.TotalSize() != static_cast(block_->Get()) + block_size_) { + NGram last(gram_); + gram_.NextInMemory(); + std::copy(last.begin() + 1, last.end(), gram_.begin()); + return; + } + // Block end. Need to store the context in a temporary buffer. + std::copy(gram_.begin() + 1, gram_.end(), buffer_.get()); + dedupe_.Clear(DedupeEntry::Construct(&dedupe_invalid_[0])); + block_->SetValidSize(block_size_); + gram_.ReBase((++block_)->Get()); + std::copy(buffer_.get(), buffer_.get() + gram_.Order() - 1, gram_.begin()); + } + + private: + void AddUnigramWord(WordIndex index) { + *gram_.begin() = index; + gram_.Count() = 0; + gram_.NextInMemory(); + if (gram_.Base() == static_cast(block_->Get()) + block_size_) { + block_->SetValidSize(block_size_); + gram_.ReBase((++block_)->Get()); + } + } + + util::stream::Link block_; + + NGram gram_; + + // This is the memory behind the invalid value in dedupe_. + std::vector dedupe_invalid_; + // Hash table combiner implementation. + Dedupe dedupe_; + + // Small buffer to hold existing ngrams when shifting across a block boundary. + boost::scoped_array buffer_; + + const std::size_t block_size_; +}; + +} // namespace + +float CorpusCount::DedupeMultiplier(std::size_t order) { + return kProbingMultiplier * static_cast(sizeof(DedupeEntry)) / static_cast(NGram::TotalSize(order)); +} + +CorpusCount::CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block) + : from_(from), vocab_write_(vocab_write), token_count_(token_count), type_count_(type_count), + dedupe_mem_size_(Dedupe::Size(entries_per_block, kProbingMultiplier)), + dedupe_mem_(util::MallocOrThrow(dedupe_mem_size_)) { + token_count_ = 0; + type_count_ = 0; +} + +void CorpusCount::Run(const util::stream::ChainPosition &position) { + UTIL_TIMER("(%w s) Counted n-grams\n"); + + VocabHandout vocab(vocab_write_); + const WordIndex end_sentence = vocab.Lookup(""); + Writer writer(NGram::OrderFromSize(position.GetChain().EntrySize()), position, dedupe_mem_.get(), dedupe_mem_size_); + uint64_t count = 0; + try { + while(true) { + StringPiece line(from_.ReadLine()); + writer.StartSentence(); + for (util::TokenIter w(line, " \t"); w; ++w) { + WordIndex word = vocab.Lookup(*w); + UTIL_THROW_IF(word <= 2, FormatLoadException, "Special word " << *w << " is not allowed in the corpus. I plan to support models containing in the future."); + writer.Append(word); + ++count; + } + writer.Append(end_sentence); + } + } catch (const util::EndOfFileException &e) {} + token_count_ = count; + type_count_ = vocab.Size(); +} + +} // namespace builder +} // namespace lm diff --git a/klm/lm/builder/corpus_count.hh b/klm/lm/builder/corpus_count.hh new file mode 100644 index 00000000..e255bad1 --- /dev/null +++ b/klm/lm/builder/corpus_count.hh @@ -0,0 +1,42 @@ +#ifndef LM_BUILDER_CORPUS_COUNT__ +#define LM_BUILDER_CORPUS_COUNT__ + +#include "lm/word_index.hh" +#include "util/scoped.hh" + +#include +#include +#include + +namespace util { +class FilePiece; +namespace stream { +class ChainPosition; +} // namespace stream +} // namespace util + +namespace lm { +namespace builder { + +class CorpusCount { + public: + // Memory usage will be DedupeMultipler(order) * block_size + total_chain_size + unknown vocab_hash_size + static float DedupeMultiplier(std::size_t order); + + CorpusCount(util::FilePiece &from, int vocab_write, uint64_t &token_count, WordIndex &type_count, std::size_t entries_per_block); + + void Run(const util::stream::ChainPosition &position); + + private: + util::FilePiece &from_; + int vocab_write_; + uint64_t &token_count_; + WordIndex &type_count_; + + std::size_t dedupe_mem_size_; + util::scoped_malloc dedupe_mem_; +}; + +} // namespace builder +} // namespace lm +#endif // LM_BUILDER_CORPUS_COUNT__ diff --git a/klm/lm/builder/corpus_count_test.cc b/klm/lm/builder/corpus_count_test.cc new file mode 100644 index 00000000..8d53ca9d --- /dev/null +++ b/klm/lm/builder/corpus_count_test.cc @@ -0,0 +1,76 @@ +#include "lm/builder/corpus_count.hh" + +#include "lm/builder/ngram.hh" +#include "lm/builder/ngram_stream.hh" + +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/tokenize_piece.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#define BOOST_TEST_MODULE CorpusCountTest +#include + +namespace lm { namespace builder { namespace { + +#define Check(str, count) { \ + BOOST_REQUIRE(stream); \ + w = stream->begin(); \ + for (util::TokenIter t(str, " "); t; ++t, ++w) { \ + BOOST_CHECK_EQUAL(*t, v[*w]); \ + } \ + BOOST_CHECK_EQUAL((uint64_t)count, stream->Count()); \ + ++stream; \ +} + +BOOST_AUTO_TEST_CASE(Short) { + util::scoped_fd input_file(util::MakeTemp("corpus_count_test_temp")); + const char input[] = "looking on a little more loin\non a little more loin\non foo little more loin\nbar\n\n"; + // Blocks of 10 are + // looking on a little more loin on a little[duplicate] more[duplicate] loin[duplicate] [duplicate] on[duplicate] foo + // little more loin bar + + util::WriteOrThrow(input_file.get(), input, sizeof(input) - 1); + util::FilePiece input_piece(input_file.release(), "temp file"); + + util::stream::ChainConfig config; + config.entry_size = NGram::TotalSize(3); + config.total_memory = config.entry_size * 20; + config.block_count = 2; + + util::scoped_fd vocab(util::MakeTemp("corpus_count_test_vocab")); + + util::stream::Chain chain(config); + NGramStream stream; + uint64_t token_count; + WordIndex type_count; + CorpusCount counter(input_piece, vocab.get(), token_count, type_count, chain.BlockSize() / chain.EntrySize()); + chain >> boost::ref(counter) >> stream >> util::stream::kRecycle; + + const char *v[] = {"", "", "", "looking", "on", "a", "little", "more", "loin", "foo", "bar"}; + + WordIndex *w; + + Check(" looking", 1); + Check(" looking on", 1); + Check("looking on a", 1); + Check("on a little", 2); + Check("a little more", 2); + Check("little more loin", 2); + Check("more loin ", 2); + Check(" on", 2); + Check(" on a", 1); + Check(" on foo", 1); + Check("on foo little", 1); + Check("foo little more", 1); + Check("little more loin", 1); + Check("more loin ", 1); + Check(" bar", 1); + Check(" bar ", 1); + Check(" ", 1); + BOOST_CHECK(!stream); + BOOST_CHECK_EQUAL(sizeof(v) / sizeof(const char*), type_count); +} + +}}} // namespaces diff --git a/klm/lm/builder/discount.hh b/klm/lm/builder/discount.hh new file mode 100644 index 00000000..754fb20d --- /dev/null +++ b/klm/lm/builder/discount.hh @@ -0,0 +1,26 @@ +#ifndef BUILDER_DISCOUNT__ +#define BUILDER_DISCOUNT__ + +#include + +#include + +namespace lm { +namespace builder { + +struct Discount { + float amount[4]; + + float Get(uint64_t count) const { + return amount[std::min(count, 3)]; + } + + float Apply(uint64_t count) const { + return static_cast(count) - Get(count); + } +}; + +} // namespace builder +} // namespace lm + +#endif // BUILDER_DISCOUNT__ diff --git a/klm/lm/builder/header_info.hh b/klm/lm/builder/header_info.hh new file mode 100644 index 00000000..ccca1456 --- /dev/null +++ b/klm/lm/builder/header_info.hh @@ -0,0 +1,20 @@ +#ifndef LM_BUILDER_HEADER_INFO__ +#define LM_BUILDER_HEADER_INFO__ + +#include +#include + +// Some configuration info that is used to add +// comments to the beginning of an ARPA file +struct HeaderInfo { + const std::string input_file; + const uint64_t token_count; + + HeaderInfo(const std::string& input_file_in, uint64_t token_count_in) + : input_file(input_file_in), token_count(token_count_in) {} + + // TODO: Add smoothing type + // TODO: More info if multiple models were interpolated +}; + +#endif diff --git a/klm/lm/builder/initial_probabilities.cc b/klm/lm/builder/initial_probabilities.cc new file mode 100644 index 00000000..58b42a20 --- /dev/null +++ b/klm/lm/builder/initial_probabilities.cc @@ -0,0 +1,136 @@ +#include "lm/builder/initial_probabilities.hh" + +#include "lm/builder/discount.hh" +#include "lm/builder/ngram_stream.hh" +#include "lm/builder/sort.hh" +#include "util/file.hh" +#include "util/stream/chain.hh" +#include "util/stream/io.hh" +#include "util/stream/stream.hh" + +#include + +namespace lm { namespace builder { + +namespace { +struct BufferEntry { + // Gamma from page 20 of Chen and Goodman. + float gamma; + // \sum_w a(c w) for all w. + float denominator; +}; + +// Extract an array of gamma from an array of BufferEntry. +class OnlyGamma { + public: + void Run(const util::stream::ChainPosition &position) { + for (util::stream::Link block_it(position); block_it; ++block_it) { + float *out = static_cast(block_it->Get()); + const float *in = out; + const float *end = static_cast(block_it->ValidEnd()); + for (out += 1, in += 2; in < end; out += 1, in += 2) { + *out = *in; + } + block_it->SetValidSize(block_it->ValidSize() / 2); + } + } +}; + +class AddRight { + public: + AddRight(const Discount &discount, const util::stream::ChainPosition &input) + : discount_(discount), input_(input) {} + + void Run(const util::stream::ChainPosition &output) { + NGramStream in(input_); + util::stream::Stream out(output); + + std::vector previous(in->Order() - 1); + const std::size_t size = sizeof(WordIndex) * previous.size(); + for(; in; ++out) { + memcpy(&previous[0], in->begin(), size); + uint64_t denominator = 0; + uint64_t counts[4]; + memset(counts, 0, sizeof(counts)); + do { + denominator += in->Count(); + ++counts[std::min(in->Count(), static_cast(3))]; + } while (++in && !memcmp(&previous[0], in->begin(), size)); + BufferEntry &entry = *reinterpret_cast(out.Get()); + entry.denominator = static_cast(denominator); + entry.gamma = 0.0; + for (unsigned i = 1; i <= 3; ++i) { + entry.gamma += discount_.Get(i) * static_cast(counts[i]); + } + entry.gamma /= entry.denominator; + } + out.Poison(); + } + + private: + const Discount &discount_; + const util::stream::ChainPosition input_; +}; + +class MergeRight { + public: + MergeRight(bool interpolate_unigrams, const util::stream::ChainPosition &from_adder, const Discount &discount) + : interpolate_unigrams_(interpolate_unigrams), from_adder_(from_adder), discount_(discount) {} + + // calculate the initial probability of each n-gram (before order-interpolation) + // Run() gets invoked once for each order + void Run(const util::stream::ChainPosition &primary) { + util::stream::Stream summed(from_adder_); + + NGramStream grams(primary); + + // Without interpolation, the interpolation weight goes to . + if (grams->Order() == 1 && !interpolate_unigrams_) { + BufferEntry sums(*static_cast(summed.Get())); + assert(*grams->begin() == kUNK); + grams->Value().uninterp.prob = sums.gamma; + grams->Value().uninterp.gamma = 0.0; + while (++grams) { + grams->Value().uninterp.prob = discount_.Apply(grams->Count()) / sums.denominator; + grams->Value().uninterp.gamma = 0.0; + } + ++summed; + return; + } + + std::vector previous(grams->Order() - 1); + const std::size_t size = sizeof(WordIndex) * previous.size(); + for (; grams; ++summed) { + memcpy(&previous[0], grams->begin(), size); + const BufferEntry &sums = *static_cast(summed.Get()); + do { + Payload &pay = grams->Value(); + pay.uninterp.prob = discount_.Apply(pay.count) / sums.denominator; + pay.uninterp.gamma = sums.gamma; + } while (++grams && !memcmp(&previous[0], grams->begin(), size)); + } + } + + private: + bool interpolate_unigrams_; + util::stream::ChainPosition from_adder_; + Discount discount_; +}; + +} // namespace + +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out) { + util::stream::ChainConfig gamma_config = config.adder_out; + gamma_config.entry_size = sizeof(BufferEntry); + for (size_t i = 0; i < primary.size(); ++i) { + util::stream::ChainPosition second(second_in[i].Add()); + second_in[i] >> util::stream::kRecycle; + gamma_out.push_back(gamma_config); + gamma_out[i] >> AddRight(discounts[i], second); + primary[i] >> MergeRight(config.interpolate_unigrams, gamma_out[i].Add(), discounts[i]); + // Don't bother with the OnlyGamma thread for something to discard. + if (i) gamma_out[i] >> OnlyGamma(); + } +} + +}} // namespaces diff --git a/klm/lm/builder/initial_probabilities.hh b/klm/lm/builder/initial_probabilities.hh new file mode 100644 index 00000000..626388eb --- /dev/null +++ b/klm/lm/builder/initial_probabilities.hh @@ -0,0 +1,34 @@ +#ifndef LM_BUILDER_INITIAL_PROBABILITIES__ +#define LM_BUILDER_INITIAL_PROBABILITIES__ + +#include "lm/builder/discount.hh" +#include "util/stream/config.hh" + +#include + +namespace lm { +namespace builder { +class Chains; + +struct InitialProbabilitiesConfig { + // These should be small buffers to keep the adder from getting too far ahead + util::stream::ChainConfig adder_in; + util::stream::ChainConfig adder_out; + // SRILM doesn't normally interpolate unigrams. + bool interpolate_unigrams; +}; + +/* Compute initial (uninterpolated) probabilities + * primary: the normal chain of n-grams. Incoming is context sorted adjusted + * counts. Outgoing has uninterpolated probabilities for use by Interpolate. + * second_in: a second copy of the primary input. Discard the output. + * gamma_out: Computed gamma values are output on these chains in suffix order. + * The values are bare floats and should be buffered for interpolation to + * use. + */ +void InitialProbabilities(const InitialProbabilitiesConfig &config, const std::vector &discounts, Chains &primary, Chains &second_in, Chains &gamma_out); + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_INITIAL_PROBABILITIES__ diff --git a/klm/lm/builder/interpolate.cc b/klm/lm/builder/interpolate.cc new file mode 100644 index 00000000..50026806 --- /dev/null +++ b/klm/lm/builder/interpolate.cc @@ -0,0 +1,65 @@ +#include "lm/builder/interpolate.hh" + +#include "lm/builder/joint_order.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/sort.hh" +#include "lm/lm_exception.hh" + +#include + +namespace lm { namespace builder { +namespace { + +class Callback { + public: + Callback(float uniform_prob, const ChainPositions &backoffs) : backoffs_(backoffs.size()), probs_(backoffs.size() + 2) { + probs_[0] = uniform_prob; + for (std::size_t i = 0; i < backoffs.size(); ++i) { + backoffs_.push_back(backoffs[i]); + } + } + + ~Callback() { + for (std::size_t i = 0; i < backoffs_.size(); ++i) { + if (backoffs_[i]) { + std::cerr << "Backoffs do not match for order " << (i + 1) << std::endl; + abort(); + } + } + } + + void Enter(unsigned order_minus_1, NGram &gram) { + Payload &pay = gram.Value(); + pay.complete.prob = pay.uninterp.prob + pay.uninterp.gamma * probs_[order_minus_1]; + probs_[order_minus_1 + 1] = pay.complete.prob; + pay.complete.prob = log10(pay.complete.prob); + // TODO: this is a hack to skip n-grams that don't appear as context. Pruning will require some different handling. + if (order_minus_1 < backoffs_.size() && *(gram.end() - 1) != kUNK && *(gram.end() - 1) != kEOS) { + pay.complete.backoff = log10(*static_cast(backoffs_[order_minus_1].Get())); + ++backoffs_[order_minus_1]; + } else { + // Not a context. + pay.complete.backoff = 0.0; + } + } + + void Exit(unsigned, const NGram &) const {} + + private: + FixedArray backoffs_; + + std::vector probs_; +}; +} // namespace + +Interpolate::Interpolate(uint64_t unigram_count, const ChainPositions &backoffs) + : uniform_prob_(1.0 / static_cast(unigram_count - 1)), backoffs_(backoffs) {} + +// perform order-wise interpolation +void Interpolate::Run(const ChainPositions &positions) { + assert(positions.size() == backoffs_.size() + 1); + Callback callback(uniform_prob_, backoffs_); + JointOrder(positions, callback); +} + +}} // namespaces diff --git a/klm/lm/builder/interpolate.hh b/klm/lm/builder/interpolate.hh new file mode 100644 index 00000000..9268d404 --- /dev/null +++ b/klm/lm/builder/interpolate.hh @@ -0,0 +1,27 @@ +#ifndef LM_BUILDER_INTERPOLATE__ +#define LM_BUILDER_INTERPOLATE__ + +#include + +#include "lm/builder/multi_stream.hh" + +namespace lm { namespace builder { + +/* Interpolate step. + * Input: suffix sorted n-grams with (p_uninterpolated, gamma) from + * InitialProbabilities. + * Output: suffix sorted n-grams with complete probability + */ +class Interpolate { + public: + explicit Interpolate(uint64_t unigram_count, const ChainPositions &backoffs); + + void Run(const ChainPositions &positions); + + private: + float uniform_prob_; + ChainPositions backoffs_; +}; + +}} // namespaces +#endif // LM_BUILDER_INTERPOLATE__ diff --git a/klm/lm/builder/joint_order.hh b/klm/lm/builder/joint_order.hh new file mode 100644 index 00000000..b5620144 --- /dev/null +++ b/klm/lm/builder/joint_order.hh @@ -0,0 +1,43 @@ +#ifndef LM_BUILDER_JOINT_ORDER__ +#define LM_BUILDER_JOINT_ORDER__ + +#include "lm/builder/multi_stream.hh" +#include "lm/lm_exception.hh" + +#include + +namespace lm { namespace builder { + +template void JointOrder(const ChainPositions &positions, Callback &callback) { + // Allow matching to reference streams[-1]. + NGramStreams streams_with_dummy; + streams_with_dummy.InitWithDummy(positions); + NGramStream *streams = streams_with_dummy.begin() + 1; + + unsigned int order; + for (order = 0; order < positions.size() && streams[order]; ++order) {} + assert(order); // should always have . + unsigned int current = 0; + while (true) { + // Does the context match the lower one? + if (!memcmp(streams[static_cast(current) - 1]->begin(), streams[current]->begin() + Compare::kMatchOffset, sizeof(WordIndex) * current)) { + callback.Enter(current, *streams[current]); + // Transition to looking for extensions. + if (++current < order) continue; + } + // No extension left. + while(true) { + assert(current > 0); + --current; + callback.Exit(current, *streams[current]); + if (++streams[current]) break; + UTIL_THROW_IF(order != current + 1, FormatLoadException, "Detected n-gram without matching suffix"); + order = current; + if (!order) return; + } + } +} + +}} // namespaces + +#endif // LM_BUILDER_JOINT_ORDER__ diff --git a/klm/lm/builder/main.cc b/klm/lm/builder/main.cc new file mode 100644 index 00000000..90b9dca2 --- /dev/null +++ b/klm/lm/builder/main.cc @@ -0,0 +1,94 @@ +#include "lm/builder/pipeline.hh" +#include "util/file.hh" +#include "util/file_piece.hh" +#include "util/usage.hh" + +#include + +#include + +namespace { +class SizeNotify { + public: + SizeNotify(std::size_t &out) : behind_(out) {} + + void operator()(const std::string &from) { + behind_ = util::ParseSize(from); + } + + private: + std::size_t &behind_; +}; + +boost::program_options::typed_value *SizeOption(std::size_t &to, const char *default_value) { + return boost::program_options::value()->notifier(SizeNotify(to))->default_value(default_value); +} + +} // namespace + +int main(int argc, char *argv[]) { + try { + namespace po = boost::program_options; + po::options_description options("Language model building options"); + lm::builder::PipelineConfig pipeline; + + options.add_options() + ("order,o", po::value(&pipeline.order)->required(), "Order of the model") + ("interpolate_unigrams", po::bool_switch(&pipeline.initial_probs.interpolate_unigrams), "Interpolate the unigrams (default: emulate SRILM by not interpolating)") + ("temp_prefix,T", po::value(&pipeline.sort.temp_prefix)->default_value("/tmp/lm"), "Temporary file prefix") + ("memory,S", SizeOption(pipeline.sort.total_memory, util::GuessPhysicalMemory() ? "80%" : "1G"), "Sorting memory") + ("vocab_memory", SizeOption(pipeline.assume_vocab_hash_size, "50M"), "Assume that the vocabulary hash table will use this much memory for purposes of calculating total memory in the count step") + ("minimum_block", SizeOption(pipeline.minimum_block, "8K"), "Minimum block size to allow") + ("sort_block", SizeOption(pipeline.sort.buffer_size, "64M"), "Size of IO operations for sort (determines arity)") + ("block_count", po::value(&pipeline.block_count)->default_value(2), "Block count (per order)") + ("vocab_file", po::value(&pipeline.vocab_file)->default_value(""), "Location to write vocabulary file") + ("verbose_header", po::bool_switch(&pipeline.verbose_header), "Add a verbose header to the ARPA file that includes information such as token count, smoothing type, etc."); + if (argc == 1) { + std::cerr << + "Builds unpruned language models with modified Kneser-Ney smoothing.\n\n" + "Please cite:\n" + "@inproceedings{kenlm,\n" + "author = {Kenneth Heafield},\n" + "title = {{KenLM}: Faster and Smaller Language Model Queries},\n" + "booktitle = {Proceedings of the Sixth Workshop on Statistical Machine Translation},\n" + "month = {July}, year={2011},\n" + "address = {Edinburgh, UK},\n" + "publisher = {Association for Computational Linguistics},\n" + "}\n\n" + "Provide the corpus on stdin. The ARPA file will be written to stdout. Order of\n" + "the model (-o) is the only mandatory option. As this is an on-disk program,\n" + "setting the temporary file location (-T) and sorting memory (-S) is recommended.\n\n" + "Memory sizes are specified like GNU sort: a number followed by a unit character.\n" + "Valid units are \% for percentage of memory (supported platforms only) and (in\n" + "increasing powers of 1024): b, K, M, G, T, P, E, Z, Y. Default is K (*1024).\n\n"; + std::cerr << options << std::endl; + return 1; + } + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, options), vm); + po::notify(vm); + + util::NormalizeTempPrefix(pipeline.sort.temp_prefix); + + lm::builder::InitialProbabilitiesConfig &initial = pipeline.initial_probs; + // TODO: evaluate options for these. + initial.adder_in.total_memory = 32768; + initial.adder_in.block_count = 2; + initial.adder_out.total_memory = 32768; + initial.adder_out.block_count = 2; + pipeline.read_backoffs = initial.adder_out; + + // Read from stdin + try { + lm::builder::Pipeline(pipeline, 0, 1); + } catch (const util::MallocException &e) { + std::cerr << e.what() << std::endl; + std::cerr << "Try rerunning with a more conservative -S setting than " << vm["memory"].as() << std::endl; + return 1; + } + util::PrintUsage(std::cerr); + } catch (const std::exception &e) { + std::cerr << e.what() << std::endl; + return 1; + } +} diff --git a/klm/lm/builder/multi_stream.hh b/klm/lm/builder/multi_stream.hh new file mode 100644 index 00000000..707a98c7 --- /dev/null +++ b/klm/lm/builder/multi_stream.hh @@ -0,0 +1,180 @@ +#ifndef LM_BUILDER_MULTI_STREAM__ +#define LM_BUILDER_MULTI_STREAM__ + +#include "lm/builder/ngram_stream.hh" +#include "util/scoped.hh" +#include "util/stream/chain.hh" + +#include +#include + +#include +#include + +namespace lm { namespace builder { + +template class FixedArray { + public: + explicit FixedArray(std::size_t count) { + Init(count); + } + + FixedArray() : newed_end_(NULL) {} + + void Init(std::size_t count) { + assert(!block_.get()); + block_.reset(malloc(sizeof(T) * count)); + if (!block_.get()) throw std::bad_alloc(); + newed_end_ = begin(); + } + + FixedArray(const FixedArray &from) { + std::size_t size = from.newed_end_ - static_cast(from.block_.get()); + Init(size); + for (std::size_t i = 0; i < size; ++i) { + new(end()) T(from[i]); + Constructed(); + } + } + + ~FixedArray() { clear(); } + + T *begin() { return static_cast(block_.get()); } + const T *begin() const { return static_cast(block_.get()); } + // Always call Constructed after successful completion of new. + T *end() { return newed_end_; } + const T *end() const { return newed_end_; } + + T &back() { return *(end() - 1); } + const T &back() const { return *(end() - 1); } + + std::size_t size() const { return end() - begin(); } + bool empty() const { return begin() == end(); } + + T &operator[](std::size_t i) { return begin()[i]; } + const T &operator[](std::size_t i) const { return begin()[i]; } + + template void push_back(const C &c) { + new (end()) T(c); + Constructed(); + } + + void clear() { + for (T *i = begin(); i != end(); ++i) + i->~T(); + newed_end_ = begin(); + } + + protected: + void Constructed() { + ++newed_end_; + } + + private: + util::scoped_malloc block_; + + T *newed_end_; +}; + +class Chains; + +class ChainPositions : public FixedArray { + public: + ChainPositions() {} + + void Init(Chains &chains); + + explicit ChainPositions(Chains &chains) { + Init(chains); + } +}; + +class Chains : public FixedArray { + private: + template struct CheckForRun { + typedef Chains type; + }; + + public: + explicit Chains(std::size_t limit) : FixedArray(limit) {} + + template typename CheckForRun::type &operator>>(const Worker &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + template typename CheckForRun::type &operator>>(const boost::reference_wrapper &worker) { + threads_.push_back(new util::stream::Thread(ChainPositions(*this), worker)); + return *this; + } + + Chains &operator>>(const util::stream::Recycler &recycler) { + for (util::stream::Chain *i = begin(); i != end(); ++i) + *i >> recycler; + return *this; + } + + void Wait(bool release_memory = true) { + threads_.clear(); + for (util::stream::Chain *i = begin(); i != end(); ++i) { + i->Wait(release_memory); + } + } + + private: + boost::ptr_vector threads_; + + Chains(const Chains &); + void operator=(const Chains &); +}; + +inline void ChainPositions::Init(Chains &chains) { + FixedArray::Init(chains.size()); + for (util::stream::Chain *i = chains.begin(); i != chains.end(); ++i) { + new (end()) util::stream::ChainPosition(i->Add()); Constructed(); + } +} + +inline Chains &operator>>(Chains &chains, ChainPositions &positions) { + positions.Init(chains); + return chains; +} + +class NGramStreams : public FixedArray { + public: + NGramStreams() {} + + // This puts a dummy NGramStream at the beginning (useful to algorithms that need to reference something at the beginning). + void InitWithDummy(const ChainPositions &positions) { + FixedArray::Init(positions.size() + 1); + new (end()) NGramStream(); Constructed(); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.end(); ++i) { + push_back(*i); + } + } + + // Limit restricts to positions[0,limit) + void Init(const ChainPositions &positions, std::size_t limit) { + FixedArray::Init(limit); + for (const util::stream::ChainPosition *i = positions.begin(); i != positions.begin() + limit; ++i) { + push_back(*i); + } + } + void Init(const ChainPositions &positions) { + Init(positions, positions.size()); + } + + NGramStreams(const ChainPositions &positions) { + Init(positions); + } +}; + +inline Chains &operator>>(Chains &chains, NGramStreams &streams) { + ChainPositions positions; + chains >> positions; + streams.Init(positions); + return chains; +} + +}} // namespaces +#endif // LM_BUILDER_MULTI_STREAM__ diff --git a/klm/lm/builder/ngram.hh b/klm/lm/builder/ngram.hh new file mode 100644 index 00000000..2984ed0b --- /dev/null +++ b/klm/lm/builder/ngram.hh @@ -0,0 +1,84 @@ +#ifndef LM_BUILDER_NGRAM__ +#define LM_BUILDER_NGRAM__ + +#include "lm/weights.hh" +#include "lm/word_index.hh" + +#include + +#include +#include +#include + +namespace lm { +namespace builder { + +struct Uninterpolated { + float prob; // Uninterpolated probability. + float gamma; // Interpolation weight for lower order. +}; + +union Payload { + uint64_t count; + Uninterpolated uninterp; + ProbBackoff complete; +}; + +class NGram { + public: + NGram(void *begin, std::size_t order) + : begin_(static_cast(begin)), end_(begin_ + order) {} + + const uint8_t *Base() const { return reinterpret_cast(begin_); } + uint8_t *Base() { return reinterpret_cast(begin_); } + + void ReBase(void *to) { + std::size_t difference = end_ - begin_; + begin_ = reinterpret_cast(to); + end_ = begin_ + difference; + } + + // Would do operator++ but that can get confusing for a stream. + void NextInMemory() { + ReBase(&Value() + 1); + } + + // Lower-case in deference to STL. + const WordIndex *begin() const { return begin_; } + WordIndex *begin() { return begin_; } + const WordIndex *end() const { return end_; } + WordIndex *end() { return end_; } + + const Payload &Value() const { return *reinterpret_cast(end_); } + Payload &Value() { return *reinterpret_cast(end_); } + + uint64_t &Count() { return Value().count; } + const uint64_t Count() const { return Value().count; } + + std::size_t Order() const { return end_ - begin_; } + + static std::size_t TotalSize(std::size_t order) { + return order * sizeof(WordIndex) + sizeof(Payload); + } + std::size_t TotalSize() const { + // Compiler should optimize this. + return TotalSize(Order()); + } + static std::size_t OrderFromSize(std::size_t size) { + std::size_t ret = (size - sizeof(Payload)) / sizeof(WordIndex); + assert(size == TotalSize(ret)); + return ret; + } + + private: + WordIndex *begin_, *end_; +}; + +const WordIndex kUNK = 0; +const WordIndex kBOS = 1; +const WordIndex kEOS = 2; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_NGRAM__ diff --git a/klm/lm/builder/ngram_stream.hh b/klm/lm/builder/ngram_stream.hh new file mode 100644 index 00000000..3c994664 --- /dev/null +++ b/klm/lm/builder/ngram_stream.hh @@ -0,0 +1,55 @@ +#ifndef LM_BUILDER_NGRAM_STREAM__ +#define LM_BUILDER_NGRAM_STREAM__ + +#include "lm/builder/ngram.hh" +#include "util/stream/chain.hh" +#include "util/stream/stream.hh" + +#include + +namespace lm { namespace builder { + +class NGramStream { + public: + NGramStream() : gram_(NULL, 0) {} + + NGramStream(const util::stream::ChainPosition &position) : gram_(NULL, 0) { + Init(position); + } + + void Init(const util::stream::ChainPosition &position) { + stream_.Init(position); + gram_ = NGram(stream_.Get(), NGram::OrderFromSize(position.GetChain().EntrySize())); + } + + NGram &operator*() { return gram_; } + const NGram &operator*() const { return gram_; } + + NGram *operator->() { return &gram_; } + const NGram *operator->() const { return &gram_; } + + void *Get() { return stream_.Get(); } + const void *Get() const { return stream_.Get(); } + + operator bool() const { return stream_; } + bool operator!() const { return !stream_; } + void Poison() { stream_.Poison(); } + + NGramStream &operator++() { + ++stream_; + gram_.ReBase(stream_.Get()); + return *this; + } + + private: + NGram gram_; + util::stream::Stream stream_; +}; + +inline util::stream::Chain &operator>>(util::stream::Chain &chain, NGramStream &str) { + str.Init(chain.Add()); + return chain; +} + +}} // namespaces +#endif // LM_BUILDER_NGRAM_STREAM__ diff --git a/klm/lm/builder/pipeline.cc b/klm/lm/builder/pipeline.cc new file mode 100644 index 00000000..14a1f721 --- /dev/null +++ b/klm/lm/builder/pipeline.cc @@ -0,0 +1,320 @@ +#include "lm/builder/pipeline.hh" + +#include "lm/builder/adjust_counts.hh" +#include "lm/builder/corpus_count.hh" +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/interpolate.hh" +#include "lm/builder/print.hh" +#include "lm/builder/sort.hh" + +#include "lm/sizes.hh" + +#include "util/exception.hh" +#include "util/file.hh" +#include "util/stream/io.hh" + +#include +#include +#include + +namespace lm { namespace builder { + +namespace { +void PrintStatistics(const std::vector &counts, const std::vector &discounts) { + std::cerr << "Statistics:\n"; + for (size_t i = 0; i < counts.size(); ++i) { + std::cerr << (i + 1) << ' ' << counts[i]; + for (size_t d = 1; d <= 3; ++d) + std::cerr << " D" << d << (d == 3 ? "+=" : "=") << discounts[i].amount[d]; + std::cerr << '\n'; + } +} + +class Master { + public: + explicit Master(const PipelineConfig &config) + : config_(config), chains_(config.order), files_(config.order) { + config_.minimum_block = std::max(NGram::TotalSize(config_.order), config_.minimum_block); + } + + const PipelineConfig &Config() const { return config_; } + + Chains &MutableChains() { return chains_; } + + template Master &operator>>(const T &worker) { + chains_ >> worker; + return *this; + } + + // This takes the (partially) sorted ngrams and sets up for adjusted counts. + void InitForAdjust(util::stream::Sort &ngrams, WordIndex types) { + const std::size_t each_order_min = config_.minimum_block * config_.block_count; + // We know how many unigrams there are. Don't allocate more than needed to them. + const std::size_t min_chains = (config_.order - 1) * each_order_min + + std::min(types * NGram::TotalSize(1), each_order_min); + // Do merge sort with calculated laziness. + const std::size_t merge_using = ngrams.Merge(std::min(config_.TotalMemory() - min_chains, ngrams.DefaultLazy())); + + std::vector count_bounds(1, types); + CreateChains(config_.TotalMemory() - merge_using, count_bounds); + ngrams.Output(chains_.back(), merge_using); + + // Setup unigram file. + files_.push_back(util::MakeTemp(config_.TempPrefix())); + } + + // For initial probabilities, but this is generic. + void SortAndReadTwice(const std::vector &counts, Sorts &sorts, Chains &second, util::stream::ChainConfig second_config) { + // Do merge first before allocating chain memory. + for (std::size_t i = 1; i < config_.order; ++i) { + sorts[i - 1].Merge(0); + } + // There's no lazy merge, so just divide memory amongst the chains. + CreateChains(config_.TotalMemory(), counts); + chains_.back().ActivateProgress(); + chains_[0] >> files_[0].Source(); + second_config.entry_size = NGram::TotalSize(1); + second.push_back(second_config); + second.back() >> files_[0].Source(); + for (std::size_t i = 1; i < config_.order; ++i) { + util::scoped_fd fd(sorts[i - 1].StealCompleted()); + chains_[i].SetProgressTarget(util::SizeOrThrow(fd.get())); + chains_[i] >> util::stream::PRead(util::DupOrThrow(fd.get()), true); + second_config.entry_size = NGram::TotalSize(i + 1); + second.push_back(second_config); + second.back() >> util::stream::PRead(fd.release(), true); + } + } + + // There is no sort after this, so go for broke on lazy merging. + template void MaximumLazyInput(const std::vector &counts, Sorts &sorts) { + // Determine the minimum we can use for all the chains. + std::size_t min_chains = 0; + for (std::size_t i = 0; i < config_.order; ++i) { + min_chains += std::min(counts[i] * NGram::TotalSize(i + 1), static_cast(config_.minimum_block)); + } + std::size_t for_merge = min_chains > config_.TotalMemory() ? 0 : (config_.TotalMemory() - min_chains); + std::vector laziness; + // Prioritize longer n-grams. + for (util::stream::Sort *i = sorts.end() - 1; i >= sorts.begin(); --i) { + laziness.push_back(i->Merge(for_merge)); + assert(for_merge >= laziness.back()); + for_merge -= laziness.back(); + } + std::reverse(laziness.begin(), laziness.end()); + + CreateChains(for_merge + min_chains, counts); + chains_.back().ActivateProgress(); + chains_[0] >> files_[0].Source(); + for (std::size_t i = 1; i < config_.order; ++i) { + sorts[i - 1].Output(chains_[i], laziness[i - 1]); + } + } + + void BufferFinal(const std::vector &counts) { + chains_[0] >> files_[0].Sink(); + for (std::size_t i = 1; i < config_.order; ++i) { + files_.push_back(util::MakeTemp(config_.TempPrefix())); + chains_[i] >> files_[i].Sink(); + } + chains_.Wait(true); + // Use less memory. Because we can. + CreateChains(std::min(config_.sort.buffer_size * config_.order, config_.TotalMemory()), counts); + for (std::size_t i = 0; i < config_.order; ++i) { + chains_[i] >> files_[i].Source(); + } + } + + template void SetupSorts(Sorts &sorts) { + sorts.Init(config_.order - 1); + // Unigrams don't get sorted because their order is always the same. + chains_[0] >> files_[0].Sink(); + for (std::size_t i = 1; i < config_.order; ++i) { + sorts.push_back(chains_[i], config_.sort, Compare(i + 1)); + } + chains_.Wait(true); + } + + private: + // Create chains, allocating memory to them. Totally heuristic. Count + // bounds are upper bounds on the counts or not present. + void CreateChains(std::size_t remaining_mem, const std::vector &count_bounds) { + std::vector assignments; + assignments.reserve(config_.order); + // Start by assigning maximum memory usage (to be refined later). + for (std::size_t i = 0; i < count_bounds.size(); ++i) { + assignments.push_back(static_cast(std::min( + static_cast(remaining_mem), + count_bounds[i] * static_cast(NGram::TotalSize(i + 1))))); + } + assignments.resize(config_.order, remaining_mem); + + // Now we know how much memory everybody wants. How much will they get? + // Proportional to this. + std::vector portions; + // Indices of orders that have yet to be assigned. + std::vector unassigned; + for (std::size_t i = 0; i < config_.order; ++i) { + portions.push_back(static_cast((i+1) * NGram::TotalSize(i+1))); + unassigned.push_back(i); + } + /*If somebody doesn't eat their full dinner, give it to the rest of the + * family. Then somebody else might not eat their full dinner etc. Ends + * when everybody unassigned is hungry. + */ + float sum; + bool found_more; + std::vector block_count(config_.order); + do { + sum = 0.0; + for (std::size_t i = 0; i < unassigned.size(); ++i) { + sum += portions[unassigned[i]]; + } + found_more = false; + // If the proportional assignment is more than needed, give it just what it needs. + for (std::vector::iterator i = unassigned.begin(); i != unassigned.end();) { + if (assignments[*i] <= remaining_mem * (portions[*i] / sum)) { + remaining_mem -= assignments[*i]; + block_count[*i] = 1; + i = unassigned.erase(i); + found_more = true; + } else { + ++i; + } + } + } while (found_more); + for (std::vector::iterator i = unassigned.begin(); i != unassigned.end(); ++i) { + assignments[*i] = remaining_mem * (portions[*i] / sum); + block_count[*i] = config_.block_count; + } + chains_.clear(); + std::cerr << "Chain sizes:"; + for (std::size_t i = 0; i < config_.order; ++i) { + std::cerr << ' ' << (i+1) << ":" << assignments[i]; + chains_.push_back(util::stream::ChainConfig(NGram::TotalSize(i + 1), block_count[i], assignments[i])); + } + std::cerr << std::endl; + } + + PipelineConfig config_; + + Chains chains_; + // Often only unigrams, but sometimes all orders. + FixedArray files_; +}; + +void CountText(int text_file /* input */, int vocab_file /* output */, Master &master, uint64_t &token_count, std::string &text_file_name) { + const PipelineConfig &config = master.Config(); + std::cerr << "=== 1/5 Counting and sorting n-grams ===" << std::endl; + + UTIL_THROW_IF(config.TotalMemory() < config.assume_vocab_hash_size, util::Exception, "Vocab hash size estimate " << config.assume_vocab_hash_size << " exceeds total memory " << config.TotalMemory()); + std::size_t memory_for_chain = + // This much memory to work with after vocab hash table. + static_cast(config.TotalMemory() - config.assume_vocab_hash_size) / + // Solve for block size including the dedupe multiplier for one block. + (static_cast(config.block_count) + CorpusCount::DedupeMultiplier(config.order)) * + // Chain likes memory expressed in terms of total memory. + static_cast(config.block_count); + util::stream::Chain chain(util::stream::ChainConfig(NGram::TotalSize(config.order), config.block_count, memory_for_chain)); + + WordIndex type_count; + util::FilePiece text(text_file, NULL, &std::cerr); + text_file_name = text.FileName(); + CorpusCount counter(text, vocab_file, token_count, type_count, chain.BlockSize() / chain.EntrySize()); + chain >> boost::ref(counter); + + util::stream::Sort sorter(chain, config.sort, SuffixOrder(config.order), AddCombiner()); + chain.Wait(true); + std::cerr << "=== 2/5 Calculating and sorting adjusted counts ===" << std::endl; + master.InitForAdjust(sorter, type_count); +} + +void InitialProbabilities(const std::vector &counts, const std::vector &discounts, Master &master, Sorts &primary, FixedArray &gammas) { + const PipelineConfig &config = master.Config(); + Chains second(config.order); + + { + Sorts sorts; + master.SetupSorts(sorts); + PrintStatistics(counts, discounts); + lm::ngram::ShowSizes(counts); + std::cerr << "=== 3/5 Calculating and sorting initial probabilities ===" << std::endl; + master.SortAndReadTwice(counts, sorts, second, config.initial_probs.adder_in); + } + + Chains gamma_chains(config.order); + InitialProbabilities(config.initial_probs, discounts, master.MutableChains(), second, gamma_chains); + // Don't care about gamma for 0. + gamma_chains[0] >> util::stream::kRecycle; + gammas.Init(config.order - 1); + for (std::size_t i = 1; i < config.order; ++i) { + gammas.push_back(util::MakeTemp(config.TempPrefix())); + gamma_chains[i] >> gammas[i - 1].Sink(); + } + // Has to be done here due to gamma_chains scope. + master.SetupSorts(primary); +} + +void InterpolateProbabilities(const std::vector &counts, Master &master, Sorts &primary, FixedArray &gammas) { + std::cerr << "=== 4/5 Calculating and writing order-interpolated probabilities ===" << std::endl; + const PipelineConfig &config = master.Config(); + master.MaximumLazyInput(counts, primary); + + Chains gamma_chains(config.order - 1); + util::stream::ChainConfig read_backoffs(config.read_backoffs); + read_backoffs.entry_size = sizeof(float); + for (std::size_t i = 0; i < config.order - 1; ++i) { + gamma_chains.push_back(read_backoffs); + gamma_chains.back() >> gammas[i].Source(); + } + master >> Interpolate(counts[0], ChainPositions(gamma_chains)); + gamma_chains >> util::stream::kRecycle; + master.BufferFinal(counts); +} + +} // namespace + +void Pipeline(PipelineConfig config, int text_file, int out_arpa) { + // Some fail-fast sanity checks. + if (config.sort.buffer_size * 4 > config.TotalMemory()) { + config.sort.buffer_size = config.TotalMemory() / 4; + std::cerr << "Warning: changing sort block size to " << config.sort.buffer_size << " bytes due to low total memory." << std::endl; + } + if (config.minimum_block < NGram::TotalSize(config.order)) { + config.minimum_block = NGram::TotalSize(config.order); + std::cerr << "Warning: raising minimum block to " << config.minimum_block << " to fit an ngram in every block." << std::endl; + } + UTIL_THROW_IF(config.sort.buffer_size < config.minimum_block, util::Exception, "Sort block size " << config.sort.buffer_size << " is below the minimum block size " << config.minimum_block << "."); + UTIL_THROW_IF(config.TotalMemory() < config.minimum_block * config.order * config.block_count, util::Exception, + "Not enough memory to fit " << (config.order * config.block_count) << " blocks with minimum size " << config.minimum_block << ". Increase memory to " << (config.minimum_block * config.order * config.block_count) << " bytes or decrease the minimum block size."); + + UTIL_TIMER("(%w s) Total wall time elapsed\n"); + Master master(config); + + util::scoped_fd vocab_file(config.vocab_file.empty() ? + util::MakeTemp(config.TempPrefix()) : + util::CreateOrThrow(config.vocab_file.c_str())); + uint64_t token_count; + std::string text_file_name; + CountText(text_file, vocab_file.get(), master, token_count, text_file_name); + + std::vector counts; + std::vector discounts; + master >> AdjustCounts(counts, discounts); + + { + FixedArray gammas; + Sorts primary; + InitialProbabilities(counts, discounts, master, primary, gammas); + InterpolateProbabilities(counts, master, primary, gammas); + } + + std::cerr << "=== 5/5 Writing ARPA model ===" << std::endl; + VocabReconstitute vocab(vocab_file.get()); + UTIL_THROW_IF(vocab.Size() != counts[0], util::Exception, "Vocab words don't match up. Is there a null byte in the input?"); + HeaderInfo header_info(text_file_name, token_count); + master >> PrintARPA(vocab, counts, (config.verbose_header ? &header_info : NULL), out_arpa) >> util::stream::kRecycle; + master.MutableChains().Wait(true); +} + +}} // namespaces diff --git a/klm/lm/builder/pipeline.hh b/klm/lm/builder/pipeline.hh new file mode 100644 index 00000000..f1d6c5f6 --- /dev/null +++ b/klm/lm/builder/pipeline.hh @@ -0,0 +1,40 @@ +#ifndef LM_BUILDER_PIPELINE__ +#define LM_BUILDER_PIPELINE__ + +#include "lm/builder/initial_probabilities.hh" +#include "lm/builder/header_info.hh" +#include "util/stream/config.hh" +#include "util/file_piece.hh" + +#include +#include + +namespace lm { namespace builder { + +struct PipelineConfig { + std::size_t order; + std::string vocab_file; + util::stream::SortConfig sort; + InitialProbabilitiesConfig initial_probs; + util::stream::ChainConfig read_backoffs; + bool verbose_header; + + // Amount of memory to assume that the vocabulary hash table will use. This + // is subtracted from total memory for CorpusCount. + std::size_t assume_vocab_hash_size; + + // Minimum block size to tolerate. + std::size_t minimum_block; + + // Number of blocks to use. This will be overridden to 1 if everything fits. + std::size_t block_count; + + const std::string &TempPrefix() const { return sort.temp_prefix; } + std::size_t TotalMemory() const { return sort.total_memory; } +}; + +// Takes ownership of text_file. +void Pipeline(PipelineConfig config, int text_file, int out_arpa); + +}} // namespaces +#endif // LM_BUILDER_PIPELINE__ diff --git a/klm/lm/builder/print.cc b/klm/lm/builder/print.cc new file mode 100644 index 00000000..b0323221 --- /dev/null +++ b/klm/lm/builder/print.cc @@ -0,0 +1,135 @@ +#include "lm/builder/print.hh" + +#include "util/double-conversion/double-conversion.h" +#include "util/double-conversion/utils.h" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/scoped.hh" +#include "util/stream/timer.hh" + +#define BOOST_LEXICAL_CAST_ASSUME_C_LOCALE +#include + +#include + +#include + +namespace lm { namespace builder { + +VocabReconstitute::VocabReconstitute(int fd) { + uint64_t size = util::SizeOrThrow(fd); + util::MapRead(util::POPULATE_OR_READ, fd, 0, size, memory_); + const char *const start = static_cast(memory_.get()); + const char *i; + for (i = start; i != start + size; i += strlen(i) + 1) { + map_.push_back(i); + } + // Last one for LookupPiece. + map_.push_back(i); +} + +namespace { +class OutputManager { + public: + static const std::size_t kOutBuf = 1048576; + + // Does not take ownership of out. + explicit OutputManager(int out) + : buf_(util::MallocOrThrow(kOutBuf)), + builder_(static_cast(buf_.get()), kOutBuf), + // Mostly the default but with inf instead. And no flags. + convert_(double_conversion::DoubleToStringConverter::NO_FLAGS, "inf", "NaN", 'e', -6, 21, 6, 0), + fd_(out) {} + + ~OutputManager() { + Flush(); + } + + OutputManager &operator<<(float value) { + // Odd, but this is the largest number found in the comments. + EnsureRemaining(double_conversion::DoubleToStringConverter::kMaxPrecisionDigits + 8); + convert_.ToShortestSingle(value, &builder_); + return *this; + } + + OutputManager &operator<<(StringPiece str) { + if (str.size() > kOutBuf) { + Flush(); + util::WriteOrThrow(fd_, str.data(), str.size()); + } else { + EnsureRemaining(str.size()); + builder_.AddSubstring(str.data(), str.size()); + } + return *this; + } + + // Inefficient! + OutputManager &operator<<(unsigned val) { + return *this << boost::lexical_cast(val); + } + + OutputManager &operator<<(char c) { + EnsureRemaining(1); + builder_.AddCharacter(c); + return *this; + } + + void Flush() { + util::WriteOrThrow(fd_, buf_.get(), builder_.position()); + builder_.Reset(); + } + + private: + void EnsureRemaining(std::size_t amount) { + if (static_cast(builder_.size() - builder_.position()) < amount) { + Flush(); + } + } + + util::scoped_malloc buf_; + double_conversion::StringBuilder builder_; + double_conversion::DoubleToStringConverter convert_; + int fd_; +}; +} // namespace + +PrintARPA::PrintARPA(const VocabReconstitute &vocab, const std::vector &counts, const HeaderInfo* header_info, int out_fd) + : vocab_(vocab), out_fd_(out_fd) { + std::stringstream stream; + + if (header_info) { + stream << "# Input file: " << header_info->input_file << '\n'; + stream << "# Token count: " << header_info->token_count << '\n'; + stream << "# Smoothing: Modified Kneser-Ney" << '\n'; + } + stream << "\\data\\\n"; + for (size_t i = 0; i < counts.size(); ++i) { + stream << "ngram " << (i+1) << '=' << counts[i] << '\n'; + } + stream << '\n'; + std::string as_string(stream.str()); + util::WriteOrThrow(out_fd, as_string.data(), as_string.size()); +} + +void PrintARPA::Run(const ChainPositions &positions) { + UTIL_TIMER("(%w s) Wrote ARPA file\n"); + OutputManager out(out_fd_); + for (unsigned order = 1; order <= positions.size(); ++order) { + out << "\\" << order << "-grams:" << '\n'; + for (NGramStream stream(positions[order - 1]); stream; ++stream) { + // Correcting for numerical precision issues. Take that IRST. + out << std::min(0.0f, stream->Value().complete.prob) << '\t' << vocab_.Lookup(*stream->begin()); + for (const WordIndex *i = stream->begin() + 1; i != stream->end(); ++i) { + out << ' ' << vocab_.Lookup(*i); + } + float backoff = stream->Value().complete.backoff; + if (backoff != 0.0) + out << '\t' << backoff; + out << '\n'; + } + out << '\n'; + } + out << "\\end\\\n"; +} + +}} // namespaces diff --git a/klm/lm/builder/print.hh b/klm/lm/builder/print.hh new file mode 100644 index 00000000..aa932e75 --- /dev/null +++ b/klm/lm/builder/print.hh @@ -0,0 +1,102 @@ +#ifndef LM_BUILDER_PRINT__ +#define LM_BUILDER_PRINT__ + +#include "lm/builder/ngram.hh" +#include "lm/builder/multi_stream.hh" +#include "lm/builder/header_info.hh" +#include "util/file.hh" +#include "util/mmap.hh" +#include "util/string_piece.hh" + +#include + +#include + +// Warning: print routines read all unigrams before all bigrams before all +// trigrams etc. So if other parts of the chain move jointly, you'll have to +// buffer. + +namespace lm { namespace builder { + +class VocabReconstitute { + public: + // fd must be alive for life of this object; does not take ownership. + explicit VocabReconstitute(int fd); + + const char *Lookup(WordIndex index) const { + assert(index < map_.size() - 1); + return map_[index]; + } + + StringPiece LookupPiece(WordIndex index) const { + return StringPiece(map_[index], map_[index + 1] - 1 - map_[index]); + } + + std::size_t Size() const { + // There's an extra entry to support StringPiece lengths. + return map_.size() - 1; + } + + private: + util::scoped_memory memory_; + std::vector map_; +}; + +// Not defined, only specialized. +template void PrintPayload(std::ostream &to, const Payload &payload); +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << payload.count; +} +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << log10(payload.uninterp.prob) << ' ' << log10(payload.uninterp.gamma); +} +template <> inline void PrintPayload(std::ostream &to, const Payload &payload) { + to << payload.complete.prob << ' ' << payload.complete.backoff; +} + +// template parameter is the type stored. +template class Print { + public: + explicit Print(const VocabReconstitute &vocab, std::ostream &to) : vocab_(vocab), to_(to) {} + + void Run(const ChainPositions &chains) { + NGramStreams streams(chains); + for (NGramStream *s = streams.begin(); s != streams.end(); ++s) { + DumpStream(*s); + } + } + + void Run(const util::stream::ChainPosition &position) { + NGramStream stream(position); + DumpStream(stream); + } + + private: + void DumpStream(NGramStream &stream) { + for (; stream; ++stream) { + PrintPayload(to_, stream->Value()); + for (const WordIndex *w = stream->begin(); w != stream->end(); ++w) { + to_ << ' ' << vocab_.Lookup(*w) << '=' << *w; + } + to_ << '\n'; + } + } + + const VocabReconstitute &vocab_; + std::ostream &to_; +}; + +class PrintARPA { + public: + // header_info may be NULL to disable the header + explicit PrintARPA(const VocabReconstitute &vocab, const std::vector &counts, const HeaderInfo* header_info, int out_fd); + + void Run(const ChainPositions &positions); + + private: + const VocabReconstitute &vocab_; + int out_fd_; +}; + +}} // namespaces +#endif // LM_BUILDER_PRINT__ diff --git a/klm/lm/builder/sort.hh b/klm/lm/builder/sort.hh new file mode 100644 index 00000000..9989389b --- /dev/null +++ b/klm/lm/builder/sort.hh @@ -0,0 +1,103 @@ +#ifndef LM_BUILDER_SORT__ +#define LM_BUILDER_SORT__ + +#include "lm/builder/multi_stream.hh" +#include "lm/builder/ngram.hh" +#include "lm/word_index.hh" +#include "util/stream/sort.hh" + +#include "util/stream/timer.hh" + +#include +#include + +namespace lm { +namespace builder { + +template class Comparator : public std::binary_function { + public: + explicit Comparator(std::size_t order) : order_(order) {} + + inline bool operator()(const void *lhs, const void *rhs) const { + return static_cast(this)->Compare(static_cast(lhs), static_cast(rhs)); + } + + std::size_t Order() const { return order_; } + + protected: + std::size_t order_; +}; + +class SuffixOrder : public Comparator { + public: + explicit SuffixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = order_ - 1; i != 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[0] < rhs[0]; + } + + static const unsigned kMatchOffset = 1; +}; + +class ContextOrder : public Comparator { + public: + explicit ContextOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (int i = order_ - 2; i >= 0; --i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return lhs[order_ - 1] < rhs[order_ - 1]; + } +}; + +class PrefixOrder : public Comparator { + public: + explicit PrefixOrder(std::size_t order) : Comparator(order) {} + + inline bool Compare(const WordIndex *lhs, const WordIndex *rhs) const { + for (std::size_t i = 0; i < order_; ++i) { + if (lhs[i] != rhs[i]) + return lhs[i] < rhs[i]; + } + return false; + } + + static const unsigned kMatchOffset = 0; +}; + +// Sum counts for the same n-gram. +struct AddCombiner { + bool operator()(void *first_void, const void *second_void, const SuffixOrder &compare) const { + NGram first(first_void, compare.Order()); + // There isn't a const version of NGram. + NGram second(const_cast(second_void), compare.Order()); + if (memcmp(first.begin(), second.begin(), sizeof(WordIndex) * compare.Order())) return false; + first.Count() += second.Count(); + return true; + } +}; + +// The combiner is only used on a single chain, so I didn't bother to allow +// that template. +template class Sorts : public FixedArray > { + private: + typedef util::stream::Sort S; + typedef FixedArray P; + + public: + void push_back(util::stream::Chain &chain, const util::stream::SortConfig &config, const Compare &compare) { + new (P::end()) S(chain, config, compare); + P::Constructed(); + } +}; + +} // namespace builder +} // namespace lm + +#endif // LM_BUILDER_SORT__ diff --git a/klm/lm/filter/arpa_io.cc b/klm/lm/filter/arpa_io.cc new file mode 100644 index 00000000..caf8df95 --- /dev/null +++ b/klm/lm/filter/arpa_io.cc @@ -0,0 +1,122 @@ +#include "lm/filter/arpa_io.hh" +#include "util/file_piece.hh" + +#include +#include +#include +#include + +#include +#include +#include + +namespace lm { + +ARPAInputException::ARPAInputException(const StringPiece &message) throw() : what_("Error: ") { + what_.append(message.data(), message.size()); +} + +ARPAInputException::ARPAInputException(const StringPiece &message, const StringPiece &line) throw() { + what_ = "Error: "; + what_.append(message.data(), message.size()); + what_ += " in line '"; + what_.append(line.data(), line.size()); + what_ += "'."; +} + +ARPAOutputException::ARPAOutputException(const char *message, const std::string &file_name) throw() + : what_(std::string(message) + " file " + file_name), file_name_(file_name) { + if (errno) { + char buf[1024]; + buf[0] = 0; +#if (_POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600) && ! _GNU_SOURCE + const char *add = buf; + if (!strerror_r(errno, buf, 1024)) { +#else + const char *add = strerror_r(errno, buf, 1024); + if (add) { +#endif + what_ += " :"; + what_ += add; + } + } +} + +// Seeking is the responsibility of the caller. +void WriteCounts(std::ostream &out, const std::vector &number) { + out << "\n\\data\\\n"; + for (unsigned int i = 0; i < number.size(); ++i) { + out << "ngram " << i+1 << "=" << number[i] << '\n'; + } + out << '\n'; +} + +size_t SizeNeededForCounts(const std::vector &number) { + std::ostringstream buf; + WriteCounts(buf, number); + return buf.tellp(); +} + +bool IsEntirelyWhiteSpace(const StringPiece &line) { + for (size_t i = 0; i < static_cast(line.size()); ++i) { + if (!isspace(line.data()[i])) return false; + } + return true; +} + +ARPAOutput::ARPAOutput(const char *name, size_t buffer_size) : file_name_(name), buffer_(new char[buffer_size]) { + try { + file_.exceptions(std::ostream::eofbit | std::ostream::failbit | std::ostream::badbit); + if (!file_.rdbuf()->pubsetbuf(buffer_.get(), buffer_size)) { + std::cerr << "Warning: could not enlarge buffer for " << name << std::endl; + buffer_.reset(); + } + file_.open(name, std::ios::out | std::ios::binary); + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Opening", file_name_); + } +} + +void ARPAOutput::ReserveForCounts(std::streampos reserve) { + try { + for (std::streampos i = 0; i < reserve; i += std::streampos(1)) { + file_ << '\n'; + } + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blanks to reserve space for counts to ", file_name_); + } +} + +void ARPAOutput::BeginLength(unsigned int length) { + fast_counter_ = 0; + try { + file_ << '\\' << length << "-grams:" << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing n-gram header to ", file_name_); + } +} + +void ARPAOutput::EndLength(unsigned int length) { + try { + file_ << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing blank at end of count list to ", file_name_); + } + if (length > counts_.size()) { + counts_.resize(length); + } + counts_[length - 1] = fast_counter_; +} + +void ARPAOutput::Finish() { + try { + file_ << "\\end\\\n"; + file_.seekp(0); + WriteCounts(file_, counts_); + file_ << std::flush; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Finishing including writing counts at beginning to ", file_name_); + } +} + +} // namespace lm diff --git a/klm/lm/filter/arpa_io.hh b/klm/lm/filter/arpa_io.hh new file mode 100644 index 00000000..90f48447 --- /dev/null +++ b/klm/lm/filter/arpa_io.hh @@ -0,0 +1,122 @@ +#ifndef LM_FILTER_ARPA_IO__ +#define LM_FILTER_ARPA_IO__ +/* Input and output for ARPA format language model files. + */ +#include "lm/read_arpa.hh" +#include "util/exception.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include +#include + +#include +#include +#include + +#include +#include + +namespace util { class FilePiece; } + +namespace lm { + +class ARPAInputException : public util::Exception { + public: + explicit ARPAInputException(const StringPiece &message) throw(); + explicit ARPAInputException(const StringPiece &message, const StringPiece &line) throw(); + virtual ~ARPAInputException() throw() {} + + const char *what() const throw() { return what_.c_str(); } + + private: + std::string what_; +}; + +class ARPAOutputException : public std::exception { + public: + ARPAOutputException(const char *prefix, const std::string &file_name) throw(); + virtual ~ARPAOutputException() throw() {} + + const char *what() const throw() { return what_.c_str(); } + + const std::string &File() const throw() { return file_name_; } + + private: + std::string what_; + const std::string file_name_; +}; + +// Handling for the counts of n-grams at the beginning of ARPA files. +size_t SizeNeededForCounts(const std::vector &number); + +/* Writes an ARPA file. This has to be seekable so the counts can be written + * at the end. Hence, I just have it own a std::fstream instead of accepting + * a separately held std::ostream. + */ +class ARPAOutput : boost::noncopyable { + public: + explicit ARPAOutput(const char *name, size_t buffer_size = 65536); + + void ReserveForCounts(std::streampos reserve); + + void BeginLength(unsigned int length); + + void AddNGram(const StringPiece &line) { + try { + file_ << line << '\n'; + } catch (const std::ios_base::failure &f) { + throw ARPAOutputException("Writing an n-gram", file_name_); + } + ++fast_counter_; + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void EndLength(unsigned int length); + + void Finish(); + + private: + const std::string file_name_; + boost::scoped_array buffer_; + std::fstream file_; + size_t fast_counter_; + std::vector counts_; +}; + + +template void ReadNGrams(util::FilePiece &in, unsigned int length, size_t number, Output &out) { + ReadNGramHeader(in, length); + out.BeginLength(length); + for (size_t i = 0; i < number; ++i) { + StringPiece line = in.ReadLine(); + util::TokenIter tabber(line, '\t'); + if (!tabber) throw ARPAInputException("blank line", line); + if (!++tabber) throw ARPAInputException("no tab", line); + + out.AddNGram(*tabber, line); + } + out.EndLength(length); +} + +template void ReadARPA(util::FilePiece &in_lm, Output &out) { + std::vector number; + ReadARPACounts(in_lm, number); + out.ReserveForCounts(SizeNeededForCounts(number)); + for (unsigned int i = 0; i < number.size(); ++i) { + ReadNGrams(in_lm, i + 1, number[i], out); + } + ReadEnd(in_lm); + out.Finish(); +} + +} // namespace lm + +#endif // LM_FILTER_ARPA_IO__ diff --git a/klm/lm/filter/count_io.hh b/klm/lm/filter/count_io.hh new file mode 100644 index 00000000..97c0fa25 --- /dev/null +++ b/klm/lm/filter/count_io.hh @@ -0,0 +1,91 @@ +#ifndef LM_FILTER_COUNT_IO__ +#define LM_FILTER_COUNT_IO__ + +#include +#include +#include + +#include + +#include "util/file_piece.hh" + +namespace lm { + +class CountOutput : boost::noncopyable { + public: + explicit CountOutput(const char *name) : file_(name, std::ios::out) {} + + void AddNGram(const StringPiece &line) { + if (!(file_ << line << '\n')) { + err(3, "Writing counts file failed"); + } + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + AddNGram(line); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + AddNGram(line); + } + + private: + std::fstream file_; +}; + +class CountBatch { + public: + explicit CountBatch(std::streamsize initial_read) + : initial_read_(initial_read) { + buffer_.reserve(initial_read); + } + + void Read(std::istream &in) { + buffer_.resize(initial_read_); + in.read(&*buffer_.begin(), initial_read_); + buffer_.resize(in.gcount()); + char got; + while (in.get(got) && got != '\n') + buffer_.push_back(got); + } + + template void Send(Output &out) { + for (util::TokenIter line(StringPiece(&*buffer_.begin(), buffer_.size()), '\n'); line; ++line) { + util::TokenIter tabber(*line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + util::TokenIter words(*tabber, ' '); + if (!words) { + std::cerr << "Line has a tab but no words.\n"; + continue; + } + out.AddNGram(words, util::TokenIter::end(), *line); + } + } + + private: + std::streamsize initial_read_; + + // This could have been a std::string but that's less happy with raw writes. + std::vector buffer_; +}; + +template void ReadCount(util::FilePiece &in_file, Output &out) { + try { + while (true) { + StringPiece line = in_file.ReadLine(); + util::TokenIter tabber(line, '\t'); + if (!tabber) { + std::cerr << "Warning: empty n-gram count line being removed\n"; + continue; + } + out.AddNGram(*tabber, line); + } + } catch (const util::EndOfFileException &e) {} +} + +} // namespace lm + +#endif // LM_FILTER_COUNT_IO__ diff --git a/klm/lm/filter/format.hh b/klm/lm/filter/format.hh new file mode 100644 index 00000000..7f945b0d --- /dev/null +++ b/klm/lm/filter/format.hh @@ -0,0 +1,250 @@ +#ifndef LM_FILTER_FORMAT_H__ +#define LM_FITLER_FORMAT_H__ + +#include "lm/filter/arpa_io.hh" +#include "lm/filter/count_io.hh" + +#include +#include + +#include + +namespace lm { + +template class MultipleOutput { + private: + typedef boost::ptr_vector Singles; + typedef typename Singles::iterator SinglesIterator; + + public: + MultipleOutput(const char *prefix, size_t number) { + files_.reserve(number); + std::string tmp; + for (unsigned int i = 0; i < number; ++i) { + tmp = prefix; + tmp += boost::lexical_cast(i); + files_.push_back(new Single(tmp.c_str())); + } + } + + void AddNGram(const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(line); + } + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + for (SinglesIterator i = files_.begin(); i != files_.end(); ++i) + i->AddNGram(begin, end, line); + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + files_[offset].AddNGram(line); + } + + template void SingleAddNGram(size_t offset, const Iterator &begin, const Iterator &end, const StringPiece &line) { + files_[offset].AddNGram(begin, end, line); + } + + protected: + Singles files_; +}; + +class MultipleARPAOutput : public MultipleOutput { + public: + MultipleARPAOutput(const char *prefix, size_t number) : MultipleOutput(prefix, number) {} + + void ReserveForCounts(std::streampos reserve) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->ReserveForCounts(reserve); + } + + void BeginLength(unsigned int length) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->BeginLength(length); + } + + void EndLength(unsigned int length) { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->EndLength(length); + } + + void Finish() { + for (boost::ptr_vector::iterator i = files_.begin(); i != files_.end(); ++i) + i->Finish(); + } +}; + +template class DispatchInput { + public: + DispatchInput(Filter &filter, Output &output) : filter_(filter), output_(output) {} + +/* template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line) { + filter_.AddNGram(begin, end, line, output_); + }*/ + + void AddNGram(const StringPiece &ngram, const StringPiece &line) { + filter_.AddNGram(ngram, line, output_); + } + + protected: + Filter &filter_; + Output &output_; +}; + +template class DispatchARPAInput : public DispatchInput { + private: + typedef DispatchInput B; + + public: + DispatchARPAInput(Filter &filter, Output &output) : B(filter, output) {} + + void ReserveForCounts(std::streampos reserve) { B::output_.ReserveForCounts(reserve); } + void BeginLength(unsigned int length) { B::output_.BeginLength(length); } + + void EndLength(unsigned int length) { + B::filter_.Flush(); + B::output_.EndLength(length); + } + void Finish() { B::output_.Finish(); } +}; + +struct ARPAFormat { + typedef ARPAOutput Output; + typedef MultipleARPAOutput Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadARPA(in, out); + } + template static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchARPAInput dispatcher(filter, output); + ReadARPA(in, dispatcher); + } +}; + +struct CountFormat { + typedef CountOutput Output; + typedef MultipleOutput Multiple; + static void Copy(util::FilePiece &in, Output &out) { + ReadCount(in, out); + } + template static void RunFilter(util::FilePiece &in, Filter &filter, Out &output) { + DispatchInput dispatcher(filter, output); + ReadCount(in, dispatcher); + } +}; + +/* For multithreading, the buffer classes hold batches of filter inputs and + * outputs in memory. The strings get reused a lot, so keep them around + * instead of clearing each time. + */ +class InputBuffer { + public: + InputBuffer() : actual_(0) {} + + void Reserve(size_t size) { lines_.reserve(size); } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + if (lines_.size() == actual_) lines_.resize(lines_.size() + 1); + // TODO avoid this copy. + std::string &copied = lines_[actual_].line; + copied.assign(line.data(), line.size()); + lines_[actual_].ngram.set(copied.data() + (ngram.data() - line.data()), ngram.size()); + ++actual_; + } + + template void CallFilter(Filter &filter, Output &output) const { + for (std::vector::const_iterator i = lines_.begin(); i != lines_.begin() + actual_; ++i) { + filter.AddNGram(i->ngram, i->line, output); + } + } + + void Clear() { actual_ = 0; } + bool Empty() { return actual_ == 0; } + size_t Size() { return actual_; } + + private: + struct Line { + std::string line; + StringPiece ngram; + }; + + size_t actual_; + + std::vector lines_; +}; + +class BinaryOutputBuffer { + public: + BinaryOutputBuffer() {} + + void Reserve(size_t size) { + lines_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + lines_.push_back(line); + } + + template void Flush(Output &output) { + for (std::vector::const_iterator i = lines_.begin(); i != lines_.end(); ++i) { + output.AddNGram(*i); + } + lines_.clear(); + } + + private: + std::vector lines_; +}; + +class MultipleOutputBuffer { + public: + MultipleOutputBuffer() : last_(NULL) {} + + void Reserve(size_t size) { + annotated_.reserve(size); + } + + void AddNGram(const StringPiece &line) { + annotated_.resize(annotated_.size() + 1); + annotated_.back().line = line; + } + + void SingleAddNGram(size_t offset, const StringPiece &line) { + if ((line.data() == last_.data()) && (line.length() == last_.length())) { + annotated_.back().systems.push_back(offset); + } else { + annotated_.resize(annotated_.size() + 1); + annotated_.back().systems.push_back(offset); + annotated_.back().line = line; + last_ = line; + } + } + + template void Flush(Output &output) { + for (std::vector::const_iterator i = annotated_.begin(); i != annotated_.end(); ++i) { + if (i->systems.empty()) { + output.AddNGram(i->line); + } else { + for (std::vector::const_iterator j = i->systems.begin(); j != i->systems.end(); ++j) { + output.SingleAddNGram(*j, i->line); + } + } + } + annotated_.clear(); + } + + private: + struct Annotated { + // If this is empty, send to all systems. + // A filter should never send to all systems and send to a single one. + std::vector systems; + StringPiece line; + }; + + StringPiece last_; + + std::vector annotated_; +}; + +} // namespace lm + +#endif // LM_FILTER_FORMAT_H__ diff --git a/klm/lm/filter/main.cc b/klm/lm/filter/main.cc new file mode 100644 index 00000000..c42243e2 --- /dev/null +++ b/klm/lm/filter/main.cc @@ -0,0 +1,249 @@ +#include "lm/filter/arpa_io.hh" +#include "lm/filter/format.hh" +#include "lm/filter/phrase.hh" +#ifndef NTHREAD +#include "lm/filter/thread.hh" +#endif +#include "lm/filter/vocab.hh" +#include "lm/filter/wrapper.hh" +#include "util/file_piece.hh" + +#include + +#include +#include +#include +#include + +namespace lm { +namespace { + +void DisplayHelp(const char *name) { + std::cerr + << "Usage: " << name << " mode [context] [phrase] [raw|arpa] [threads:m] [batch_size:m] (vocab|model):input_file output_file\n\n" + "copy mode just copies, but makes the format nicer for e.g. irstlm's broken\n" + " parser.\n" + "single mode treats the entire input as a single sentence.\n" + "multiple mode filters to multiple sentences in parallel. Each sentence is on\n" + " a separate line. A separate file is created for each file by appending the\n" + " 0-indexed line number to the output file name.\n" + "union mode produces one filtered model that is the union of models created by\n" + " multiple mode.\n\n" + "context means only the context (all but last word) has to pass the filter, but\n" + " the entire n-gram is output.\n\n" + "phrase means that the vocabulary is actually tab-delimited phrases and that the\n" + " phrases can generate the n-gram when assembled in arbitrary order and\n" + " clipped. Currently works with multiple or union mode.\n\n" + "The file format is set by [raw|arpa] with default arpa:\n" + "raw means space-separated tokens, optionally followed by a tab and arbitrary\n" + " text. This is useful for ngram count files.\n" + "arpa means the ARPA file format for n-gram language models.\n\n" +#ifndef NTHREAD + "threads:m sets m threads (default: conccurrency detected by boost)\n" + "batch_size:m sets the batch size for threading. Expect memory usage from this\n" + " of 2*threads*batch_size n-grams.\n\n" +#else + "This binary was compiled with -DNTHREAD, disabling threading. If you wanted\n" + " threading, compile without this flag against Boost >=1.42.0.\n\n" +#endif + "There are two inputs: vocabulary and model. Either may be given as a file\n" + " while the other is on stdin. Specify the type given as a file using\n" + " vocab: or model: before the file name. \n\n" + "For ARPA format, the output must be seekable. For raw format, it can be a\n" + " stream i.e. /dev/stdout\n"; +} + +typedef enum {MODE_COPY, MODE_SINGLE, MODE_MULTIPLE, MODE_UNION} FilterMode; +typedef enum {FORMAT_ARPA, FORMAT_COUNT} Format; + +struct Config { + Config() : +#ifndef NTHREAD + batch_size(25000), + threads(boost::thread::hardware_concurrency()), +#endif + phrase(false), + context(false), + format(FORMAT_ARPA) + { +#ifndef NTHREAD + if (!threads) threads = 1; +#endif + } + +#ifndef NTHREAD + size_t batch_size; + size_t threads; +#endif + bool phrase; + bool context; + FilterMode mode; + Format format; +}; + +template void RunThreadedFilter(const Config &config, util::FilePiece &in_lm, Filter &filter, Output &output) { +#ifndef NTHREAD + if (config.threads == 1) { +#endif + Format::RunFilter(in_lm, filter, output); +#ifndef NTHREAD + } else { + typedef Controller Threaded; + Threaded threading(config.batch_size, config.threads * 2, config.threads, filter, output); + Format::RunFilter(in_lm, threading, output); + } +#endif +} + +template void RunContextFilter(const Config &config, util::FilePiece &in_lm, Filter filter, Output &output) { + if (config.context) { + ContextFilter context_filter(filter); + RunThreadedFilter, OutputBuffer, Output>(config, in_lm, context_filter, output); + } else { + RunThreadedFilter(config, in_lm, filter, output); + } +} + +template void DispatchBinaryFilter(const Config &config, util::FilePiece &in_lm, const Binary &binary, typename Format::Output &out) { + typedef BinaryFilter Filter; + RunContextFilter(config, in_lm, Filter(binary), out); +} + +template void DispatchFilterModes(const Config &config, std::istream &in_vocab, util::FilePiece &in_lm, const char *out_name) { + if (config.mode == MODE_MULTIPLE) { + if (config.phrase) { + typedef phrase::Multiple Filter; + phrase::Substrings substrings; + typename Format::Multiple out(out_name, phrase::ReadMultiple(in_vocab, substrings)); + RunContextFilter(config, in_lm, Filter(substrings), out); + } else { + typedef vocab::Multiple Filter; + boost::unordered_map > words; + typename Format::Multiple out(out_name, vocab::ReadMultiple(in_vocab, words)); + RunContextFilter(config, in_lm, Filter(words), out); + } + return; + } + + typename Format::Output out(out_name); + + if (config.mode == MODE_COPY) { + Format::Copy(in_lm, out); + return; + } + + if (config.mode == MODE_SINGLE) { + vocab::Single::Words words; + vocab::ReadSingle(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Single(words), out); + return; + } + + if (config.mode == MODE_UNION) { + if (config.phrase) { + phrase::Substrings substrings; + phrase::ReadMultiple(in_vocab, substrings); + DispatchBinaryFilter(config, in_lm, phrase::Union(substrings), out); + } else { + vocab::Union::Words words; + vocab::ReadMultiple(in_vocab, words); + DispatchBinaryFilter(config, in_lm, vocab::Union(words), out); + } + return; + } +} + +} // namespace +} // namespace lm + +int main(int argc, char *argv[]) { + if (argc < 4) { + lm::DisplayHelp(argv[0]); + return 1; + } + + // I used to have boost::program_options, but some users didn't want to compile boost. + lm::Config config; + boost::optional mode; + for (int i = 1; i < argc - 2; ++i) { + const char *str = argv[i]; + if (!std::strcmp(str, "copy")) { + mode = lm::MODE_COPY; + } else if (!std::strcmp(str, "single")) { + mode = lm::MODE_SINGLE; + } else if (!std::strcmp(str, "multiple")) { + mode = lm::MODE_MULTIPLE; + } else if (!std::strcmp(str, "union")) { + mode = lm::MODE_UNION; + } else if (!std::strcmp(str, "phrase")) { + config.phrase = true; + } else if (!std::strcmp(str, "context")) { + config.context = true; + } else if (!std::strcmp(str, "arpa")) { + config.format = lm::FORMAT_ARPA; + } else if (!std::strcmp(str, "raw")) { + config.format = lm::FORMAT_COUNT; +#ifndef NTHREAD + } else if (!std::strncmp(str, "threads:", 8)) { + config.threads = boost::lexical_cast(str + 8); + if (!config.threads) { + std::cerr << "Specify at least one thread." << std::endl; + return 1; + } + } else if (!std::strncmp(str, "batch_size:", 11)) { + config.batch_size = boost::lexical_cast(str + 11); + if (config.batch_size < 5000) { + std::cerr << "Batch size must be at least one and should probably be >= 5000" << std::endl; + if (!config.batch_size) return 1; + } +#endif + } else { + lm::DisplayHelp(argv[0]); + return 1; + } + } + + if (!mode) { + lm::DisplayHelp(argv[0]); + return 1; + } + config.mode = *mode; + + if (config.phrase && config.mode != lm::MODE_UNION && mode != lm::MODE_MULTIPLE) { + std::cerr << "Phrase constraint currently only works in multiple or union mode. If you really need it for single, put everything on one line and use union." << std::endl; + return 1; + } + + bool cmd_is_model = true; + const char *cmd_input = argv[argc - 2]; + if (!strncmp(cmd_input, "vocab:", 6)) { + cmd_is_model = false; + cmd_input += 6; + } else if (!strncmp(cmd_input, "model:", 6)) { + cmd_input += 6; + } else if (strchr(cmd_input, ':')) { + errx(1, "Specify vocab: or model: before the input file name, not \"%s\"", cmd_input); + } else { + std::cerr << "Assuming that " << cmd_input << " is a model file" << std::endl; + } + std::ifstream cmd_file; + std::istream *vocab; + if (cmd_is_model) { + vocab = &std::cin; + } else { + cmd_file.open(cmd_input, std::ios::in); + if (!cmd_file) { + err(2, "Could not open input file %s", cmd_input); + } + vocab = &cmd_file; + } + + util::FilePiece model(cmd_is_model ? util::OpenReadOrThrow(cmd_input) : 0, cmd_is_model ? cmd_input : NULL, &std::cerr); + + if (config.format == lm::FORMAT_ARPA) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } else if (config.format == lm::FORMAT_COUNT) { + lm::DispatchFilterModes(config, *vocab, model, argv[argc - 1]); + } + return 0; +} diff --git a/klm/lm/filter/phrase.cc b/klm/lm/filter/phrase.cc new file mode 100644 index 00000000..1bef2a3f --- /dev/null +++ b/klm/lm/filter/phrase.cc @@ -0,0 +1,281 @@ +#include "lm/filter/phrase.hh" + +#include "lm/filter/format.hh" + +#include +#include +#include +#include +#include +#include + +#include + +namespace lm { +namespace phrase { + +unsigned int ReadMultiple(std::istream &in, Substrings &out) { + bool sentence_content = false; + unsigned int sentence_id = 0; + std::vector phrase; + std::string word; + while (in) { + char c; + // Gather a word. + while (!isspace(c = in.get()) && in) word += c; + // Treat EOF like a newline. + if (!in) c = '\n'; + // Add the word to the phrase. + if (!word.empty()) { + phrase.push_back(util::MurmurHashNative(word.data(), word.size())); + word.clear(); + } + if (c == ' ') continue; + // It's more than just a space. Close out the phrase. + if (!phrase.empty()) { + sentence_content = true; + out.AddPhrase(sentence_id, phrase.begin(), phrase.end()); + phrase.clear(); + } + if (c == '\t' || c == '\v') continue; + // It's more than a space or tab: a newline. + if (sentence_content) { + ++sentence_id; + sentence_content = false; + } + } + if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit); + return sentence_id + sentence_content; +} + +namespace detail { const StringPiece kEndSentence(""); } + +namespace { + +typedef unsigned int Sentence; +typedef std::vector Sentences; + +class Vertex; + +class Arc { + public: + Arc() {} + + // For arcs from one vertex to another. + void SetPhrase(Vertex &from, Vertex &to, const Sentences &intersect) { + Set(to, intersect); + from_ = &from; + } + + /* For arcs from before the n-gram begins to somewhere in the n-gram (right + * aligned). These have no from_ vertex; it implictly matches every + * sentence. This also handles when the n-gram is a substring of a phrase. + */ + void SetRight(Vertex &to, const Sentences &complete) { + Set(to, complete); + from_ = NULL; + } + + Sentence Current() const { + return *current_; + } + + bool Empty() const { + return current_ == last_; + } + + /* When this function returns: + * If Empty() then there's nothing left from this intersection. + * + * If Current() == to then to is part of the intersection. + * + * Otherwise, Current() > to. In this case, to is not part of the + * intersection and neither is anything < Current(). To determine if + * any value >= Current() is in the intersection, call LowerBound again + * with the value. + */ + void LowerBound(const Sentence to); + + private: + void Set(Vertex &to, const Sentences &sentences); + + const Sentence *current_; + const Sentence *last_; + Vertex *from_; +}; + +struct ArcGreater : public std::binary_function { + bool operator()(const Arc *first, const Arc *second) const { + return first->Current() > second->Current(); + } +}; + +class Vertex { + public: + Vertex() : current_(0) {} + + Sentence Current() const { + return current_; + } + + bool Empty() const { + return incoming_.empty(); + } + + void LowerBound(const Sentence to); + + private: + friend class Arc; + + void AddIncoming(Arc *arc) { + if (!arc->Empty()) incoming_.push(arc); + } + + unsigned int current_; + std::priority_queue, ArcGreater> incoming_; +}; + +void Arc::LowerBound(const Sentence to) { + current_ = std::lower_bound(current_, last_, to); + // If *current_ > to, don't advance from_. The intervening values of + // from_ may be useful for another one of its outgoing arcs. + if (!from_ || Empty() || (Current() > to)) return; + assert(Current() == to); + from_->LowerBound(to); + if (from_->Empty()) { + current_ = last_; + return; + } + assert(from_->Current() >= to); + if (from_->Current() > to) { + current_ = std::lower_bound(current_ + 1, last_, from_->Current()); + } +} + +void Arc::Set(Vertex &to, const Sentences &sentences) { + current_ = &*sentences.begin(); + last_ = &*sentences.end(); + to.AddIncoming(this); +} + +void Vertex::LowerBound(const Sentence to) { + if (Empty()) return; + // Union lower bound. + while (true) { + Arc *top = incoming_.top(); + if (top->Current() > to) { + current_ = top->Current(); + return; + } + // If top->Current() == to, we still need to verify that's an actual + // element and not just a bound. + incoming_.pop(); + top->LowerBound(to); + if (!top->Empty()) { + incoming_.push(top); + if (top->Current() == to) { + current_ = to; + return; + } + } else if (Empty()) { + return; + } + } +} + +void BuildGraph(const Substrings &phrase, const std::vector &hashes, Vertex *const vertices, Arc *free_arc) { + assert(!hashes.empty()); + + const Hash *const first_word = &*hashes.begin(); + const Hash *const last_word = &*hashes.end() - 1; + + Hash hash = 0; + const Sentences *found; + // Phrases starting at or before the first word in the n-gram. + { + Vertex *vertex = vertices; + for (const Hash *word = first_word; ; ++word, ++vertex) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word); + // Now hash is [hashes.begin(), word]. + if (word == last_word) { + if (phrase.FindSubstring(hash, found)) + (free_arc++)->SetRight(*vertex, *found); + break; + } + if (!phrase.FindRight(hash, found)) break; + (free_arc++)->SetRight(*vertex, *found); + } + } + + // Phrases starting at the second or later word in the n-gram. + Vertex *vertex_from = vertices; + for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) { + hash = 0; + Vertex *vertex_to = vertex_from + 1; + for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) { + // Notice that word_to and vertex_to have the same index. + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to); + // Now hash covers [word_from, word_to]. + if (word_to == last_word) { + if (phrase.FindLeft(hash, found)) + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + break; + } + if (!phrase.FindPhrase(hash, found)) break; + (free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); + } + } +} + +} // namespace + +namespace detail { + +} // namespace detail + +bool Union::Evaluate() { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return false; + if (last_vertex.Current() == lower) return true; + lower = last_vertex.Current(); + } +} + +template void Multiple::Evaluate(const StringPiece &line, Output &output) { + assert(!hashes_.empty()); + // Usually there are at most 6 words in an n-gram, so stack allocation is reasonable. + Vertex vertices[hashes_.size()]; + // One for every substring. + Arc arcs[((hashes_.size() + 1) * hashes_.size()) / 2]; + BuildGraph(substrings_, hashes_, vertices, arcs); + Vertex &last_vertex = vertices[hashes_.size() - 1]; + + unsigned int lower = 0; + while (true) { + last_vertex.LowerBound(lower); + if (last_vertex.Empty()) return; + if (last_vertex.Current() == lower) { + output.SingleAddNGram(lower, line); + ++lower; + } else { + lower = last_vertex.Current(); + } + } +} + +template void Multiple::Evaluate(const StringPiece &line, CountFormat::Multiple &output); +template void Multiple::Evaluate(const StringPiece &line, ARPAFormat::Multiple &output); +template void Multiple::Evaluate(const StringPiece &line, MultipleOutputBuffer &output); + +} // namespace phrase +} // namespace lm diff --git a/klm/lm/filter/phrase.hh b/klm/lm/filter/phrase.hh new file mode 100644 index 00000000..07479dea --- /dev/null +++ b/klm/lm/filter/phrase.hh @@ -0,0 +1,153 @@ +#ifndef LM_FILTER_PHRASE_H__ +#define LM_FILTER_PHRASE_H__ + +#include "util/murmur_hash.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include + +#include +#include + +#define LM_FILTER_PHRASE_METHOD(caps, lower) \ +bool Find##caps(Hash key, const std::vector *&out) const {\ + Table::const_iterator i(table_.find(key));\ + if (i==table_.end()) return false; \ + out = &i->second.lower; \ + return true; \ +} + +namespace lm { +namespace phrase { + +typedef uint64_t Hash; + +class Substrings { + private: + /* This is the value in a hash table where the key is a string. It indicates + * four sets of sentences: + * substring is sentences with a phrase containing the key as a substring. + * left is sentencess with a phrase that begins with the key (left aligned). + * right is sentences with a phrase that ends with the key (right aligned). + * phrase is sentences where the key is a phrase. + * Each set is encoded as a vector of sentence ids in increasing order. + */ + struct SentenceRelation { + std::vector substring, left, right, phrase; + }; + /* Most of the CPU is hash table lookups, so let's not complicate it with + * vector equality comparisons. If a collision happens, the SentenceRelation + * structure will contain the union of sentence ids over the colliding strings. + * In that case, the filter will be slightly more permissive. + * The key here is the same as boost's hash of std::vector. + */ + typedef boost::unordered_map Table; + + public: + Substrings() {} + + /* If the string isn't a substring of any phrase, return NULL. Otherwise, + * return a pointer to std::vector listing sentences with + * matching phrases. This set may be empty for Left, Right, or Phrase. + * Example: const std::vector *FindSubstring(Hash key) + */ + LM_FILTER_PHRASE_METHOD(Substring, substring) + LM_FILTER_PHRASE_METHOD(Left, left) + LM_FILTER_PHRASE_METHOD(Right, right) + LM_FILTER_PHRASE_METHOD(Phrase, phrase) + + // sentence_id must be non-decreasing. Iterators are over words in the phrase. + template void AddPhrase(unsigned int sentence_id, const Iterator &begin, const Iterator &end) { + // Iterate over all substrings. + for (Iterator start = begin; start != end; ++start) { + Hash hash = 0; + SentenceRelation *relation; + for (Iterator finish = start; finish != end; ++finish) { + hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *finish); + // Now hash is of [start, finish]. + relation = &table_[hash]; + AppendSentence(relation->substring, sentence_id); + if (start == begin) AppendSentence(relation->left, sentence_id); + } + AppendSentence(relation->right, sentence_id); + if (start == begin) AppendSentence(relation->phrase, sentence_id); + } + } + + private: + void AppendSentence(std::vector &vec, unsigned int sentence_id) { + if (vec.empty() || vec.back() != sentence_id) vec.push_back(sentence_id); + } + + Table table_; +}; + +// Read a file with one sentence per line containing tab-delimited phrases of +// space-separated words. +unsigned int ReadMultiple(std::istream &in, Substrings &out); + +namespace detail { +extern const StringPiece kEndSentence; + +template void MakeHashes(Iterator i, const Iterator &end, std::vector &hashes) { + hashes.clear(); + if (i == end) return; + // TODO: check strict phrase boundaries after and before . For now, just skip tags. + if ((i->data()[0] == '<') && (i->data()[i->size() - 1] == '>')) { + ++i; + } + for (; i != end && (*i != kEndSentence); ++i) { + hashes.push_back(util::MurmurHashNative(i->data(), i->size())); + } +} + +} // namespace detail + +class Union { + public: + explicit Union(const Substrings &substrings) : substrings_(substrings) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + detail::MakeHashes(begin, end, hashes_); + return hashes_.empty() || Evaluate(); + } + + private: + bool Evaluate(); + + std::vector hashes_; + + const Substrings &substrings_; +}; + +class Multiple { + public: + explicit Multiple(const Substrings &substrings) : substrings_(substrings) {} + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + detail::MakeHashes(begin, end, hashes_); + if (hashes_.empty()) { + output.AddNGram(line); + return; + } + Evaluate(line, output); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + template void Evaluate(const StringPiece &line, Output &output); + + std::vector hashes_; + + const Substrings &substrings_; +}; + +} // namespace phrase +} // namespace lm +#endif // LM_FILTER_PHRASE_H__ diff --git a/klm/lm/filter/thread.hh b/klm/lm/filter/thread.hh new file mode 100644 index 00000000..e785b263 --- /dev/null +++ b/klm/lm/filter/thread.hh @@ -0,0 +1,167 @@ +#ifndef LM_FILTER_THREAD_H__ +#define LM_FILTER_THREAD_H__ + +#include "util/thread_pool.hh" + +#include + +#include +#include + +namespace lm { + +template class ThreadBatch { + public: + ThreadBatch() {} + + void Reserve(size_t size) { + input_.Reserve(size); + output_.Reserve(size); + } + + // File reading thread. + InputBuffer &Fill(uint64_t sequence) { + sequence_ = sequence; + // Why wait until now to clear instead of after output? free in the same + // thread as allocated. + input_.Clear(); + return input_; + } + + // Filter worker thread. + template void CallFilter(Filter &filter) { + input_.CallFilter(filter, output_); + } + + uint64_t Sequence() const { return sequence_; } + + // File writing thread. + template void Flush(RealOutput &output) { + output_.Flush(output); + } + + private: + InputBuffer input_; + OutputBuffer output_; + + uint64_t sequence_; +}; + +template class FilterWorker { + public: + typedef Batch *Request; + + FilterWorker(const Filter &filter, util::PCQueue &done) : filter_(filter), done_(done) {} + + void operator()(Request request) { + request->CallFilter(filter_); + done_.Produce(request); + } + + private: + Filter filter_; + + util::PCQueue &done_; +}; + +// There should only be one OutputWorker. +template class OutputWorker { + public: + typedef Batch *Request; + + OutputWorker(Output &output, util::PCQueue &done) : output_(output), done_(done), base_sequence_(0) {} + + void operator()(Request request) { + assert(request->Sequence() >= base_sequence_); + // Assemble the output in order. + uint64_t pos = request->Sequence() - base_sequence_; + if (pos >= ordering_.size()) { + ordering_.resize(pos + 1, NULL); + } + ordering_[pos] = request; + while (!ordering_.empty() && ordering_.front()) { + ordering_.front()->Flush(output_); + done_.Produce(ordering_.front()); + ordering_.pop_front(); + ++base_sequence_; + } + } + + private: + Output &output_; + + util::PCQueue &done_; + + std::deque ordering_; + + uint64_t base_sequence_; +}; + +template class Controller : boost::noncopyable { + private: + typedef ThreadBatch Batch; + + public: + Controller(size_t batch_size, size_t queue, size_t workers, const Filter &filter, RealOutput &output) + : batch_size_(batch_size), queue_size_(queue), + batches_(queue), + to_read_(queue), + output_(queue, 1, boost::in_place(boost::ref(output), boost::ref(to_read_)), NULL), + filter_(queue, workers, boost::in_place(boost::ref(filter), boost::ref(output_.In())), NULL), + sequence_(0) { + for (size_t i = 0; i < queue; ++i) { + batches_[i].Reserve(batch_size); + local_read_.push(&batches_[i]); + } + NewInput(); + } + + void AddNGram(const StringPiece &ngram, const StringPiece &line, RealOutput &output) { + input_->AddNGram(ngram, line, output); + if (input_->Size() == batch_size_) { + FlushInput(); + NewInput(); + } + } + + void Flush() { + FlushInput(); + while (local_read_.size() < queue_size_) { + MoveRead(); + } + NewInput(); + } + + private: + void FlushInput() { + if (input_->Empty()) return; + filter_.Produce(local_read_.top()); + local_read_.pop(); + if (local_read_.empty()) MoveRead(); + } + + void NewInput() { + input_ = &local_read_.top()->Fill(sequence_++); + } + + void MoveRead() { + local_read_.push(to_read_.Consume()); + } + + const size_t batch_size_; + const size_t queue_size_; + + std::vector batches_; + + util::PCQueue to_read_; + std::stack local_read_; + util::ThreadPool > output_; + util::ThreadPool > filter_; + + uint64_t sequence_; + InputBuffer *input_; +}; + +} // namespace lm + +#endif // LM_FILTER_THREAD_H__ diff --git a/klm/lm/filter/vocab.cc b/klm/lm/filter/vocab.cc new file mode 100644 index 00000000..7ee4e84b --- /dev/null +++ b/klm/lm/filter/vocab.cc @@ -0,0 +1,54 @@ +#include "lm/filter/vocab.hh" + +#include +#include + +#include +#include + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set &out) { + in.exceptions(std::istream::badbit); + std::string word; + while (in >> word) { + out.insert(word); + } +} + +namespace { +bool IsLineEnd(std::istream &in) { + int got; + do { + got = in.get(); + if (!in) return true; + if (got == '\n') return true; + } while (isspace(got)); + in.unget(); + return false; +} +}// namespace + +// Read space separated words in enter separated lines. These lines can be +// very long, so don't read an entire line at a time. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out) { + in.exceptions(std::istream::badbit); + unsigned int sentence = 0; + bool used_id = false; + std::string word; + while (in >> word) { + used_id = true; + std::vector &posting = out[word]; + if (posting.empty() || (posting.back() != sentence)) + posting.push_back(sentence); + if (IsLineEnd(in)) { + ++sentence; + used_id = false; + } + } + return sentence + used_id; +} + +} // namespace vocab +} // namespace lm diff --git a/klm/lm/filter/vocab.hh b/klm/lm/filter/vocab.hh new file mode 100644 index 00000000..e2b6adff --- /dev/null +++ b/klm/lm/filter/vocab.hh @@ -0,0 +1,132 @@ +#ifndef LM_FILTER_VOCAB_H__ +#define LM_FILTER_VOCAB_H__ + +// Vocabulary-based filters for language models. + +#include "util/multi_intersection.hh" +#include "util/string_piece.hh" +#include "util/tokenize_piece.hh" + +#include +#include +#include +#include + +#include +#include + +namespace lm { +namespace vocab { + +void ReadSingle(std::istream &in, boost::unordered_set &out); + +// Read one sentence vocabulary per line. Return the number of sentences. +unsigned int ReadMultiple(std::istream &in, boost::unordered_map > &out); + +/* Is this a special tag like or ? This actually includes anything + * surrounded with < and >, which most tokenizers separate for real words, so + * this should not catch real words as it looks at a single token. + */ +inline bool IsTag(const StringPiece &value) { + // The parser should never give an empty string. + assert(!value.empty()); + return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>'); +} + +class Single { + public: + typedef boost::unordered_set Words; + + explicit Single(const Words &vocab) : vocab_(vocab) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + for (Iterator i = begin; i != end; ++i) { + if (IsTag(*i)) continue; + if (FindStringPiece(vocab_, *i) == vocab_.end()) return false; + } + return true; + } + + private: + const Words &vocab_; +}; + +class Union { + public: + typedef boost::unordered_map > Words; + + explicit Union(const Words &vocabs) : vocabs_(vocabs) {} + + template bool PassNGram(const Iterator &begin, const Iterator &end) { + sets_.clear(); + + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return false; + sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); + } + return (sets_.empty() || util::FirstIntersection(sets_)); + } + + private: + const Words &vocabs_; + + std::vector > sets_; +}; + +class Multiple { + public: + typedef boost::unordered_map > Words; + + Multiple(const Words &vocabs) : vocabs_(vocabs) {} + + private: + // Callback from AllIntersection that does AddNGram. + template class Callback { + public: + Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {} + + void operator()(unsigned int index) { + out_.SingleAddNGram(index, line_); + } + + private: + Output &out_; + const StringPiece &line_; + }; + + public: + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + sets_.clear(); + for (Iterator i(begin); i != end; ++i) { + if (IsTag(*i)) continue; + Words::const_iterator found(FindStringPiece(vocabs_, *i)); + if (vocabs_.end() == found) return; + sets_.push_back(boost::iterator_range(&*found->second.begin(), &*found->second.end())); + } + if (sets_.empty()) { + output.AddNGram(line); + return; + } + + Callback cb(output, line); + util::AllIntersection(sets_, cb); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + const Words &vocabs_; + + std::vector > sets_; +}; + +} // namespace vocab +} // namespace lm + +#endif // LM_FILTER_VOCAB_H__ diff --git a/klm/lm/filter/wrapper.hh b/klm/lm/filter/wrapper.hh new file mode 100644 index 00000000..90b07a08 --- /dev/null +++ b/klm/lm/filter/wrapper.hh @@ -0,0 +1,58 @@ +#ifndef LM_FILTER_WRAPPER_H__ +#define LM_FILTER_WRAPPER_H__ + +#include "util/string_piece.hh" + +#include +#include +#include + +namespace lm { + +// Provide a single-output filter with the same interface as a +// multiple-output filter so clients code against one interface. +template class BinaryFilter { + public: + // Binary modes are just references (and a set) and it makes the API cleaner to copy them. + explicit BinaryFilter(Binary binary) : binary_(binary) {} + + template void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) { + if (binary_.PassNGram(begin, end)) + output.AddNGram(line); + } + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + AddNGram(util::TokenIter(ngram, ' '), util::TokenIter::end(), line, output); + } + + void Flush() const {} + + private: + Binary binary_; +}; + +// Wrap another filter to pay attention only to context words +template class ContextFilter { + public: + typedef FilterT Filter; + + explicit ContextFilter(Filter &backend) : backend_(backend) {} + + template void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) { + pieces_.clear(); + // TODO: this copy could be avoided by a lookahead iterator. + std::copy(util::TokenIter(ngram, ' '), util::TokenIter::end(), std::back_insert_iterator >(pieces_)); + backend_.AddNGram(pieces_.begin(), pieces_.end() - !pieces_.empty(), line, output); + } + + void Flush() const {} + + private: + std::vector pieces_; + + Filter backend_; +}; + +} // namespace lm + +#endif // LM_FILTER_WRAPPER_H__ diff --git a/klm/lm/model_test.cc b/klm/lm/model_test.cc index 32084b5b..eb159094 100644 --- a/klm/lm/model_test.cc +++ b/klm/lm/model_test.cc @@ -1,6 +1,7 @@ #include "lm/model.hh" #include +#include #define BOOST_TEST_MODULE ModelTest #include @@ -22,17 +23,20 @@ std::ostream &operator<<(std::ostream &o, const State &state) { namespace { +// Stupid bjam reverses the command line arguments randomly. const char *TestLocation() { - if (boost::unit_test::framework::master_test_suite().argc < 2) { + if (boost::unit_test::framework::master_test_suite().argc < 3) { return "test.arpa"; } - return boost::unit_test::framework::master_test_suite().argv[1]; + char **argv = boost::unit_test::framework::master_test_suite().argv; + return argv[strstr(argv[1], "nounk") ? 2 : 1]; } const char *TestNoUnkLocation() { if (boost::unit_test::framework::master_test_suite().argc < 3) { return "test_nounk.arpa"; } - return boost::unit_test::framework::master_test_suite().argv[2]; + char **argv = boost::unit_test::framework::master_test_suite().argv; + return argv[strstr(argv[1], "nounk") ? 1 : 2]; } template State GetState(const Model &model, const char *word, const State &in) { diff --git a/klm/lm/read_arpa.cc b/klm/lm/read_arpa.cc index b709fef9..9ea08798 100644 --- a/klm/lm/read_arpa.cc +++ b/klm/lm/read_arpa.cc @@ -1,6 +1,7 @@ #include "lm/read_arpa.hh" #include "lm/blank.hh" +#include "util/file.hh" #include #include @@ -45,8 +46,14 @@ uint64_t ReadCount(const std::string &from) { void ReadARPACounts(util::FilePiece &in, std::vector &number) { number.clear(); - StringPiece line; - while (IsEntirelyWhiteSpace(line = in.ReadLine())) {} + StringPiece line = in.ReadLine(); + // In general, ARPA files can have arbitrary text before "\data\" + // But in KenLM, we require such lines to start with "#", so that + // we can do stricter error checking + while (IsEntirelyWhiteSpace(line) || line.starts_with("#")) { + line = in.ReadLine(); + } + if (line != "\\data\\") { if ((line.size() >= 2) && (line.data()[0] == 0x1f) && (static_cast(line.data()[1]) == 0x8b)) { UTIL_THROW(FormatLoadException, "Looks like a gzip file. If this is an ARPA file, pipe " << in.FileName() << " through zcat. If this already in binary format, you need to decompress it because mmap doesn't work on top of gzip."); diff --git a/klm/lm/sizes.cc b/klm/lm/sizes.cc new file mode 100644 index 00000000..55ad586c --- /dev/null +++ b/klm/lm/sizes.cc @@ -0,0 +1,63 @@ +#include "lm/sizes.hh" +#include "lm/model.hh" +#include "util/file_piece.hh" + +#include +#include + +namespace lm { +namespace ngram { + +void ShowSizes(const std::vector &counts, const lm::ngram::Config &config) { + uint64_t sizes[6]; + sizes[0] = ProbingModel::Size(counts, config); + sizes[1] = RestProbingModel::Size(counts, config); + sizes[2] = TrieModel::Size(counts, config); + sizes[3] = QuantTrieModel::Size(counts, config); + sizes[4] = ArrayTrieModel::Size(counts, config); + sizes[5] = QuantArrayTrieModel::Size(counts, config); + uint64_t max_length = *std::max_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t min_length = *std::min_element(sizes, sizes + sizeof(sizes) / sizeof(uint64_t)); + uint64_t divide; + char prefix; + if (min_length < (1 << 10) * 10) { + prefix = ' '; + divide = 1; + } else if (min_length < (1 << 20) * 10) { + prefix = 'k'; + divide = 1 << 10; + } else if (min_length < (1ULL << 30) * 10) { + prefix = 'M'; + divide = 1 << 20; + } else { + prefix = 'G'; + divide = 1 << 30; + } + long int length = std::max(2, static_cast(ceil(log10((double) max_length / divide)))); + std::cerr << "Memory estimate for binary LM:\ntype "; + + // right align bytes. + for (long int i = 0; i < length - 2; ++i) std::cerr << ' '; + + std::cerr << prefix << "B\n" + "probing " << std::setw(length) << (sizes[0] / divide) << " assuming -p " << config.probing_multiplier << "\n" + "probing " << std::setw(length) << (sizes[1] / divide) << " assuming -r models -p " << config.probing_multiplier << "\n" + "trie " << std::setw(length) << (sizes[2] / divide) << " without quantization\n" + "trie " << std::setw(length) << (sizes[3] / divide) << " assuming -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits << " quantization \n" + "trie " << std::setw(length) << (sizes[4] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " array pointer compression\n" + "trie " << std::setw(length) << (sizes[5] / divide) << " assuming -a " << (unsigned)config.pointer_bhiksha_bits << " -q " << (unsigned)config.prob_bits << " -b " << (unsigned)config.backoff_bits<< " array pointer compression and quantization\n"; +} + +void ShowSizes(const std::vector &counts) { + lm::ngram::Config config; + ShowSizes(counts, config); +} + +void ShowSizes(const char *file, const lm::ngram::Config &config) { + std::vector counts; + util::FilePiece f(file); + lm::ReadARPACounts(f, counts); + ShowSizes(counts, config); +} + +}} //namespaces diff --git a/klm/lm/sizes.hh b/klm/lm/sizes.hh new file mode 100644 index 00000000..85abade7 --- /dev/null +++ b/klm/lm/sizes.hh @@ -0,0 +1,17 @@ +#ifndef LM_SIZES__ +#define LM_SIZES__ + +#include + +#include + +namespace lm { namespace ngram { + +struct Config; + +void ShowSizes(const std::vector &counts, const lm::ngram::Config &config); +void ShowSizes(const std::vector &counts); +void ShowSizes(const char *file, const lm::ngram::Config &config); + +}} // namespaces +#endif // LM_SIZES__ diff --git a/klm/lm/state.hh b/klm/lm/state.hh index 551510a8..d8e6c132 100644 --- a/klm/lm/state.hh +++ b/klm/lm/state.hh @@ -56,14 +56,14 @@ inline uint64_t hash_value(const State &state, uint64_t seed = 0) { struct Left { bool operator==(const Left &other) const { return - (length == other.length) && - pointers[length - 1] == other.pointers[length - 1] && - full == other.full; + length == other.length && + (!length || (pointers[length - 1] == other.pointers[length - 1] && full == other.full)); } int Compare(const Left &other) const { if (length < other.length) return -1; if (length > other.length) return 1; + if (length == 0) return 0; // Must be full. if (pointers[length - 1] > other.pointers[length - 1]) return 1; if (pointers[length - 1] < other.pointers[length - 1]) return -1; return (int)full - (int)other.full; diff --git a/klm/lm/trie_sort.cc b/klm/lm/trie_sort.cc index 8663e94e..dc542bb3 100644 --- a/klm/lm/trie_sort.cc +++ b/klm/lm/trie_sort.cc @@ -65,13 +65,13 @@ class PartialViewProxy { typedef util::ProxyIterator PartialIter; -FILE *DiskFlush(const void *mem_begin, const void *mem_end, const util::TempMaker &maker) { - util::scoped_fd file(maker.Make()); +FILE *DiskFlush(const void *mem_begin, const void *mem_end, const std::string &temp_prefix) { + util::scoped_fd file(util::MakeTemp(temp_prefix)); util::WriteOrThrow(file.get(), mem_begin, (uint8_t*)mem_end - (uint8_t*)mem_begin); return util::FDOpenOrThrow(file); } -FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &maker, std::size_t entry_size, unsigned char order) { +FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const std::string &temp_prefix, std::size_t entry_size, unsigned char order) { const size_t context_size = sizeof(WordIndex) * (order - 1); // Sort just the contexts using the same memory. PartialIter context_begin(PartialViewProxy(begin + sizeof(WordIndex), entry_size, context_size)); @@ -84,7 +84,7 @@ FILE *WriteContextFile(uint8_t *begin, uint8_t *end, const util::TempMaker &make #endif (context_begin, context_end, util::SizedCompare(EntryCompare(order - 1))); - util::scoped_FILE out(maker.MakeFile()); + util::scoped_FILE out(util::FMakeTemp(temp_prefix)); // Write out to file and uniqueify at the same time. Could have used unique_copy if there was an appropriate OutputIterator. if (context_begin == context_end) return out.release(); @@ -114,12 +114,12 @@ struct FirstCombine { } }; -template FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const util::TempMaker &maker, std::size_t weights_size, unsigned char order, const Combine &combine) { +template FILE *MergeSortedFiles(FILE *first_file, FILE *second_file, const std::string &temp_prefix, std::size_t weights_size, unsigned char order, const Combine &combine) { std::size_t entry_size = sizeof(WordIndex) * order + weights_size; RecordReader first, second; first.Init(first_file, entry_size); second.Init(second_file, entry_size); - util::scoped_FILE out_file(maker.MakeFile()); + util::scoped_FILE out_file(util::FMakeTemp(temp_prefix)); EntryCompare less(order); while (first && second) { if (less(first.Data(), second.Data())) { @@ -177,9 +177,8 @@ void RecordReader::Rewind() { } SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, size_t buffer, const std::string &file_prefix, SortedVocabulary &vocab) { - util::TempMaker maker(file_prefix); PositiveProbWarn warn(config.positive_log_probability); - unigram_.reset(maker.Make()); + unigram_.reset(util::MakeTemp(file_prefix)); { // In case appears. size_t size_out = (counts[0] + 1) * sizeof(ProbBackoff); @@ -202,7 +201,7 @@ SortedFiles::SortedFiles(const Config &config, util::FilePiece &f, std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { +void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const std::string &file_prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size) { ReadNGramHeader(f, order); const size_t count = counts[order - 1]; // Size of weights. Does it include backoff? @@ -261,8 +260,8 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo std::sort #endif (NGramIter(proxy_begin), NGramIter(proxy_end), util::SizedCompare(EntryCompare(order))); - files.push_back(DiskFlush(begin, out_end, maker)); - contexts.push_back(WriteContextFile(begin, out_end, maker, entry_size, order)); + files.push_back(DiskFlush(begin, out_end, file_prefix)); + contexts.push_back(WriteContextFile(begin, out_end, file_prefix, entry_size, order)); done += (out_end - begin) / entry_size; } @@ -270,10 +269,10 @@ void SortedFiles::ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vo // All individual files created. Merge them. while (files.size() > 1) { - files.push_back(MergeSortedFiles(files[0], files[1], maker, weights_size, order, ThrowCombine())); + files.push_back(MergeSortedFiles(files[0], files[1], file_prefix, weights_size, order, ThrowCombine())); files_closer.PopFront(); files_closer.PopFront(); - contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], maker, 0, order - 1, FirstCombine())); + contexts.push_back(MergeSortedFiles(contexts[0], contexts[1], file_prefix, 0, order - 1, FirstCombine())); contexts_closer.PopFront(); contexts_closer.PopFront(); } diff --git a/klm/lm/trie_sort.hh b/klm/lm/trie_sort.hh index 2197b80c..1afd9562 100644 --- a/klm/lm/trie_sort.hh +++ b/klm/lm/trie_sort.hh @@ -18,7 +18,6 @@ namespace util { class FilePiece; -class TempMaker; } // namespace util namespace lm { @@ -101,7 +100,7 @@ class SortedFiles { } private: - void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const util::TempMaker &maker, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); + void ConvertToSorted(util::FilePiece &f, const SortedVocabulary &vocab, const std::vector &counts, const std::string &prefix, unsigned char order, PositiveProbWarn &warn, void *mem, std::size_t mem_size); util::scoped_fd unigram_; diff --git a/klm/search/Jamfile b/klm/search/Jamfile deleted file mode 100644 index bc95c53a..00000000 --- a/klm/search/Jamfile +++ /dev/null @@ -1,5 +0,0 @@ -lib search : weights.cc vertex.cc vertex_generator.cc edge_generator.cc rule.cc ../lm//kenlm ../util//kenutil /top//boost_system : : : .. ; - -import testing ; - -unit-test weights_test : weights_test.cc search /top//boost_unit_test_framework ; diff --git a/klm/test.sh b/klm/test.sh deleted file mode 100755 index fb33300a..00000000 --- a/klm/test.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -#Run tests. Requires Boost. -set -e -./compile.sh -for i in util/{bit_packing,file_piece,joint_sort,key_value_packing,probing_hash_table,sorted_uniform}_test lm/{model,left}_test; do - g++ -I. -O3 $CXXFLAGS $i.cc {lm,util}/*.o -lboost_test_exec_monitor -lz -o $i - pushd $(dirname $i) >/dev/null && ./$(basename $i) || echo "$i failed"; popd >/dev/null -done diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 3ab7560f..294ebc0a 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -29,6 +29,7 @@ libklm_util_a_SOURCES = \ joint_sort.hh \ mmap.hh \ murmur_hash.hh \ + pcqueue.hh \ pool.hh \ probing_hash_table.hh \ proxy_iterator.hh \ @@ -37,6 +38,7 @@ libklm_util_a_SOURCES = \ sized_iterator.hh \ sorted_uniform.hh \ string_piece.hh \ + thread_pool.hh \ tokenize_piece.hh \ usage.hh \ ersatz_progress.cc \ @@ -48,7 +50,8 @@ libklm_util_a_SOURCES = \ murmur_hash.cc \ pool.cc \ read_compressed.cc \ + scoped.cc \ string_piece.cc \ usage.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion diff --git a/klm/util/double-conversion/LICENSE b/klm/util/double-conversion/LICENSE new file mode 100644 index 00000000..933718a9 --- /dev/null +++ b/klm/util/double-conversion/LICENSE @@ -0,0 +1,26 @@ +Copyright 2006-2011, the V8 project authors. All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + * Neither the name of Google Inc. nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/klm/util/double-conversion/Makefile.am b/klm/util/double-conversion/Makefile.am new file mode 100644 index 00000000..eb6616f7 --- /dev/null +++ b/klm/util/double-conversion/Makefile.am @@ -0,0 +1,23 @@ +noinst_LIBRARIES = libklm_util_double.a + +libklm_util_double_a_SOURCES = \ + bignum-dtoa.h \ + bignum.h \ + cached-powers.h \ + diy-fp.h \ + double-conversion.h \ + fast-dtoa.h \ + fixed-dtoa.h \ + ieee.h \ + strtod.h \ + utils.h \ + bignum.cc \ + bignum-dtoa.cc \ + cached-powers.cc \ + diy-fp.cc \ + double-conversion.cc \ + fast-dtoa.cc \ + fixed-dtoa.cc \ + strtod.cc + +AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion diff --git a/klm/util/double-conversion/bignum-dtoa.cc b/klm/util/double-conversion/bignum-dtoa.cc new file mode 100644 index 00000000..b6c2e85d --- /dev/null +++ b/klm/util/double-conversion/bignum-dtoa.cc @@ -0,0 +1,640 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "bignum-dtoa.h" + +#include "bignum.h" +#include "ieee.h" + +namespace double_conversion { + +static int NormalizedExponent(uint64_t significand, int exponent) { + ASSERT(significand != 0); + while ((significand & Double::kHiddenBit) == 0) { + significand = significand << 1; + exponent = exponent - 1; + } + return exponent; +} + + +// Forward declarations: +// Returns an estimation of k such that 10^(k-1) <= v < 10^k. +static int EstimatePower(int exponent); +// Computes v / 10^estimated_power exactly, as a ratio of two bignums, numerator +// and denominator. +static void InitialScaledStartValues(uint64_t significand, + int exponent, + bool lower_boundary_is_closer, + int estimated_power, + bool need_boundary_deltas, + Bignum* numerator, + Bignum* denominator, + Bignum* delta_minus, + Bignum* delta_plus); +// Multiplies numerator/denominator so that its values lies in the range 1-10. +// Returns decimal_point s.t. +// v = numerator'/denominator' * 10^(decimal_point-1) +// where numerator' and denominator' are the values of numerator and +// denominator after the call to this function. +static void FixupMultiply10(int estimated_power, bool is_even, + int* decimal_point, + Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus); +// Generates digits from the left to the right and stops when the generated +// digits yield the shortest decimal representation of v. +static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus, + bool is_even, + Vector buffer, int* length); +// Generates 'requested_digits' after the decimal point. +static void BignumToFixed(int requested_digits, int* decimal_point, + Bignum* numerator, Bignum* denominator, + Vector(buffer), int* length); +// Generates 'count' digits of numerator/denominator. +// Once 'count' digits have been produced rounds the result depending on the +// remainder (remainders of exactly .5 round upwards). Might update the +// decimal_point when rounding up (for example for 0.9999). +static void GenerateCountedDigits(int count, int* decimal_point, + Bignum* numerator, Bignum* denominator, + Vector(buffer), int* length); + + +void BignumDtoa(double v, BignumDtoaMode mode, int requested_digits, + Vector buffer, int* length, int* decimal_point) { + ASSERT(v > 0); + ASSERT(!Double(v).IsSpecial()); + uint64_t significand; + int exponent; + bool lower_boundary_is_closer; + if (mode == BIGNUM_DTOA_SHORTEST_SINGLE) { + float f = static_cast(v); + ASSERT(f == v); + significand = Single(f).Significand(); + exponent = Single(f).Exponent(); + lower_boundary_is_closer = Single(f).LowerBoundaryIsCloser(); + } else { + significand = Double(v).Significand(); + exponent = Double(v).Exponent(); + lower_boundary_is_closer = Double(v).LowerBoundaryIsCloser(); + } + bool need_boundary_deltas = + (mode == BIGNUM_DTOA_SHORTEST || mode == BIGNUM_DTOA_SHORTEST_SINGLE); + + bool is_even = (significand & 1) == 0; + int normalized_exponent = NormalizedExponent(significand, exponent); + // estimated_power might be too low by 1. + int estimated_power = EstimatePower(normalized_exponent); + + // Shortcut for Fixed. + // The requested digits correspond to the digits after the point. If the + // number is much too small, then there is no need in trying to get any + // digits. + if (mode == BIGNUM_DTOA_FIXED && -estimated_power - 1 > requested_digits) { + buffer[0] = '\0'; + *length = 0; + // Set decimal-point to -requested_digits. This is what Gay does. + // Note that it should not have any effect anyways since the string is + // empty. + *decimal_point = -requested_digits; + return; + } + + Bignum numerator; + Bignum denominator; + Bignum delta_minus; + Bignum delta_plus; + // Make sure the bignum can grow large enough. The smallest double equals + // 4e-324. In this case the denominator needs fewer than 324*4 binary digits. + // The maximum double is 1.7976931348623157e308 which needs fewer than + // 308*4 binary digits. + ASSERT(Bignum::kMaxSignificantBits >= 324*4); + InitialScaledStartValues(significand, exponent, lower_boundary_is_closer, + estimated_power, need_boundary_deltas, + &numerator, &denominator, + &delta_minus, &delta_plus); + // We now have v = (numerator / denominator) * 10^estimated_power. + FixupMultiply10(estimated_power, is_even, decimal_point, + &numerator, &denominator, + &delta_minus, &delta_plus); + // We now have v = (numerator / denominator) * 10^(decimal_point-1), and + // 1 <= (numerator + delta_plus) / denominator < 10 + switch (mode) { + case BIGNUM_DTOA_SHORTEST: + case BIGNUM_DTOA_SHORTEST_SINGLE: + GenerateShortestDigits(&numerator, &denominator, + &delta_minus, &delta_plus, + is_even, buffer, length); + break; + case BIGNUM_DTOA_FIXED: + BignumToFixed(requested_digits, decimal_point, + &numerator, &denominator, + buffer, length); + break; + case BIGNUM_DTOA_PRECISION: + GenerateCountedDigits(requested_digits, decimal_point, + &numerator, &denominator, + buffer, length); + break; + default: + UNREACHABLE(); + } + buffer[*length] = '\0'; +} + + +// The procedure starts generating digits from the left to the right and stops +// when the generated digits yield the shortest decimal representation of v. A +// decimal representation of v is a number lying closer to v than to any other +// double, so it converts to v when read. +// +// This is true if d, the decimal representation, is between m- and m+, the +// upper and lower boundaries. d must be strictly between them if !is_even. +// m- := (numerator - delta_minus) / denominator +// m+ := (numerator + delta_plus) / denominator +// +// Precondition: 0 <= (numerator+delta_plus) / denominator < 10. +// If 1 <= (numerator+delta_plus) / denominator < 10 then no leading 0 digit +// will be produced. This should be the standard precondition. +static void GenerateShortestDigits(Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus, + bool is_even, + Vector buffer, int* length) { + // Small optimization: if delta_minus and delta_plus are the same just reuse + // one of the two bignums. + if (Bignum::Equal(*delta_minus, *delta_plus)) { + delta_plus = delta_minus; + } + *length = 0; + while (true) { + uint16_t digit; + digit = numerator->DivideModuloIntBignum(*denominator); + ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive. + // digit = numerator / denominator (integer division). + // numerator = numerator % denominator. + buffer[(*length)++] = digit + '0'; + + // Can we stop already? + // If the remainder of the division is less than the distance to the lower + // boundary we can stop. In this case we simply round down (discarding the + // remainder). + // Similarly we test if we can round up (using the upper boundary). + bool in_delta_room_minus; + bool in_delta_room_plus; + if (is_even) { + in_delta_room_minus = Bignum::LessEqual(*numerator, *delta_minus); + } else { + in_delta_room_minus = Bignum::Less(*numerator, *delta_minus); + } + if (is_even) { + in_delta_room_plus = + Bignum::PlusCompare(*numerator, *delta_plus, *denominator) >= 0; + } else { + in_delta_room_plus = + Bignum::PlusCompare(*numerator, *delta_plus, *denominator) > 0; + } + if (!in_delta_room_minus && !in_delta_room_plus) { + // Prepare for next iteration. + numerator->Times10(); + delta_minus->Times10(); + // We optimized delta_plus to be equal to delta_minus (if they share the + // same value). So don't multiply delta_plus if they point to the same + // object. + if (delta_minus != delta_plus) { + delta_plus->Times10(); + } + } else if (in_delta_room_minus && in_delta_room_plus) { + // Let's see if 2*numerator < denominator. + // If yes, then the next digit would be < 5 and we can round down. + int compare = Bignum::PlusCompare(*numerator, *numerator, *denominator); + if (compare < 0) { + // Remaining digits are less than .5. -> Round down (== do nothing). + } else if (compare > 0) { + // Remaining digits are more than .5 of denominator. -> Round up. + // Note that the last digit could not be a '9' as otherwise the whole + // loop would have stopped earlier. + // We still have an assert here in case the preconditions were not + // satisfied. + ASSERT(buffer[(*length) - 1] != '9'); + buffer[(*length) - 1]++; + } else { + // Halfway case. + // TODO(floitsch): need a way to solve half-way cases. + // For now let's round towards even (since this is what Gay seems to + // do). + + if ((buffer[(*length) - 1] - '0') % 2 == 0) { + // Round down => Do nothing. + } else { + ASSERT(buffer[(*length) - 1] != '9'); + buffer[(*length) - 1]++; + } + } + return; + } else if (in_delta_room_minus) { + // Round down (== do nothing). + return; + } else { // in_delta_room_plus + // Round up. + // Note again that the last digit could not be '9' since this would have + // stopped the loop earlier. + // We still have an ASSERT here, in case the preconditions were not + // satisfied. + ASSERT(buffer[(*length) -1] != '9'); + buffer[(*length) - 1]++; + return; + } + } +} + + +// Let v = numerator / denominator < 10. +// Then we generate 'count' digits of d = x.xxxxx... (without the decimal point) +// from left to right. Once 'count' digits have been produced we decide wether +// to round up or down. Remainders of exactly .5 round upwards. Numbers such +// as 9.999999 propagate a carry all the way, and change the +// exponent (decimal_point), when rounding upwards. +static void GenerateCountedDigits(int count, int* decimal_point, + Bignum* numerator, Bignum* denominator, + Vector(buffer), int* length) { + ASSERT(count >= 0); + for (int i = 0; i < count - 1; ++i) { + uint16_t digit; + digit = numerator->DivideModuloIntBignum(*denominator); + ASSERT(digit <= 9); // digit is a uint16_t and therefore always positive. + // digit = numerator / denominator (integer division). + // numerator = numerator % denominator. + buffer[i] = digit + '0'; + // Prepare for next iteration. + numerator->Times10(); + } + // Generate the last digit. + uint16_t digit; + digit = numerator->DivideModuloIntBignum(*denominator); + if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) { + digit++; + } + buffer[count - 1] = digit + '0'; + // Correct bad digits (in case we had a sequence of '9's). Propagate the + // carry until we hat a non-'9' or til we reach the first digit. + for (int i = count - 1; i > 0; --i) { + if (buffer[i] != '0' + 10) break; + buffer[i] = '0'; + buffer[i - 1]++; + } + if (buffer[0] == '0' + 10) { + // Propagate a carry past the top place. + buffer[0] = '1'; + (*decimal_point)++; + } + *length = count; +} + + +// Generates 'requested_digits' after the decimal point. It might omit +// trailing '0's. If the input number is too small then no digits at all are +// generated (ex.: 2 fixed digits for 0.00001). +// +// Input verifies: 1 <= (numerator + delta) / denominator < 10. +static void BignumToFixed(int requested_digits, int* decimal_point, + Bignum* numerator, Bignum* denominator, + Vector(buffer), int* length) { + // Note that we have to look at more than just the requested_digits, since + // a number could be rounded up. Example: v=0.5 with requested_digits=0. + // Even though the power of v equals 0 we can't just stop here. + if (-(*decimal_point) > requested_digits) { + // The number is definitively too small. + // Ex: 0.001 with requested_digits == 1. + // Set decimal-point to -requested_digits. This is what Gay does. + // Note that it should not have any effect anyways since the string is + // empty. + *decimal_point = -requested_digits; + *length = 0; + return; + } else if (-(*decimal_point) == requested_digits) { + // We only need to verify if the number rounds down or up. + // Ex: 0.04 and 0.06 with requested_digits == 1. + ASSERT(*decimal_point == -requested_digits); + // Initially the fraction lies in range (1, 10]. Multiply the denominator + // by 10 so that we can compare more easily. + denominator->Times10(); + if (Bignum::PlusCompare(*numerator, *numerator, *denominator) >= 0) { + // If the fraction is >= 0.5 then we have to include the rounded + // digit. + buffer[0] = '1'; + *length = 1; + (*decimal_point)++; + } else { + // Note that we caught most of similar cases earlier. + *length = 0; + } + return; + } else { + // The requested digits correspond to the digits after the point. + // The variable 'needed_digits' includes the digits before the point. + int needed_digits = (*decimal_point) + requested_digits; + GenerateCountedDigits(needed_digits, decimal_point, + numerator, denominator, + buffer, length); + } +} + + +// Returns an estimation of k such that 10^(k-1) <= v < 10^k where +// v = f * 2^exponent and 2^52 <= f < 2^53. +// v is hence a normalized double with the given exponent. The output is an +// approximation for the exponent of the decimal approimation .digits * 10^k. +// +// The result might undershoot by 1 in which case 10^k <= v < 10^k+1. +// Note: this property holds for v's upper boundary m+ too. +// 10^k <= m+ < 10^k+1. +// (see explanation below). +// +// Examples: +// EstimatePower(0) => 16 +// EstimatePower(-52) => 0 +// +// Note: e >= 0 => EstimatedPower(e) > 0. No similar claim can be made for e<0. +static int EstimatePower(int exponent) { + // This function estimates log10 of v where v = f*2^e (with e == exponent). + // Note that 10^floor(log10(v)) <= v, but v <= 10^ceil(log10(v)). + // Note that f is bounded by its container size. Let p = 53 (the double's + // significand size). Then 2^(p-1) <= f < 2^p. + // + // Given that log10(v) == log2(v)/log2(10) and e+(len(f)-1) is quite close + // to log2(v) the function is simplified to (e+(len(f)-1)/log2(10)). + // The computed number undershoots by less than 0.631 (when we compute log3 + // and not log10). + // + // Optimization: since we only need an approximated result this computation + // can be performed on 64 bit integers. On x86/x64 architecture the speedup is + // not really measurable, though. + // + // Since we want to avoid overshooting we decrement by 1e10 so that + // floating-point imprecisions don't affect us. + // + // Explanation for v's boundary m+: the computation takes advantage of + // the fact that 2^(p-1) <= f < 2^p. Boundaries still satisfy this requirement + // (even for denormals where the delta can be much more important). + + const double k1Log10 = 0.30102999566398114; // 1/lg(10) + + // For doubles len(f) == 53 (don't forget the hidden bit). + const int kSignificandSize = Double::kSignificandSize; + double estimate = ceil((exponent + kSignificandSize - 1) * k1Log10 - 1e-10); + return static_cast(estimate); +} + + +// See comments for InitialScaledStartValues. +static void InitialScaledStartValuesPositiveExponent( + uint64_t significand, int exponent, + int estimated_power, bool need_boundary_deltas, + Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus) { + // A positive exponent implies a positive power. + ASSERT(estimated_power >= 0); + // Since the estimated_power is positive we simply multiply the denominator + // by 10^estimated_power. + + // numerator = v. + numerator->AssignUInt64(significand); + numerator->ShiftLeft(exponent); + // denominator = 10^estimated_power. + denominator->AssignPowerUInt16(10, estimated_power); + + if (need_boundary_deltas) { + // Introduce a common denominator so that the deltas to the boundaries are + // integers. + denominator->ShiftLeft(1); + numerator->ShiftLeft(1); + // Let v = f * 2^e, then m+ - v = 1/2 * 2^e; With the common + // denominator (of 2) delta_plus equals 2^e. + delta_plus->AssignUInt16(1); + delta_plus->ShiftLeft(exponent); + // Same for delta_minus. The adjustments if f == 2^p-1 are done later. + delta_minus->AssignUInt16(1); + delta_minus->ShiftLeft(exponent); + } +} + + +// See comments for InitialScaledStartValues +static void InitialScaledStartValuesNegativeExponentPositivePower( + uint64_t significand, int exponent, + int estimated_power, bool need_boundary_deltas, + Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus) { + // v = f * 2^e with e < 0, and with estimated_power >= 0. + // This means that e is close to 0 (have a look at how estimated_power is + // computed). + + // numerator = significand + // since v = significand * 2^exponent this is equivalent to + // numerator = v * / 2^-exponent + numerator->AssignUInt64(significand); + // denominator = 10^estimated_power * 2^-exponent (with exponent < 0) + denominator->AssignPowerUInt16(10, estimated_power); + denominator->ShiftLeft(-exponent); + + if (need_boundary_deltas) { + // Introduce a common denominator so that the deltas to the boundaries are + // integers. + denominator->ShiftLeft(1); + numerator->ShiftLeft(1); + // Let v = f * 2^e, then m+ - v = 1/2 * 2^e; With the common + // denominator (of 2) delta_plus equals 2^e. + // Given that the denominator already includes v's exponent the distance + // to the boundaries is simply 1. + delta_plus->AssignUInt16(1); + // Same for delta_minus. The adjustments if f == 2^p-1 are done later. + delta_minus->AssignUInt16(1); + } +} + + +// See comments for InitialScaledStartValues +static void InitialScaledStartValuesNegativeExponentNegativePower( + uint64_t significand, int exponent, + int estimated_power, bool need_boundary_deltas, + Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus) { + // Instead of multiplying the denominator with 10^estimated_power we + // multiply all values (numerator and deltas) by 10^-estimated_power. + + // Use numerator as temporary container for power_ten. + Bignum* power_ten = numerator; + power_ten->AssignPowerUInt16(10, -estimated_power); + + if (need_boundary_deltas) { + // Since power_ten == numerator we must make a copy of 10^estimated_power + // before we complete the computation of the numerator. + // delta_plus = delta_minus = 10^estimated_power + delta_plus->AssignBignum(*power_ten); + delta_minus->AssignBignum(*power_ten); + } + + // numerator = significand * 2 * 10^-estimated_power + // since v = significand * 2^exponent this is equivalent to + // numerator = v * 10^-estimated_power * 2 * 2^-exponent. + // Remember: numerator has been abused as power_ten. So no need to assign it + // to itself. + ASSERT(numerator == power_ten); + numerator->MultiplyByUInt64(significand); + + // denominator = 2 * 2^-exponent with exponent < 0. + denominator->AssignUInt16(1); + denominator->ShiftLeft(-exponent); + + if (need_boundary_deltas) { + // Introduce a common denominator so that the deltas to the boundaries are + // integers. + numerator->ShiftLeft(1); + denominator->ShiftLeft(1); + // With this shift the boundaries have their correct value, since + // delta_plus = 10^-estimated_power, and + // delta_minus = 10^-estimated_power. + // These assignments have been done earlier. + // The adjustments if f == 2^p-1 (lower boundary is closer) are done later. + } +} + + +// Let v = significand * 2^exponent. +// Computes v / 10^estimated_power exactly, as a ratio of two bignums, numerator +// and denominator. The functions GenerateShortestDigits and +// GenerateCountedDigits will then convert this ratio to its decimal +// representation d, with the required accuracy. +// Then d * 10^estimated_power is the representation of v. +// (Note: the fraction and the estimated_power might get adjusted before +// generating the decimal representation.) +// +// The initial start values consist of: +// - a scaled numerator: s.t. numerator/denominator == v / 10^estimated_power. +// - a scaled (common) denominator. +// optionally (used by GenerateShortestDigits to decide if it has the shortest +// decimal converting back to v): +// - v - m-: the distance to the lower boundary. +// - m+ - v: the distance to the upper boundary. +// +// v, m+, m-, and therefore v - m- and m+ - v all share the same denominator. +// +// Let ep == estimated_power, then the returned values will satisfy: +// v / 10^ep = numerator / denominator. +// v's boundarys m- and m+: +// m- / 10^ep == v / 10^ep - delta_minus / denominator +// m+ / 10^ep == v / 10^ep + delta_plus / denominator +// Or in other words: +// m- == v - delta_minus * 10^ep / denominator; +// m+ == v + delta_plus * 10^ep / denominator; +// +// Since 10^(k-1) <= v < 10^k (with k == estimated_power) +// or 10^k <= v < 10^(k+1) +// we then have 0.1 <= numerator/denominator < 1 +// or 1 <= numerator/denominator < 10 +// +// It is then easy to kickstart the digit-generation routine. +// +// The boundary-deltas are only filled if the mode equals BIGNUM_DTOA_SHORTEST +// or BIGNUM_DTOA_SHORTEST_SINGLE. + +static void InitialScaledStartValues(uint64_t significand, + int exponent, + bool lower_boundary_is_closer, + int estimated_power, + bool need_boundary_deltas, + Bignum* numerator, + Bignum* denominator, + Bignum* delta_minus, + Bignum* delta_plus) { + if (exponent >= 0) { + InitialScaledStartValuesPositiveExponent( + significand, exponent, estimated_power, need_boundary_deltas, + numerator, denominator, delta_minus, delta_plus); + } else if (estimated_power >= 0) { + InitialScaledStartValuesNegativeExponentPositivePower( + significand, exponent, estimated_power, need_boundary_deltas, + numerator, denominator, delta_minus, delta_plus); + } else { + InitialScaledStartValuesNegativeExponentNegativePower( + significand, exponent, estimated_power, need_boundary_deltas, + numerator, denominator, delta_minus, delta_plus); + } + + if (need_boundary_deltas && lower_boundary_is_closer) { + // The lower boundary is closer at half the distance of "normal" numbers. + // Increase the common denominator and adapt all but the delta_minus. + denominator->ShiftLeft(1); // *2 + numerator->ShiftLeft(1); // *2 + delta_plus->ShiftLeft(1); // *2 + } +} + + +// This routine multiplies numerator/denominator so that its values lies in the +// range 1-10. That is after a call to this function we have: +// 1 <= (numerator + delta_plus) /denominator < 10. +// Let numerator the input before modification and numerator' the argument +// after modification, then the output-parameter decimal_point is such that +// numerator / denominator * 10^estimated_power == +// numerator' / denominator' * 10^(decimal_point - 1) +// In some cases estimated_power was too low, and this is already the case. We +// then simply adjust the power so that 10^(k-1) <= v < 10^k (with k == +// estimated_power) but do not touch the numerator or denominator. +// Otherwise the routine multiplies the numerator and the deltas by 10. +static void FixupMultiply10(int estimated_power, bool is_even, + int* decimal_point, + Bignum* numerator, Bignum* denominator, + Bignum* delta_minus, Bignum* delta_plus) { + bool in_range; + if (is_even) { + // For IEEE doubles half-way cases (in decimal system numbers ending with 5) + // are rounded to the closest floating-point number with even significand. + in_range = Bignum::PlusCompare(*numerator, *delta_plus, *denominator) >= 0; + } else { + in_range = Bignum::PlusCompare(*numerator, *delta_plus, *denominator) > 0; + } + if (in_range) { + // Since numerator + delta_plus >= denominator we already have + // 1 <= numerator/denominator < 10. Simply update the estimated_power. + *decimal_point = estimated_power + 1; + } else { + *decimal_point = estimated_power; + numerator->Times10(); + if (Bignum::Equal(*delta_minus, *delta_plus)) { + delta_minus->Times10(); + delta_plus->AssignBignum(*delta_minus); + } else { + delta_minus->Times10(); + delta_plus->Times10(); + } + } +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/bignum-dtoa.h b/klm/util/double-conversion/bignum-dtoa.h new file mode 100644 index 00000000..34b96199 --- /dev/null +++ b/klm/util/double-conversion/bignum-dtoa.h @@ -0,0 +1,84 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_BIGNUM_DTOA_H_ +#define DOUBLE_CONVERSION_BIGNUM_DTOA_H_ + +#include "utils.h" + +namespace double_conversion { + +enum BignumDtoaMode { + // Return the shortest correct representation. + // For example the output of 0.299999999999999988897 is (the less accurate but + // correct) 0.3. + BIGNUM_DTOA_SHORTEST, + // Same as BIGNUM_DTOA_SHORTEST but for single-precision floats. + BIGNUM_DTOA_SHORTEST_SINGLE, + // Return a fixed number of digits after the decimal point. + // For instance fixed(0.1, 4) becomes 0.1000 + // If the input number is big, the output will be big. + BIGNUM_DTOA_FIXED, + // Return a fixed number of digits, no matter what the exponent is. + BIGNUM_DTOA_PRECISION +}; + +// Converts the given double 'v' to ascii. +// The result should be interpreted as buffer * 10^(point-length). +// The buffer will be null-terminated. +// +// The input v must be > 0 and different from NaN, and Infinity. +// +// The output depends on the given mode: +// - SHORTEST: produce the least amount of digits for which the internal +// identity requirement is still satisfied. If the digits are printed +// (together with the correct exponent) then reading this number will give +// 'v' again. The buffer will choose the representation that is closest to +// 'v'. If there are two at the same distance, than the number is round up. +// In this mode the 'requested_digits' parameter is ignored. +// - FIXED: produces digits necessary to print a given number with +// 'requested_digits' digits after the decimal point. The produced digits +// might be too short in which case the caller has to fill the gaps with '0's. +// Example: toFixed(0.001, 5) is allowed to return buffer="1", point=-2. +// Halfway cases are rounded up. The call toFixed(0.15, 2) thus returns +// buffer="2", point=0. +// Note: the length of the returned buffer has no meaning wrt the significance +// of its digits. That is, just because it contains '0's does not mean that +// any other digit would not satisfy the internal identity requirement. +// - PRECISION: produces 'requested_digits' where the first digit is not '0'. +// Even though the length of produced digits usually equals +// 'requested_digits', the function is allowed to return fewer digits, in +// which case the caller has to fill the missing digits with '0's. +// Halfway cases are again rounded up. +// 'BignumDtoa' expects the given buffer to be big enough to hold all digits +// and a terminating null-character. +void BignumDtoa(double v, BignumDtoaMode mode, int requested_digits, + Vector buffer, int* length, int* point); + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_BIGNUM_DTOA_H_ diff --git a/klm/util/double-conversion/bignum.cc b/klm/util/double-conversion/bignum.cc new file mode 100644 index 00000000..747491a0 --- /dev/null +++ b/klm/util/double-conversion/bignum.cc @@ -0,0 +1,764 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "bignum.h" +#include "utils.h" + +namespace double_conversion { + +Bignum::Bignum() + : bigits_(bigits_buffer_, kBigitCapacity), used_digits_(0), exponent_(0) { + for (int i = 0; i < kBigitCapacity; ++i) { + bigits_[i] = 0; + } +} + + +template +static int BitSize(S value) { + return 8 * sizeof(value); +} + +// Guaranteed to lie in one Bigit. +void Bignum::AssignUInt16(uint16_t value) { + ASSERT(kBigitSize >= BitSize(value)); + Zero(); + if (value == 0) return; + + EnsureCapacity(1); + bigits_[0] = value; + used_digits_ = 1; +} + + +void Bignum::AssignUInt64(uint64_t value) { + const int kUInt64Size = 64; + + Zero(); + if (value == 0) return; + + int needed_bigits = kUInt64Size / kBigitSize + 1; + EnsureCapacity(needed_bigits); + for (int i = 0; i < needed_bigits; ++i) { + bigits_[i] = value & kBigitMask; + value = value >> kBigitSize; + } + used_digits_ = needed_bigits; + Clamp(); +} + + +void Bignum::AssignBignum(const Bignum& other) { + exponent_ = other.exponent_; + for (int i = 0; i < other.used_digits_; ++i) { + bigits_[i] = other.bigits_[i]; + } + // Clear the excess digits (if there were any). + for (int i = other.used_digits_; i < used_digits_; ++i) { + bigits_[i] = 0; + } + used_digits_ = other.used_digits_; +} + + +static uint64_t ReadUInt64(Vector buffer, + int from, + int digits_to_read) { + uint64_t result = 0; + for (int i = from; i < from + digits_to_read; ++i) { + int digit = buffer[i] - '0'; + ASSERT(0 <= digit && digit <= 9); + result = result * 10 + digit; + } + return result; +} + + +void Bignum::AssignDecimalString(Vector value) { + // 2^64 = 18446744073709551616 > 10^19 + const int kMaxUint64DecimalDigits = 19; + Zero(); + int length = value.length(); + int pos = 0; + // Let's just say that each digit needs 4 bits. + while (length >= kMaxUint64DecimalDigits) { + uint64_t digits = ReadUInt64(value, pos, kMaxUint64DecimalDigits); + pos += kMaxUint64DecimalDigits; + length -= kMaxUint64DecimalDigits; + MultiplyByPowerOfTen(kMaxUint64DecimalDigits); + AddUInt64(digits); + } + uint64_t digits = ReadUInt64(value, pos, length); + MultiplyByPowerOfTen(length); + AddUInt64(digits); + Clamp(); +} + + +static int HexCharValue(char c) { + if ('0' <= c && c <= '9') return c - '0'; + if ('a' <= c && c <= 'f') return 10 + c - 'a'; + if ('A' <= c && c <= 'F') return 10 + c - 'A'; + UNREACHABLE(); + return 0; // To make compiler happy. +} + + +void Bignum::AssignHexString(Vector value) { + Zero(); + int length = value.length(); + + int needed_bigits = length * 4 / kBigitSize + 1; + EnsureCapacity(needed_bigits); + int string_index = length - 1; + for (int i = 0; i < needed_bigits - 1; ++i) { + // These bigits are guaranteed to be "full". + Chunk current_bigit = 0; + for (int j = 0; j < kBigitSize / 4; j++) { + current_bigit += HexCharValue(value[string_index--]) << (j * 4); + } + bigits_[i] = current_bigit; + } + used_digits_ = needed_bigits - 1; + + Chunk most_significant_bigit = 0; // Could be = 0; + for (int j = 0; j <= string_index; ++j) { + most_significant_bigit <<= 4; + most_significant_bigit += HexCharValue(value[j]); + } + if (most_significant_bigit != 0) { + bigits_[used_digits_] = most_significant_bigit; + used_digits_++; + } + Clamp(); +} + + +void Bignum::AddUInt64(uint64_t operand) { + if (operand == 0) return; + Bignum other; + other.AssignUInt64(operand); + AddBignum(other); +} + + +void Bignum::AddBignum(const Bignum& other) { + ASSERT(IsClamped()); + ASSERT(other.IsClamped()); + + // If this has a greater exponent than other append zero-bigits to this. + // After this call exponent_ <= other.exponent_. + Align(other); + + // There are two possibilities: + // aaaaaaaaaaa 0000 (where the 0s represent a's exponent) + // bbbbb 00000000 + // ---------------- + // ccccccccccc 0000 + // or + // aaaaaaaaaa 0000 + // bbbbbbbbb 0000000 + // ----------------- + // cccccccccccc 0000 + // In both cases we might need a carry bigit. + + EnsureCapacity(1 + Max(BigitLength(), other.BigitLength()) - exponent_); + Chunk carry = 0; + int bigit_pos = other.exponent_ - exponent_; + ASSERT(bigit_pos >= 0); + for (int i = 0; i < other.used_digits_; ++i) { + Chunk sum = bigits_[bigit_pos] + other.bigits_[i] + carry; + bigits_[bigit_pos] = sum & kBigitMask; + carry = sum >> kBigitSize; + bigit_pos++; + } + + while (carry != 0) { + Chunk sum = bigits_[bigit_pos] + carry; + bigits_[bigit_pos] = sum & kBigitMask; + carry = sum >> kBigitSize; + bigit_pos++; + } + used_digits_ = Max(bigit_pos, used_digits_); + ASSERT(IsClamped()); +} + + +void Bignum::SubtractBignum(const Bignum& other) { + ASSERT(IsClamped()); + ASSERT(other.IsClamped()); + // We require this to be bigger than other. + ASSERT(LessEqual(other, *this)); + + Align(other); + + int offset = other.exponent_ - exponent_; + Chunk borrow = 0; + int i; + for (i = 0; i < other.used_digits_; ++i) { + ASSERT((borrow == 0) || (borrow == 1)); + Chunk difference = bigits_[i + offset] - other.bigits_[i] - borrow; + bigits_[i + offset] = difference & kBigitMask; + borrow = difference >> (kChunkSize - 1); + } + while (borrow != 0) { + Chunk difference = bigits_[i + offset] - borrow; + bigits_[i + offset] = difference & kBigitMask; + borrow = difference >> (kChunkSize - 1); + ++i; + } + Clamp(); +} + + +void Bignum::ShiftLeft(int shift_amount) { + if (used_digits_ == 0) return; + exponent_ += shift_amount / kBigitSize; + int local_shift = shift_amount % kBigitSize; + EnsureCapacity(used_digits_ + 1); + BigitsShiftLeft(local_shift); +} + + +void Bignum::MultiplyByUInt32(uint32_t factor) { + if (factor == 1) return; + if (factor == 0) { + Zero(); + return; + } + if (used_digits_ == 0) return; + + // The product of a bigit with the factor is of size kBigitSize + 32. + // Assert that this number + 1 (for the carry) fits into double chunk. + ASSERT(kDoubleChunkSize >= kBigitSize + 32 + 1); + DoubleChunk carry = 0; + for (int i = 0; i < used_digits_; ++i) { + DoubleChunk product = static_cast(factor) * bigits_[i] + carry; + bigits_[i] = static_cast(product & kBigitMask); + carry = (product >> kBigitSize); + } + while (carry != 0) { + EnsureCapacity(used_digits_ + 1); + bigits_[used_digits_] = carry & kBigitMask; + used_digits_++; + carry >>= kBigitSize; + } +} + + +void Bignum::MultiplyByUInt64(uint64_t factor) { + if (factor == 1) return; + if (factor == 0) { + Zero(); + return; + } + ASSERT(kBigitSize < 32); + uint64_t carry = 0; + uint64_t low = factor & 0xFFFFFFFF; + uint64_t high = factor >> 32; + for (int i = 0; i < used_digits_; ++i) { + uint64_t product_low = low * bigits_[i]; + uint64_t product_high = high * bigits_[i]; + uint64_t tmp = (carry & kBigitMask) + product_low; + bigits_[i] = tmp & kBigitMask; + carry = (carry >> kBigitSize) + (tmp >> kBigitSize) + + (product_high << (32 - kBigitSize)); + } + while (carry != 0) { + EnsureCapacity(used_digits_ + 1); + bigits_[used_digits_] = carry & kBigitMask; + used_digits_++; + carry >>= kBigitSize; + } +} + + +void Bignum::MultiplyByPowerOfTen(int exponent) { + const uint64_t kFive27 = UINT64_2PART_C(0x6765c793, fa10079d); + const uint16_t kFive1 = 5; + const uint16_t kFive2 = kFive1 * 5; + const uint16_t kFive3 = kFive2 * 5; + const uint16_t kFive4 = kFive3 * 5; + const uint16_t kFive5 = kFive4 * 5; + const uint16_t kFive6 = kFive5 * 5; + const uint32_t kFive7 = kFive6 * 5; + const uint32_t kFive8 = kFive7 * 5; + const uint32_t kFive9 = kFive8 * 5; + const uint32_t kFive10 = kFive9 * 5; + const uint32_t kFive11 = kFive10 * 5; + const uint32_t kFive12 = kFive11 * 5; + const uint32_t kFive13 = kFive12 * 5; + const uint32_t kFive1_to_12[] = + { kFive1, kFive2, kFive3, kFive4, kFive5, kFive6, + kFive7, kFive8, kFive9, kFive10, kFive11, kFive12 }; + + ASSERT(exponent >= 0); + if (exponent == 0) return; + if (used_digits_ == 0) return; + + // We shift by exponent at the end just before returning. + int remaining_exponent = exponent; + while (remaining_exponent >= 27) { + MultiplyByUInt64(kFive27); + remaining_exponent -= 27; + } + while (remaining_exponent >= 13) { + MultiplyByUInt32(kFive13); + remaining_exponent -= 13; + } + if (remaining_exponent > 0) { + MultiplyByUInt32(kFive1_to_12[remaining_exponent - 1]); + } + ShiftLeft(exponent); +} + + +void Bignum::Square() { + ASSERT(IsClamped()); + int product_length = 2 * used_digits_; + EnsureCapacity(product_length); + + // Comba multiplication: compute each column separately. + // Example: r = a2a1a0 * b2b1b0. + // r = 1 * a0b0 + + // 10 * (a1b0 + a0b1) + + // 100 * (a2b0 + a1b1 + a0b2) + + // 1000 * (a2b1 + a1b2) + + // 10000 * a2b2 + // + // In the worst case we have to accumulate nb-digits products of digit*digit. + // + // Assert that the additional number of bits in a DoubleChunk are enough to + // sum up used_digits of Bigit*Bigit. + if ((1 << (2 * (kChunkSize - kBigitSize))) <= used_digits_) { + UNIMPLEMENTED(); + } + DoubleChunk accumulator = 0; + // First shift the digits so we don't overwrite them. + int copy_offset = used_digits_; + for (int i = 0; i < used_digits_; ++i) { + bigits_[copy_offset + i] = bigits_[i]; + } + // We have two loops to avoid some 'if's in the loop. + for (int i = 0; i < used_digits_; ++i) { + // Process temporary digit i with power i. + // The sum of the two indices must be equal to i. + int bigit_index1 = i; + int bigit_index2 = 0; + // Sum all of the sub-products. + while (bigit_index1 >= 0) { + Chunk chunk1 = bigits_[copy_offset + bigit_index1]; + Chunk chunk2 = bigits_[copy_offset + bigit_index2]; + accumulator += static_cast(chunk1) * chunk2; + bigit_index1--; + bigit_index2++; + } + bigits_[i] = static_cast(accumulator) & kBigitMask; + accumulator >>= kBigitSize; + } + for (int i = used_digits_; i < product_length; ++i) { + int bigit_index1 = used_digits_ - 1; + int bigit_index2 = i - bigit_index1; + // Invariant: sum of both indices is again equal to i. + // Inner loop runs 0 times on last iteration, emptying accumulator. + while (bigit_index2 < used_digits_) { + Chunk chunk1 = bigits_[copy_offset + bigit_index1]; + Chunk chunk2 = bigits_[copy_offset + bigit_index2]; + accumulator += static_cast(chunk1) * chunk2; + bigit_index1--; + bigit_index2++; + } + // The overwritten bigits_[i] will never be read in further loop iterations, + // because bigit_index1 and bigit_index2 are always greater + // than i - used_digits_. + bigits_[i] = static_cast(accumulator) & kBigitMask; + accumulator >>= kBigitSize; + } + // Since the result was guaranteed to lie inside the number the + // accumulator must be 0 now. + ASSERT(accumulator == 0); + + // Don't forget to update the used_digits and the exponent. + used_digits_ = product_length; + exponent_ *= 2; + Clamp(); +} + + +void Bignum::AssignPowerUInt16(uint16_t base, int power_exponent) { + ASSERT(base != 0); + ASSERT(power_exponent >= 0); + if (power_exponent == 0) { + AssignUInt16(1); + return; + } + Zero(); + int shifts = 0; + // We expect base to be in range 2-32, and most often to be 10. + // It does not make much sense to implement different algorithms for counting + // the bits. + while ((base & 1) == 0) { + base >>= 1; + shifts++; + } + int bit_size = 0; + int tmp_base = base; + while (tmp_base != 0) { + tmp_base >>= 1; + bit_size++; + } + int final_size = bit_size * power_exponent; + // 1 extra bigit for the shifting, and one for rounded final_size. + EnsureCapacity(final_size / kBigitSize + 2); + + // Left to Right exponentiation. + int mask = 1; + while (power_exponent >= mask) mask <<= 1; + + // The mask is now pointing to the bit above the most significant 1-bit of + // power_exponent. + // Get rid of first 1-bit; + mask >>= 2; + uint64_t this_value = base; + + bool delayed_multipliciation = false; + const uint64_t max_32bits = 0xFFFFFFFF; + while (mask != 0 && this_value <= max_32bits) { + this_value = this_value * this_value; + // Verify that there is enough space in this_value to perform the + // multiplication. The first bit_size bits must be 0. + if ((power_exponent & mask) != 0) { + uint64_t base_bits_mask = + ~((static_cast(1) << (64 - bit_size)) - 1); + bool high_bits_zero = (this_value & base_bits_mask) == 0; + if (high_bits_zero) { + this_value *= base; + } else { + delayed_multipliciation = true; + } + } + mask >>= 1; + } + AssignUInt64(this_value); + if (delayed_multipliciation) { + MultiplyByUInt32(base); + } + + // Now do the same thing as a bignum. + while (mask != 0) { + Square(); + if ((power_exponent & mask) != 0) { + MultiplyByUInt32(base); + } + mask >>= 1; + } + + // And finally add the saved shifts. + ShiftLeft(shifts * power_exponent); +} + + +// Precondition: this/other < 16bit. +uint16_t Bignum::DivideModuloIntBignum(const Bignum& other) { + ASSERT(IsClamped()); + ASSERT(other.IsClamped()); + ASSERT(other.used_digits_ > 0); + + // Easy case: if we have less digits than the divisor than the result is 0. + // Note: this handles the case where this == 0, too. + if (BigitLength() < other.BigitLength()) { + return 0; + } + + Align(other); + + uint16_t result = 0; + + // Start by removing multiples of 'other' until both numbers have the same + // number of digits. + while (BigitLength() > other.BigitLength()) { + // This naive approach is extremely inefficient if the this divided other + // might be big. This function is implemented for doubleToString where + // the result should be small (less than 10). + ASSERT(other.bigits_[other.used_digits_ - 1] >= ((1 << kBigitSize) / 16)); + // Remove the multiples of the first digit. + // Example this = 23 and other equals 9. -> Remove 2 multiples. + result += bigits_[used_digits_ - 1]; + SubtractTimes(other, bigits_[used_digits_ - 1]); + } + + ASSERT(BigitLength() == other.BigitLength()); + + // Both bignums are at the same length now. + // Since other has more than 0 digits we know that the access to + // bigits_[used_digits_ - 1] is safe. + Chunk this_bigit = bigits_[used_digits_ - 1]; + Chunk other_bigit = other.bigits_[other.used_digits_ - 1]; + + if (other.used_digits_ == 1) { + // Shortcut for easy (and common) case. + int quotient = this_bigit / other_bigit; + bigits_[used_digits_ - 1] = this_bigit - other_bigit * quotient; + result += quotient; + Clamp(); + return result; + } + + int division_estimate = this_bigit / (other_bigit + 1); + result += division_estimate; + SubtractTimes(other, division_estimate); + + if (other_bigit * (division_estimate + 1) > this_bigit) { + // No need to even try to subtract. Even if other's remaining digits were 0 + // another subtraction would be too much. + return result; + } + + while (LessEqual(other, *this)) { + SubtractBignum(other); + result++; + } + return result; +} + + +template +static int SizeInHexChars(S number) { + ASSERT(number > 0); + int result = 0; + while (number != 0) { + number >>= 4; + result++; + } + return result; +} + + +static char HexCharOfValue(int value) { + ASSERT(0 <= value && value <= 16); + if (value < 10) return value + '0'; + return value - 10 + 'A'; +} + + +bool Bignum::ToHexString(char* buffer, int buffer_size) const { + ASSERT(IsClamped()); + // Each bigit must be printable as separate hex-character. + ASSERT(kBigitSize % 4 == 0); + const int kHexCharsPerBigit = kBigitSize / 4; + + if (used_digits_ == 0) { + if (buffer_size < 2) return false; + buffer[0] = '0'; + buffer[1] = '\0'; + return true; + } + // We add 1 for the terminating '\0' character. + int needed_chars = (BigitLength() - 1) * kHexCharsPerBigit + + SizeInHexChars(bigits_[used_digits_ - 1]) + 1; + if (needed_chars > buffer_size) return false; + int string_index = needed_chars - 1; + buffer[string_index--] = '\0'; + for (int i = 0; i < exponent_; ++i) { + for (int j = 0; j < kHexCharsPerBigit; ++j) { + buffer[string_index--] = '0'; + } + } + for (int i = 0; i < used_digits_ - 1; ++i) { + Chunk current_bigit = bigits_[i]; + for (int j = 0; j < kHexCharsPerBigit; ++j) { + buffer[string_index--] = HexCharOfValue(current_bigit & 0xF); + current_bigit >>= 4; + } + } + // And finally the last bigit. + Chunk most_significant_bigit = bigits_[used_digits_ - 1]; + while (most_significant_bigit != 0) { + buffer[string_index--] = HexCharOfValue(most_significant_bigit & 0xF); + most_significant_bigit >>= 4; + } + return true; +} + + +Bignum::Chunk Bignum::BigitAt(int index) const { + if (index >= BigitLength()) return 0; + if (index < exponent_) return 0; + return bigits_[index - exponent_]; +} + + +int Bignum::Compare(const Bignum& a, const Bignum& b) { + ASSERT(a.IsClamped()); + ASSERT(b.IsClamped()); + int bigit_length_a = a.BigitLength(); + int bigit_length_b = b.BigitLength(); + if (bigit_length_a < bigit_length_b) return -1; + if (bigit_length_a > bigit_length_b) return +1; + for (int i = bigit_length_a - 1; i >= Min(a.exponent_, b.exponent_); --i) { + Chunk bigit_a = a.BigitAt(i); + Chunk bigit_b = b.BigitAt(i); + if (bigit_a < bigit_b) return -1; + if (bigit_a > bigit_b) return +1; + // Otherwise they are equal up to this digit. Try the next digit. + } + return 0; +} + + +int Bignum::PlusCompare(const Bignum& a, const Bignum& b, const Bignum& c) { + ASSERT(a.IsClamped()); + ASSERT(b.IsClamped()); + ASSERT(c.IsClamped()); + if (a.BigitLength() < b.BigitLength()) { + return PlusCompare(b, a, c); + } + if (a.BigitLength() + 1 < c.BigitLength()) return -1; + if (a.BigitLength() > c.BigitLength()) return +1; + // The exponent encodes 0-bigits. So if there are more 0-digits in 'a' than + // 'b' has digits, then the bigit-length of 'a'+'b' must be equal to the one + // of 'a'. + if (a.exponent_ >= b.BigitLength() && a.BigitLength() < c.BigitLength()) { + return -1; + } + + Chunk borrow = 0; + // Starting at min_exponent all digits are == 0. So no need to compare them. + int min_exponent = Min(Min(a.exponent_, b.exponent_), c.exponent_); + for (int i = c.BigitLength() - 1; i >= min_exponent; --i) { + Chunk chunk_a = a.BigitAt(i); + Chunk chunk_b = b.BigitAt(i); + Chunk chunk_c = c.BigitAt(i); + Chunk sum = chunk_a + chunk_b; + if (sum > chunk_c + borrow) { + return +1; + } else { + borrow = chunk_c + borrow - sum; + if (borrow > 1) return -1; + borrow <<= kBigitSize; + } + } + if (borrow == 0) return 0; + return -1; +} + + +void Bignum::Clamp() { + while (used_digits_ > 0 && bigits_[used_digits_ - 1] == 0) { + used_digits_--; + } + if (used_digits_ == 0) { + // Zero. + exponent_ = 0; + } +} + + +bool Bignum::IsClamped() const { + return used_digits_ == 0 || bigits_[used_digits_ - 1] != 0; +} + + +void Bignum::Zero() { + for (int i = 0; i < used_digits_; ++i) { + bigits_[i] = 0; + } + used_digits_ = 0; + exponent_ = 0; +} + + +void Bignum::Align(const Bignum& other) { + if (exponent_ > other.exponent_) { + // If "X" represents a "hidden" digit (by the exponent) then we are in the + // following case (a == this, b == other): + // a: aaaaaaXXXX or a: aaaaaXXX + // b: bbbbbbX b: bbbbbbbbXX + // We replace some of the hidden digits (X) of a with 0 digits. + // a: aaaaaa000X or a: aaaaa0XX + int zero_digits = exponent_ - other.exponent_; + EnsureCapacity(used_digits_ + zero_digits); + for (int i = used_digits_ - 1; i >= 0; --i) { + bigits_[i + zero_digits] = bigits_[i]; + } + for (int i = 0; i < zero_digits; ++i) { + bigits_[i] = 0; + } + used_digits_ += zero_digits; + exponent_ -= zero_digits; + ASSERT(used_digits_ >= 0); + ASSERT(exponent_ >= 0); + } +} + + +void Bignum::BigitsShiftLeft(int shift_amount) { + ASSERT(shift_amount < kBigitSize); + ASSERT(shift_amount >= 0); + Chunk carry = 0; + for (int i = 0; i < used_digits_; ++i) { + Chunk new_carry = bigits_[i] >> (kBigitSize - shift_amount); + bigits_[i] = ((bigits_[i] << shift_amount) + carry) & kBigitMask; + carry = new_carry; + } + if (carry != 0) { + bigits_[used_digits_] = carry; + used_digits_++; + } +} + + +void Bignum::SubtractTimes(const Bignum& other, int factor) { + ASSERT(exponent_ <= other.exponent_); + if (factor < 3) { + for (int i = 0; i < factor; ++i) { + SubtractBignum(other); + } + return; + } + Chunk borrow = 0; + int exponent_diff = other.exponent_ - exponent_; + for (int i = 0; i < other.used_digits_; ++i) { + DoubleChunk product = static_cast(factor) * other.bigits_[i]; + DoubleChunk remove = borrow + product; + Chunk difference = bigits_[i + exponent_diff] - (remove & kBigitMask); + bigits_[i + exponent_diff] = difference & kBigitMask; + borrow = static_cast((difference >> (kChunkSize - 1)) + + (remove >> kBigitSize)); + } + for (int i = other.used_digits_ + exponent_diff; i < used_digits_; ++i) { + if (borrow == 0) return; + Chunk difference = bigits_[i] - borrow; + bigits_[i] = difference & kBigitMask; + borrow = difference >> (kChunkSize - 1); + ++i; + } + Clamp(); +} + + +} // namespace double_conversion diff --git a/klm/util/double-conversion/bignum.h b/klm/util/double-conversion/bignum.h new file mode 100644 index 00000000..5ec3544f --- /dev/null +++ b/klm/util/double-conversion/bignum.h @@ -0,0 +1,145 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_BIGNUM_H_ +#define DOUBLE_CONVERSION_BIGNUM_H_ + +#include "utils.h" + +namespace double_conversion { + +class Bignum { + public: + // 3584 = 128 * 28. We can represent 2^3584 > 10^1000 accurately. + // This bignum can encode much bigger numbers, since it contains an + // exponent. + static const int kMaxSignificantBits = 3584; + + Bignum(); + void AssignUInt16(uint16_t value); + void AssignUInt64(uint64_t value); + void AssignBignum(const Bignum& other); + + void AssignDecimalString(Vector value); + void AssignHexString(Vector value); + + void AssignPowerUInt16(uint16_t base, int exponent); + + void AddUInt16(uint16_t operand); + void AddUInt64(uint64_t operand); + void AddBignum(const Bignum& other); + // Precondition: this >= other. + void SubtractBignum(const Bignum& other); + + void Square(); + void ShiftLeft(int shift_amount); + void MultiplyByUInt32(uint32_t factor); + void MultiplyByUInt64(uint64_t factor); + void MultiplyByPowerOfTen(int exponent); + void Times10() { return MultiplyByUInt32(10); } + // Pseudocode: + // int result = this / other; + // this = this % other; + // In the worst case this function is in O(this/other). + uint16_t DivideModuloIntBignum(const Bignum& other); + + bool ToHexString(char* buffer, int buffer_size) const; + + // Returns + // -1 if a < b, + // 0 if a == b, and + // +1 if a > b. + static int Compare(const Bignum& a, const Bignum& b); + static bool Equal(const Bignum& a, const Bignum& b) { + return Compare(a, b) == 0; + } + static bool LessEqual(const Bignum& a, const Bignum& b) { + return Compare(a, b) <= 0; + } + static bool Less(const Bignum& a, const Bignum& b) { + return Compare(a, b) < 0; + } + // Returns Compare(a + b, c); + static int PlusCompare(const Bignum& a, const Bignum& b, const Bignum& c); + // Returns a + b == c + static bool PlusEqual(const Bignum& a, const Bignum& b, const Bignum& c) { + return PlusCompare(a, b, c) == 0; + } + // Returns a + b <= c + static bool PlusLessEqual(const Bignum& a, const Bignum& b, const Bignum& c) { + return PlusCompare(a, b, c) <= 0; + } + // Returns a + b < c + static bool PlusLess(const Bignum& a, const Bignum& b, const Bignum& c) { + return PlusCompare(a, b, c) < 0; + } + private: + typedef uint32_t Chunk; + typedef uint64_t DoubleChunk; + + static const int kChunkSize = sizeof(Chunk) * 8; + static const int kDoubleChunkSize = sizeof(DoubleChunk) * 8; + // With bigit size of 28 we loose some bits, but a double still fits easily + // into two chunks, and more importantly we can use the Comba multiplication. + static const int kBigitSize = 28; + static const Chunk kBigitMask = (1 << kBigitSize) - 1; + // Every instance allocates kBigitLength chunks on the stack. Bignums cannot + // grow. There are no checks if the stack-allocated space is sufficient. + static const int kBigitCapacity = kMaxSignificantBits / kBigitSize; + + void EnsureCapacity(int size) { + if (size > kBigitCapacity) { + UNREACHABLE(); + } + } + void Align(const Bignum& other); + void Clamp(); + bool IsClamped() const; + void Zero(); + // Requires this to have enough capacity (no tests done). + // Updates used_digits_ if necessary. + // shift_amount must be < kBigitSize. + void BigitsShiftLeft(int shift_amount); + // BigitLength includes the "hidden" digits encoded in the exponent. + int BigitLength() const { return used_digits_ + exponent_; } + Chunk BigitAt(int index) const; + void SubtractTimes(const Bignum& other, int factor); + + Chunk bigits_buffer_[kBigitCapacity]; + // A vector backed by bigits_buffer_. This way accesses to the array are + // checked for out-of-bounds errors. + Vector bigits_; + int used_digits_; + // The Bignum's value equals value(bigits_) * 2^(exponent_ * kBigitSize). + int exponent_; + + DISALLOW_COPY_AND_ASSIGN(Bignum); +}; + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_BIGNUM_H_ diff --git a/klm/util/double-conversion/cached-powers.cc b/klm/util/double-conversion/cached-powers.cc new file mode 100644 index 00000000..c6764291 --- /dev/null +++ b/klm/util/double-conversion/cached-powers.cc @@ -0,0 +1,175 @@ +// Copyright 2006-2008 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include +#include + +#include "utils.h" + +#include "cached-powers.h" + +namespace double_conversion { + +struct CachedPower { + uint64_t significand; + int16_t binary_exponent; + int16_t decimal_exponent; +}; + +static const CachedPower kCachedPowers[] = { + {UINT64_2PART_C(0xfa8fd5a0, 081c0288), -1220, -348}, + {UINT64_2PART_C(0xbaaee17f, a23ebf76), -1193, -340}, + {UINT64_2PART_C(0x8b16fb20, 3055ac76), -1166, -332}, + {UINT64_2PART_C(0xcf42894a, 5dce35ea), -1140, -324}, + {UINT64_2PART_C(0x9a6bb0aa, 55653b2d), -1113, -316}, + {UINT64_2PART_C(0xe61acf03, 3d1a45df), -1087, -308}, + {UINT64_2PART_C(0xab70fe17, c79ac6ca), -1060, -300}, + {UINT64_2PART_C(0xff77b1fc, bebcdc4f), -1034, -292}, + {UINT64_2PART_C(0xbe5691ef, 416bd60c), -1007, -284}, + {UINT64_2PART_C(0x8dd01fad, 907ffc3c), -980, -276}, + {UINT64_2PART_C(0xd3515c28, 31559a83), -954, -268}, + {UINT64_2PART_C(0x9d71ac8f, ada6c9b5), -927, -260}, + {UINT64_2PART_C(0xea9c2277, 23ee8bcb), -901, -252}, + {UINT64_2PART_C(0xaecc4991, 4078536d), -874, -244}, + {UINT64_2PART_C(0x823c1279, 5db6ce57), -847, -236}, + {UINT64_2PART_C(0xc2109436, 4dfb5637), -821, -228}, + {UINT64_2PART_C(0x9096ea6f, 3848984f), -794, -220}, + {UINT64_2PART_C(0xd77485cb, 25823ac7), -768, -212}, + {UINT64_2PART_C(0xa086cfcd, 97bf97f4), -741, -204}, + {UINT64_2PART_C(0xef340a98, 172aace5), -715, -196}, + {UINT64_2PART_C(0xb23867fb, 2a35b28e), -688, -188}, + {UINT64_2PART_C(0x84c8d4df, d2c63f3b), -661, -180}, + {UINT64_2PART_C(0xc5dd4427, 1ad3cdba), -635, -172}, + {UINT64_2PART_C(0x936b9fce, bb25c996), -608, -164}, + {UINT64_2PART_C(0xdbac6c24, 7d62a584), -582, -156}, + {UINT64_2PART_C(0xa3ab6658, 0d5fdaf6), -555, -148}, + {UINT64_2PART_C(0xf3e2f893, dec3f126), -529, -140}, + {UINT64_2PART_C(0xb5b5ada8, aaff80b8), -502, -132}, + {UINT64_2PART_C(0x87625f05, 6c7c4a8b), -475, -124}, + {UINT64_2PART_C(0xc9bcff60, 34c13053), -449, -116}, + {UINT64_2PART_C(0x964e858c, 91ba2655), -422, -108}, + {UINT64_2PART_C(0xdff97724, 70297ebd), -396, -100}, + {UINT64_2PART_C(0xa6dfbd9f, b8e5b88f), -369, -92}, + {UINT64_2PART_C(0xf8a95fcf, 88747d94), -343, -84}, + {UINT64_2PART_C(0xb9447093, 8fa89bcf), -316, -76}, + {UINT64_2PART_C(0x8a08f0f8, bf0f156b), -289, -68}, + {UINT64_2PART_C(0xcdb02555, 653131b6), -263, -60}, + {UINT64_2PART_C(0x993fe2c6, d07b7fac), -236, -52}, + {UINT64_2PART_C(0xe45c10c4, 2a2b3b06), -210, -44}, + {UINT64_2PART_C(0xaa242499, 697392d3), -183, -36}, + {UINT64_2PART_C(0xfd87b5f2, 8300ca0e), -157, -28}, + {UINT64_2PART_C(0xbce50864, 92111aeb), -130, -20}, + {UINT64_2PART_C(0x8cbccc09, 6f5088cc), -103, -12}, + {UINT64_2PART_C(0xd1b71758, e219652c), -77, -4}, + {UINT64_2PART_C(0x9c400000, 00000000), -50, 4}, + {UINT64_2PART_C(0xe8d4a510, 00000000), -24, 12}, + {UINT64_2PART_C(0xad78ebc5, ac620000), 3, 20}, + {UINT64_2PART_C(0x813f3978, f8940984), 30, 28}, + {UINT64_2PART_C(0xc097ce7b, c90715b3), 56, 36}, + {UINT64_2PART_C(0x8f7e32ce, 7bea5c70), 83, 44}, + {UINT64_2PART_C(0xd5d238a4, abe98068), 109, 52}, + {UINT64_2PART_C(0x9f4f2726, 179a2245), 136, 60}, + {UINT64_2PART_C(0xed63a231, d4c4fb27), 162, 68}, + {UINT64_2PART_C(0xb0de6538, 8cc8ada8), 189, 76}, + {UINT64_2PART_C(0x83c7088e, 1aab65db), 216, 84}, + {UINT64_2PART_C(0xc45d1df9, 42711d9a), 242, 92}, + {UINT64_2PART_C(0x924d692c, a61be758), 269, 100}, + {UINT64_2PART_C(0xda01ee64, 1a708dea), 295, 108}, + {UINT64_2PART_C(0xa26da399, 9aef774a), 322, 116}, + {UINT64_2PART_C(0xf209787b, b47d6b85), 348, 124}, + {UINT64_2PART_C(0xb454e4a1, 79dd1877), 375, 132}, + {UINT64_2PART_C(0x865b8692, 5b9bc5c2), 402, 140}, + {UINT64_2PART_C(0xc83553c5, c8965d3d), 428, 148}, + {UINT64_2PART_C(0x952ab45c, fa97a0b3), 455, 156}, + {UINT64_2PART_C(0xde469fbd, 99a05fe3), 481, 164}, + {UINT64_2PART_C(0xa59bc234, db398c25), 508, 172}, + {UINT64_2PART_C(0xf6c69a72, a3989f5c), 534, 180}, + {UINT64_2PART_C(0xb7dcbf53, 54e9bece), 561, 188}, + {UINT64_2PART_C(0x88fcf317, f22241e2), 588, 196}, + {UINT64_2PART_C(0xcc20ce9b, d35c78a5), 614, 204}, + {UINT64_2PART_C(0x98165af3, 7b2153df), 641, 212}, + {UINT64_2PART_C(0xe2a0b5dc, 971f303a), 667, 220}, + {UINT64_2PART_C(0xa8d9d153, 5ce3b396), 694, 228}, + {UINT64_2PART_C(0xfb9b7cd9, a4a7443c), 720, 236}, + {UINT64_2PART_C(0xbb764c4c, a7a44410), 747, 244}, + {UINT64_2PART_C(0x8bab8eef, b6409c1a), 774, 252}, + {UINT64_2PART_C(0xd01fef10, a657842c), 800, 260}, + {UINT64_2PART_C(0x9b10a4e5, e9913129), 827, 268}, + {UINT64_2PART_C(0xe7109bfb, a19c0c9d), 853, 276}, + {UINT64_2PART_C(0xac2820d9, 623bf429), 880, 284}, + {UINT64_2PART_C(0x80444b5e, 7aa7cf85), 907, 292}, + {UINT64_2PART_C(0xbf21e440, 03acdd2d), 933, 300}, + {UINT64_2PART_C(0x8e679c2f, 5e44ff8f), 960, 308}, + {UINT64_2PART_C(0xd433179d, 9c8cb841), 986, 316}, + {UINT64_2PART_C(0x9e19db92, b4e31ba9), 1013, 324}, + {UINT64_2PART_C(0xeb96bf6e, badf77d9), 1039, 332}, + {UINT64_2PART_C(0xaf87023b, 9bf0ee6b), 1066, 340}, +}; + +static const int kCachedPowersLength = ARRAY_SIZE(kCachedPowers); +static const int kCachedPowersOffset = 348; // -1 * the first decimal_exponent. +static const double kD_1_LOG2_10 = 0.30102999566398114; // 1 / lg(10) +// Difference between the decimal exponents in the table above. +const int PowersOfTenCache::kDecimalExponentDistance = 8; +const int PowersOfTenCache::kMinDecimalExponent = -348; +const int PowersOfTenCache::kMaxDecimalExponent = 340; + +void PowersOfTenCache::GetCachedPowerForBinaryExponentRange( + int min_exponent, + int max_exponent, + DiyFp* power, + int* decimal_exponent) { + int kQ = DiyFp::kSignificandSize; + double k = ceil((min_exponent + kQ - 1) * kD_1_LOG2_10); + int foo = kCachedPowersOffset; + int index = + (foo + static_cast(k) - 1) / kDecimalExponentDistance + 1; + ASSERT(0 <= index && index < kCachedPowersLength); + CachedPower cached_power = kCachedPowers[index]; + ASSERT(min_exponent <= cached_power.binary_exponent); + ASSERT(cached_power.binary_exponent <= max_exponent); + *decimal_exponent = cached_power.decimal_exponent; + *power = DiyFp(cached_power.significand, cached_power.binary_exponent); +} + + +void PowersOfTenCache::GetCachedPowerForDecimalExponent(int requested_exponent, + DiyFp* power, + int* found_exponent) { + ASSERT(kMinDecimalExponent <= requested_exponent); + ASSERT(requested_exponent < kMaxDecimalExponent + kDecimalExponentDistance); + int index = + (requested_exponent + kCachedPowersOffset) / kDecimalExponentDistance; + CachedPower cached_power = kCachedPowers[index]; + *power = DiyFp(cached_power.significand, cached_power.binary_exponent); + *found_exponent = cached_power.decimal_exponent; + ASSERT(*found_exponent <= requested_exponent); + ASSERT(requested_exponent < *found_exponent + kDecimalExponentDistance); +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/cached-powers.h b/klm/util/double-conversion/cached-powers.h new file mode 100644 index 00000000..61a50614 --- /dev/null +++ b/klm/util/double-conversion/cached-powers.h @@ -0,0 +1,64 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_CACHED_POWERS_H_ +#define DOUBLE_CONVERSION_CACHED_POWERS_H_ + +#include "diy-fp.h" + +namespace double_conversion { + +class PowersOfTenCache { + public: + + // Not all powers of ten are cached. The decimal exponent of two neighboring + // cached numbers will differ by kDecimalExponentDistance. + static const int kDecimalExponentDistance; + + static const int kMinDecimalExponent; + static const int kMaxDecimalExponent; + + // Returns a cached power-of-ten with a binary exponent in the range + // [min_exponent; max_exponent] (boundaries included). + static void GetCachedPowerForBinaryExponentRange(int min_exponent, + int max_exponent, + DiyFp* power, + int* decimal_exponent); + + // Returns a cached power of ten x ~= 10^k such that + // k <= decimal_exponent < k + kCachedPowersDecimalDistance. + // The given decimal_exponent must satisfy + // kMinDecimalExponent <= requested_exponent, and + // requested_exponent < kMaxDecimalExponent + kDecimalExponentDistance. + static void GetCachedPowerForDecimalExponent(int requested_exponent, + DiyFp* power, + int* found_exponent); +}; + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_CACHED_POWERS_H_ diff --git a/klm/util/double-conversion/diy-fp.cc b/klm/util/double-conversion/diy-fp.cc new file mode 100644 index 00000000..ddd1891b --- /dev/null +++ b/klm/util/double-conversion/diy-fp.cc @@ -0,0 +1,57 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +#include "diy-fp.h" +#include "utils.h" + +namespace double_conversion { + +void DiyFp::Multiply(const DiyFp& other) { + // Simply "emulates" a 128 bit multiplication. + // However: the resulting number only contains 64 bits. The least + // significant 64 bits are only used for rounding the most significant 64 + // bits. + const uint64_t kM32 = 0xFFFFFFFFU; + uint64_t a = f_ >> 32; + uint64_t b = f_ & kM32; + uint64_t c = other.f_ >> 32; + uint64_t d = other.f_ & kM32; + uint64_t ac = a * c; + uint64_t bc = b * c; + uint64_t ad = a * d; + uint64_t bd = b * d; + uint64_t tmp = (bd >> 32) + (ad & kM32) + (bc & kM32); + // By adding 1U << 31 to tmp we round the final result. + // Halfway cases will be round up. + tmp += 1U << 31; + uint64_t result_f = ac + (ad >> 32) + (bc >> 32) + (tmp >> 32); + e_ += other.e_ + 64; + f_ = result_f; +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/diy-fp.h b/klm/util/double-conversion/diy-fp.h new file mode 100644 index 00000000..9dcf8fbd --- /dev/null +++ b/klm/util/double-conversion/diy-fp.h @@ -0,0 +1,118 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_DIY_FP_H_ +#define DOUBLE_CONVERSION_DIY_FP_H_ + +#include "utils.h" + +namespace double_conversion { + +// This "Do It Yourself Floating Point" class implements a floating-point number +// with a uint64 significand and an int exponent. Normalized DiyFp numbers will +// have the most significant bit of the significand set. +// Multiplication and Subtraction do not normalize their results. +// DiyFp are not designed to contain special doubles (NaN and Infinity). +class DiyFp { + public: + static const int kSignificandSize = 64; + + DiyFp() : f_(0), e_(0) {} + DiyFp(uint64_t f, int e) : f_(f), e_(e) {} + + // this = this - other. + // The exponents of both numbers must be the same and the significand of this + // must be bigger than the significand of other. + // The result will not be normalized. + void Subtract(const DiyFp& other) { + ASSERT(e_ == other.e_); + ASSERT(f_ >= other.f_); + f_ -= other.f_; + } + + // Returns a - b. + // The exponents of both numbers must be the same and this must be bigger + // than other. The result will not be normalized. + static DiyFp Minus(const DiyFp& a, const DiyFp& b) { + DiyFp result = a; + result.Subtract(b); + return result; + } + + + // this = this * other. + void Multiply(const DiyFp& other); + + // returns a * b; + static DiyFp Times(const DiyFp& a, const DiyFp& b) { + DiyFp result = a; + result.Multiply(b); + return result; + } + + void Normalize() { + ASSERT(f_ != 0); + uint64_t f = f_; + int e = e_; + + // This method is mainly called for normalizing boundaries. In general + // boundaries need to be shifted by 10 bits. We thus optimize for this case. + const uint64_t k10MSBits = UINT64_2PART_C(0xFFC00000, 00000000); + while ((f & k10MSBits) == 0) { + f <<= 10; + e -= 10; + } + while ((f & kUint64MSB) == 0) { + f <<= 1; + e--; + } + f_ = f; + e_ = e; + } + + static DiyFp Normalize(const DiyFp& a) { + DiyFp result = a; + result.Normalize(); + return result; + } + + uint64_t f() const { return f_; } + int e() const { return e_; } + + void set_f(uint64_t new_value) { f_ = new_value; } + void set_e(int new_value) { e_ = new_value; } + + private: + static const uint64_t kUint64MSB = UINT64_2PART_C(0x80000000, 00000000); + + uint64_t f_; + int e_; +}; + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_DIY_FP_H_ diff --git a/klm/util/double-conversion/double-conversion.cc b/klm/util/double-conversion/double-conversion.cc new file mode 100644 index 00000000..febba6cd --- /dev/null +++ b/klm/util/double-conversion/double-conversion.cc @@ -0,0 +1,889 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +#include "double-conversion.h" + +#include "bignum-dtoa.h" +#include "fast-dtoa.h" +#include "fixed-dtoa.h" +#include "ieee.h" +#include "strtod.h" +#include "utils.h" + +namespace double_conversion { + +const DoubleToStringConverter& DoubleToStringConverter::EcmaScriptConverter() { + int flags = UNIQUE_ZERO | EMIT_POSITIVE_EXPONENT_SIGN; + static DoubleToStringConverter converter(flags, + "Infinity", + "NaN", + 'e', + -6, 21, + 6, 0); + return converter; +} + + +bool DoubleToStringConverter::HandleSpecialValues( + double value, + StringBuilder* result_builder) const { + Double double_inspect(value); + if (double_inspect.IsInfinite()) { + if (infinity_symbol_ == NULL) return false; + if (value < 0) { + result_builder->AddCharacter('-'); + } + result_builder->AddString(infinity_symbol_); + return true; + } + if (double_inspect.IsNan()) { + if (nan_symbol_ == NULL) return false; + result_builder->AddString(nan_symbol_); + return true; + } + return false; +} + + +void DoubleToStringConverter::CreateExponentialRepresentation( + const char* decimal_digits, + int length, + int exponent, + StringBuilder* result_builder) const { + ASSERT(length != 0); + result_builder->AddCharacter(decimal_digits[0]); + if (length != 1) { + result_builder->AddCharacter('.'); + result_builder->AddSubstring(&decimal_digits[1], length-1); + } + result_builder->AddCharacter(exponent_character_); + if (exponent < 0) { + result_builder->AddCharacter('-'); + exponent = -exponent; + } else { + if ((flags_ & EMIT_POSITIVE_EXPONENT_SIGN) != 0) { + result_builder->AddCharacter('+'); + } + } + if (exponent == 0) { + result_builder->AddCharacter('0'); + return; + } + ASSERT(exponent < 1e4); + const int kMaxExponentLength = 5; + char buffer[kMaxExponentLength + 1]; + buffer[kMaxExponentLength] = '\0'; + int first_char_pos = kMaxExponentLength; + while (exponent > 0) { + buffer[--first_char_pos] = '0' + (exponent % 10); + exponent /= 10; + } + result_builder->AddSubstring(&buffer[first_char_pos], + kMaxExponentLength - first_char_pos); +} + + +void DoubleToStringConverter::CreateDecimalRepresentation( + const char* decimal_digits, + int length, + int decimal_point, + int digits_after_point, + StringBuilder* result_builder) const { + // Create a representation that is padded with zeros if needed. + if (decimal_point <= 0) { + // "0.00000decimal_rep". + result_builder->AddCharacter('0'); + if (digits_after_point > 0) { + result_builder->AddCharacter('.'); + result_builder->AddPadding('0', -decimal_point); + ASSERT(length <= digits_after_point - (-decimal_point)); + result_builder->AddSubstring(decimal_digits, length); + int remaining_digits = digits_after_point - (-decimal_point) - length; + result_builder->AddPadding('0', remaining_digits); + } + } else if (decimal_point >= length) { + // "decimal_rep0000.00000" or "decimal_rep.0000" + result_builder->AddSubstring(decimal_digits, length); + result_builder->AddPadding('0', decimal_point - length); + if (digits_after_point > 0) { + result_builder->AddCharacter('.'); + result_builder->AddPadding('0', digits_after_point); + } + } else { + // "decima.l_rep000" + ASSERT(digits_after_point > 0); + result_builder->AddSubstring(decimal_digits, decimal_point); + result_builder->AddCharacter('.'); + ASSERT(length - decimal_point <= digits_after_point); + result_builder->AddSubstring(&decimal_digits[decimal_point], + length - decimal_point); + int remaining_digits = digits_after_point - (length - decimal_point); + result_builder->AddPadding('0', remaining_digits); + } + if (digits_after_point == 0) { + if ((flags_ & EMIT_TRAILING_DECIMAL_POINT) != 0) { + result_builder->AddCharacter('.'); + } + if ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) { + result_builder->AddCharacter('0'); + } + } +} + + +bool DoubleToStringConverter::ToShortestIeeeNumber( + double value, + StringBuilder* result_builder, + DoubleToStringConverter::DtoaMode mode) const { + ASSERT(mode == SHORTEST || mode == SHORTEST_SINGLE); + if (Double(value).IsSpecial()) { + return HandleSpecialValues(value, result_builder); + } + + int decimal_point; + bool sign; + const int kDecimalRepCapacity = kBase10MaximalLength + 1; + char decimal_rep[kDecimalRepCapacity]; + int decimal_rep_length; + + DoubleToAscii(value, mode, 0, decimal_rep, kDecimalRepCapacity, + &sign, &decimal_rep_length, &decimal_point); + + bool unique_zero = (flags_ & UNIQUE_ZERO) != 0; + if (sign && (value != 0.0 || !unique_zero)) { + result_builder->AddCharacter('-'); + } + + int exponent = decimal_point - 1; + if ((decimal_in_shortest_low_ <= exponent) && + (exponent < decimal_in_shortest_high_)) { + CreateDecimalRepresentation(decimal_rep, decimal_rep_length, + decimal_point, + Max(0, decimal_rep_length - decimal_point), + result_builder); + } else { + CreateExponentialRepresentation(decimal_rep, decimal_rep_length, exponent, + result_builder); + } + return true; +} + + +bool DoubleToStringConverter::ToFixed(double value, + int requested_digits, + StringBuilder* result_builder) const { + ASSERT(kMaxFixedDigitsBeforePoint == 60); + const double kFirstNonFixed = 1e60; + + if (Double(value).IsSpecial()) { + return HandleSpecialValues(value, result_builder); + } + + if (requested_digits > kMaxFixedDigitsAfterPoint) return false; + if (value >= kFirstNonFixed || value <= -kFirstNonFixed) return false; + + // Find a sufficiently precise decimal representation of n. + int decimal_point; + bool sign; + // Add space for the '\0' byte. + const int kDecimalRepCapacity = + kMaxFixedDigitsBeforePoint + kMaxFixedDigitsAfterPoint + 1; + char decimal_rep[kDecimalRepCapacity]; + int decimal_rep_length; + DoubleToAscii(value, FIXED, requested_digits, + decimal_rep, kDecimalRepCapacity, + &sign, &decimal_rep_length, &decimal_point); + + bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0); + if (sign && (value != 0.0 || !unique_zero)) { + result_builder->AddCharacter('-'); + } + + CreateDecimalRepresentation(decimal_rep, decimal_rep_length, decimal_point, + requested_digits, result_builder); + return true; +} + + +bool DoubleToStringConverter::ToExponential( + double value, + int requested_digits, + StringBuilder* result_builder) const { + if (Double(value).IsSpecial()) { + return HandleSpecialValues(value, result_builder); + } + + if (requested_digits < -1) return false; + if (requested_digits > kMaxExponentialDigits) return false; + + int decimal_point; + bool sign; + // Add space for digit before the decimal point and the '\0' character. + const int kDecimalRepCapacity = kMaxExponentialDigits + 2; + ASSERT(kDecimalRepCapacity > kBase10MaximalLength); + char decimal_rep[kDecimalRepCapacity]; + int decimal_rep_length; + + if (requested_digits == -1) { + DoubleToAscii(value, SHORTEST, 0, + decimal_rep, kDecimalRepCapacity, + &sign, &decimal_rep_length, &decimal_point); + } else { + DoubleToAscii(value, PRECISION, requested_digits + 1, + decimal_rep, kDecimalRepCapacity, + &sign, &decimal_rep_length, &decimal_point); + ASSERT(decimal_rep_length <= requested_digits + 1); + + for (int i = decimal_rep_length; i < requested_digits + 1; ++i) { + decimal_rep[i] = '0'; + } + decimal_rep_length = requested_digits + 1; + } + + bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0); + if (sign && (value != 0.0 || !unique_zero)) { + result_builder->AddCharacter('-'); + } + + int exponent = decimal_point - 1; + CreateExponentialRepresentation(decimal_rep, + decimal_rep_length, + exponent, + result_builder); + return true; +} + + +bool DoubleToStringConverter::ToPrecision(double value, + int precision, + StringBuilder* result_builder) const { + if (Double(value).IsSpecial()) { + return HandleSpecialValues(value, result_builder); + } + + if (precision < kMinPrecisionDigits || precision > kMaxPrecisionDigits) { + return false; + } + + // Find a sufficiently precise decimal representation of n. + int decimal_point; + bool sign; + // Add one for the terminating null character. + const int kDecimalRepCapacity = kMaxPrecisionDigits + 1; + char decimal_rep[kDecimalRepCapacity]; + int decimal_rep_length; + + DoubleToAscii(value, PRECISION, precision, + decimal_rep, kDecimalRepCapacity, + &sign, &decimal_rep_length, &decimal_point); + ASSERT(decimal_rep_length <= precision); + + bool unique_zero = ((flags_ & UNIQUE_ZERO) != 0); + if (sign && (value != 0.0 || !unique_zero)) { + result_builder->AddCharacter('-'); + } + + // The exponent if we print the number as x.xxeyyy. That is with the + // decimal point after the first digit. + int exponent = decimal_point - 1; + + int extra_zero = ((flags_ & EMIT_TRAILING_ZERO_AFTER_POINT) != 0) ? 1 : 0; + if ((-decimal_point + 1 > max_leading_padding_zeroes_in_precision_mode_) || + (decimal_point - precision + extra_zero > + max_trailing_padding_zeroes_in_precision_mode_)) { + // Fill buffer to contain 'precision' digits. + // Usually the buffer is already at the correct length, but 'DoubleToAscii' + // is allowed to return less characters. + for (int i = decimal_rep_length; i < precision; ++i) { + decimal_rep[i] = '0'; + } + + CreateExponentialRepresentation(decimal_rep, + precision, + exponent, + result_builder); + } else { + CreateDecimalRepresentation(decimal_rep, decimal_rep_length, decimal_point, + Max(0, precision - decimal_point), + result_builder); + } + return true; +} + + +static BignumDtoaMode DtoaToBignumDtoaMode( + DoubleToStringConverter::DtoaMode dtoa_mode) { + switch (dtoa_mode) { + case DoubleToStringConverter::SHORTEST: return BIGNUM_DTOA_SHORTEST; + case DoubleToStringConverter::SHORTEST_SINGLE: + return BIGNUM_DTOA_SHORTEST_SINGLE; + case DoubleToStringConverter::FIXED: return BIGNUM_DTOA_FIXED; + case DoubleToStringConverter::PRECISION: return BIGNUM_DTOA_PRECISION; + default: + UNREACHABLE(); + return BIGNUM_DTOA_SHORTEST; // To silence compiler. + } +} + + +void DoubleToStringConverter::DoubleToAscii(double v, + DtoaMode mode, + int requested_digits, + char* buffer, + int buffer_length, + bool* sign, + int* length, + int* point) { + Vector vector(buffer, buffer_length); + ASSERT(!Double(v).IsSpecial()); + ASSERT(mode == SHORTEST || mode == SHORTEST_SINGLE || requested_digits >= 0); + + if (Double(v).Sign() < 0) { + *sign = true; + v = -v; + } else { + *sign = false; + } + + if (mode == PRECISION && requested_digits == 0) { + vector[0] = '\0'; + *length = 0; + return; + } + + if (v == 0) { + vector[0] = '0'; + vector[1] = '\0'; + *length = 1; + *point = 1; + return; + } + + bool fast_worked; + switch (mode) { + case SHORTEST: + fast_worked = FastDtoa(v, FAST_DTOA_SHORTEST, 0, vector, length, point); + break; + case SHORTEST_SINGLE: + fast_worked = FastDtoa(v, FAST_DTOA_SHORTEST_SINGLE, 0, + vector, length, point); + break; + case FIXED: + fast_worked = FastFixedDtoa(v, requested_digits, vector, length, point); + break; + case PRECISION: + fast_worked = FastDtoa(v, FAST_DTOA_PRECISION, requested_digits, + vector, length, point); + break; + default: + UNREACHABLE(); + fast_worked = false; + } + if (fast_worked) return; + + // If the fast dtoa didn't succeed use the slower bignum version. + BignumDtoaMode bignum_mode = DtoaToBignumDtoaMode(mode); + BignumDtoa(v, bignum_mode, requested_digits, vector, length, point); + vector[*length] = '\0'; +} + + +// Consumes the given substring from the iterator. +// Returns false, if the substring does not match. +static bool ConsumeSubString(const char** current, + const char* end, + const char* substring) { + ASSERT(**current == *substring); + for (substring++; *substring != '\0'; substring++) { + ++*current; + if (*current == end || **current != *substring) return false; + } + ++*current; + return true; +} + + +// Maximum number of significant digits in decimal representation. +// The longest possible double in decimal representation is +// (2^53 - 1) * 2 ^ -1074 that is (2 ^ 53 - 1) * 5 ^ 1074 / 10 ^ 1074 +// (768 digits). If we parse a number whose first digits are equal to a +// mean of 2 adjacent doubles (that could have up to 769 digits) the result +// must be rounded to the bigger one unless the tail consists of zeros, so +// we don't need to preserve all the digits. +const int kMaxSignificantDigits = 772; + + +// Returns true if a nonspace found and false if the end has reached. +static inline bool AdvanceToNonspace(const char** current, const char* end) { + while (*current != end) { + if (**current != ' ') return true; + ++*current; + } + return false; +} + + +static bool isDigit(int x, int radix) { + return (x >= '0' && x <= '9' && x < '0' + radix) + || (radix > 10 && x >= 'a' && x < 'a' + radix - 10) + || (radix > 10 && x >= 'A' && x < 'A' + radix - 10); +} + + +static double SignedZero(bool sign) { + return sign ? -0.0 : 0.0; +} + + +// Parsing integers with radix 2, 4, 8, 16, 32. Assumes current != end. +template +static double RadixStringToIeee(const char* current, + const char* end, + bool sign, + bool allow_trailing_junk, + double junk_string_value, + bool read_as_double, + const char** trailing_pointer) { + ASSERT(current != end); + + const int kDoubleSize = Double::kSignificandSize; + const int kSingleSize = Single::kSignificandSize; + const int kSignificandSize = read_as_double? kDoubleSize: kSingleSize; + + // Skip leading 0s. + while (*current == '0') { + ++current; + if (current == end) { + *trailing_pointer = end; + return SignedZero(sign); + } + } + + int64_t number = 0; + int exponent = 0; + const int radix = (1 << radix_log_2); + + do { + int digit; + if (*current >= '0' && *current <= '9' && *current < '0' + radix) { + digit = static_cast(*current) - '0'; + } else if (radix > 10 && *current >= 'a' && *current < 'a' + radix - 10) { + digit = static_cast(*current) - 'a' + 10; + } else if (radix > 10 && *current >= 'A' && *current < 'A' + radix - 10) { + digit = static_cast(*current) - 'A' + 10; + } else { + if (allow_trailing_junk || !AdvanceToNonspace(¤t, end)) { + break; + } else { + return junk_string_value; + } + } + + number = number * radix + digit; + int overflow = static_cast(number >> kSignificandSize); + if (overflow != 0) { + // Overflow occurred. Need to determine which direction to round the + // result. + int overflow_bits_count = 1; + while (overflow > 1) { + overflow_bits_count++; + overflow >>= 1; + } + + int dropped_bits_mask = ((1 << overflow_bits_count) - 1); + int dropped_bits = static_cast(number) & dropped_bits_mask; + number >>= overflow_bits_count; + exponent = overflow_bits_count; + + bool zero_tail = true; + while (true) { + ++current; + if (current == end || !isDigit(*current, radix)) break; + zero_tail = zero_tail && *current == '0'; + exponent += radix_log_2; + } + + if (!allow_trailing_junk && AdvanceToNonspace(¤t, end)) { + return junk_string_value; + } + + int middle_value = (1 << (overflow_bits_count - 1)); + if (dropped_bits > middle_value) { + number++; // Rounding up. + } else if (dropped_bits == middle_value) { + // Rounding to even to consistency with decimals: half-way case rounds + // up if significant part is odd and down otherwise. + if ((number & 1) != 0 || !zero_tail) { + number++; // Rounding up. + } + } + + // Rounding up may cause overflow. + if ((number & ((int64_t)1 << kSignificandSize)) != 0) { + exponent++; + number >>= 1; + } + break; + } + ++current; + } while (current != end); + + ASSERT(number < ((int64_t)1 << kSignificandSize)); + ASSERT(static_cast(static_cast(number)) == number); + + *trailing_pointer = current; + + if (exponent == 0) { + if (sign) { + if (number == 0) return -0.0; + number = -number; + } + return static_cast(number); + } + + ASSERT(number != 0); + return Double(DiyFp(number, exponent)).value(); +} + + +double StringToDoubleConverter::StringToIeee( + const char* input, + int length, + int* processed_characters_count, + bool read_as_double) const { + const char* current = input; + const char* end = input + length; + + *processed_characters_count = 0; + + const bool allow_trailing_junk = (flags_ & ALLOW_TRAILING_JUNK) != 0; + const bool allow_leading_spaces = (flags_ & ALLOW_LEADING_SPACES) != 0; + const bool allow_trailing_spaces = (flags_ & ALLOW_TRAILING_SPACES) != 0; + const bool allow_spaces_after_sign = (flags_ & ALLOW_SPACES_AFTER_SIGN) != 0; + + // To make sure that iterator dereferencing is valid the following + // convention is used: + // 1. Each '++current' statement is followed by check for equality to 'end'. + // 2. If AdvanceToNonspace returned false then current == end. + // 3. If 'current' becomes equal to 'end' the function returns or goes to + // 'parsing_done'. + // 4. 'current' is not dereferenced after the 'parsing_done' label. + // 5. Code before 'parsing_done' may rely on 'current != end'. + if (current == end) return empty_string_value_; + + if (allow_leading_spaces || allow_trailing_spaces) { + if (!AdvanceToNonspace(¤t, end)) { + *processed_characters_count = current - input; + return empty_string_value_; + } + if (!allow_leading_spaces && (input != current)) { + // No leading spaces allowed, but AdvanceToNonspace moved forward. + return junk_string_value_; + } + } + + // The longest form of simplified number is: "-.1eXXX\0". + const int kBufferSize = kMaxSignificantDigits + 10; + char buffer[kBufferSize]; // NOLINT: size is known at compile time. + int buffer_pos = 0; + + // Exponent will be adjusted if insignificant digits of the integer part + // or insignificant leading zeros of the fractional part are dropped. + int exponent = 0; + int significant_digits = 0; + int insignificant_digits = 0; + bool nonzero_digit_dropped = false; + + bool sign = false; + + if (*current == '+' || *current == '-') { + sign = (*current == '-'); + ++current; + const char* next_non_space = current; + // Skip following spaces (if allowed). + if (!AdvanceToNonspace(&next_non_space, end)) return junk_string_value_; + if (!allow_spaces_after_sign && (current != next_non_space)) { + return junk_string_value_; + } + current = next_non_space; + } + + if (infinity_symbol_ != NULL) { + if (*current == infinity_symbol_[0]) { + if (!ConsumeSubString(¤t, end, infinity_symbol_)) { + return junk_string_value_; + } + + if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) { + return junk_string_value_; + } + if (!allow_trailing_junk && AdvanceToNonspace(¤t, end)) { + return junk_string_value_; + } + + ASSERT(buffer_pos == 0); + *processed_characters_count = current - input; + return sign ? -Double::Infinity() : Double::Infinity(); + } + } + + if (nan_symbol_ != NULL) { + if (*current == nan_symbol_[0]) { + if (!ConsumeSubString(¤t, end, nan_symbol_)) { + return junk_string_value_; + } + + if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) { + return junk_string_value_; + } + if (!allow_trailing_junk && AdvanceToNonspace(¤t, end)) { + return junk_string_value_; + } + + ASSERT(buffer_pos == 0); + *processed_characters_count = current - input; + return sign ? -Double::NaN() : Double::NaN(); + } + } + + bool leading_zero = false; + if (*current == '0') { + ++current; + if (current == end) { + *processed_characters_count = current - input; + return SignedZero(sign); + } + + leading_zero = true; + + // It could be hexadecimal value. + if ((flags_ & ALLOW_HEX) && (*current == 'x' || *current == 'X')) { + ++current; + if (current == end || !isDigit(*current, 16)) { + return junk_string_value_; // "0x". + } + + const char* tail_pointer = NULL; + double result = RadixStringToIeee<4>(current, + end, + sign, + allow_trailing_junk, + junk_string_value_, + read_as_double, + &tail_pointer); + if (tail_pointer != NULL) { + if (allow_trailing_spaces) AdvanceToNonspace(&tail_pointer, end); + *processed_characters_count = tail_pointer - input; + } + return result; + } + + // Ignore leading zeros in the integer part. + while (*current == '0') { + ++current; + if (current == end) { + *processed_characters_count = current - input; + return SignedZero(sign); + } + } + } + + bool octal = leading_zero && (flags_ & ALLOW_OCTALS) != 0; + + // Copy significant digits of the integer part (if any) to the buffer. + while (*current >= '0' && *current <= '9') { + if (significant_digits < kMaxSignificantDigits) { + ASSERT(buffer_pos < kBufferSize); + buffer[buffer_pos++] = static_cast(*current); + significant_digits++; + // Will later check if it's an octal in the buffer. + } else { + insignificant_digits++; // Move the digit into the exponential part. + nonzero_digit_dropped = nonzero_digit_dropped || *current != '0'; + } + octal = octal && *current < '8'; + ++current; + if (current == end) goto parsing_done; + } + + if (significant_digits == 0) { + octal = false; + } + + if (*current == '.') { + if (octal && !allow_trailing_junk) return junk_string_value_; + if (octal) goto parsing_done; + + ++current; + if (current == end) { + if (significant_digits == 0 && !leading_zero) { + return junk_string_value_; + } else { + goto parsing_done; + } + } + + if (significant_digits == 0) { + // octal = false; + // Integer part consists of 0 or is absent. Significant digits start after + // leading zeros (if any). + while (*current == '0') { + ++current; + if (current == end) { + *processed_characters_count = current - input; + return SignedZero(sign); + } + exponent--; // Move this 0 into the exponent. + } + } + + // There is a fractional part. + // We don't emit a '.', but adjust the exponent instead. + while (*current >= '0' && *current <= '9') { + if (significant_digits < kMaxSignificantDigits) { + ASSERT(buffer_pos < kBufferSize); + buffer[buffer_pos++] = static_cast(*current); + significant_digits++; + exponent--; + } else { + // Ignore insignificant digits in the fractional part. + nonzero_digit_dropped = nonzero_digit_dropped || *current != '0'; + } + ++current; + if (current == end) goto parsing_done; + } + } + + if (!leading_zero && exponent == 0 && significant_digits == 0) { + // If leading_zeros is true then the string contains zeros. + // If exponent < 0 then string was [+-]\.0*... + // If significant_digits != 0 the string is not equal to 0. + // Otherwise there are no digits in the string. + return junk_string_value_; + } + + // Parse exponential part. + if (*current == 'e' || *current == 'E') { + if (octal && !allow_trailing_junk) return junk_string_value_; + if (octal) goto parsing_done; + ++current; + if (current == end) { + if (allow_trailing_junk) { + goto parsing_done; + } else { + return junk_string_value_; + } + } + char sign = '+'; + if (*current == '+' || *current == '-') { + sign = static_cast(*current); + ++current; + if (current == end) { + if (allow_trailing_junk) { + goto parsing_done; + } else { + return junk_string_value_; + } + } + } + + if (current == end || *current < '0' || *current > '9') { + if (allow_trailing_junk) { + goto parsing_done; + } else { + return junk_string_value_; + } + } + + const int max_exponent = INT_MAX / 2; + ASSERT(-max_exponent / 2 <= exponent && exponent <= max_exponent / 2); + int num = 0; + do { + // Check overflow. + int digit = *current - '0'; + if (num >= max_exponent / 10 + && !(num == max_exponent / 10 && digit <= max_exponent % 10)) { + num = max_exponent; + } else { + num = num * 10 + digit; + } + ++current; + } while (current != end && *current >= '0' && *current <= '9'); + + exponent += (sign == '-' ? -num : num); + } + + if (!(allow_trailing_spaces || allow_trailing_junk) && (current != end)) { + return junk_string_value_; + } + if (!allow_trailing_junk && AdvanceToNonspace(¤t, end)) { + return junk_string_value_; + } + if (allow_trailing_spaces) { + AdvanceToNonspace(¤t, end); + } + + parsing_done: + exponent += insignificant_digits; + + if (octal) { + double result; + const char* tail_pointer = NULL; + result = RadixStringToIeee<3>(buffer, + buffer + buffer_pos, + sign, + allow_trailing_junk, + junk_string_value_, + read_as_double, + &tail_pointer); + ASSERT(tail_pointer != NULL); + *processed_characters_count = current - input; + return result; + } + + if (nonzero_digit_dropped) { + buffer[buffer_pos++] = '1'; + exponent--; + } + + ASSERT(buffer_pos < kBufferSize); + buffer[buffer_pos] = '\0'; + + double converted; + if (read_as_double) { + converted = Strtod(Vector(buffer, buffer_pos), exponent); + } else { + converted = Strtof(Vector(buffer, buffer_pos), exponent); + } + *processed_characters_count = current - input; + return sign? -converted: converted; +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/double-conversion.h b/klm/util/double-conversion/double-conversion.h new file mode 100644 index 00000000..1c3387d4 --- /dev/null +++ b/klm/util/double-conversion/double-conversion.h @@ -0,0 +1,536 @@ +// Copyright 2012 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ +#define DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ + +#include "utils.h" + +namespace double_conversion { + +class DoubleToStringConverter { + public: + // When calling ToFixed with a double > 10^kMaxFixedDigitsBeforePoint + // or a requested_digits parameter > kMaxFixedDigitsAfterPoint then the + // function returns false. + static const int kMaxFixedDigitsBeforePoint = 60; + static const int kMaxFixedDigitsAfterPoint = 60; + + // When calling ToExponential with a requested_digits + // parameter > kMaxExponentialDigits then the function returns false. + static const int kMaxExponentialDigits = 120; + + // When calling ToPrecision with a requested_digits + // parameter < kMinPrecisionDigits or requested_digits > kMaxPrecisionDigits + // then the function returns false. + static const int kMinPrecisionDigits = 1; + static const int kMaxPrecisionDigits = 120; + + enum Flags { + NO_FLAGS = 0, + EMIT_POSITIVE_EXPONENT_SIGN = 1, + EMIT_TRAILING_DECIMAL_POINT = 2, + EMIT_TRAILING_ZERO_AFTER_POINT = 4, + UNIQUE_ZERO = 8 + }; + + // Flags should be a bit-or combination of the possible Flags-enum. + // - NO_FLAGS: no special flags. + // - EMIT_POSITIVE_EXPONENT_SIGN: when the number is converted into exponent + // form, emits a '+' for positive exponents. Example: 1.2e+2. + // - EMIT_TRAILING_DECIMAL_POINT: when the input number is an integer and is + // converted into decimal format then a trailing decimal point is appended. + // Example: 2345.0 is converted to "2345.". + // - EMIT_TRAILING_ZERO_AFTER_POINT: in addition to a trailing decimal point + // emits a trailing '0'-character. This flag requires the + // EXMIT_TRAILING_DECIMAL_POINT flag. + // Example: 2345.0 is converted to "2345.0". + // - UNIQUE_ZERO: "-0.0" is converted to "0.0". + // + // Infinity symbol and nan_symbol provide the string representation for these + // special values. If the string is NULL and the special value is encountered + // then the conversion functions return false. + // + // The exponent_character is used in exponential representations. It is + // usually 'e' or 'E'. + // + // When converting to the shortest representation the converter will + // represent input numbers in decimal format if they are in the interval + // [10^decimal_in_shortest_low; 10^decimal_in_shortest_high[ + // (lower boundary included, greater boundary excluded). + // Example: with decimal_in_shortest_low = -6 and + // decimal_in_shortest_high = 21: + // ToShortest(0.000001) -> "0.000001" + // ToShortest(0.0000001) -> "1e-7" + // ToShortest(111111111111111111111.0) -> "111111111111111110000" + // ToShortest(100000000000000000000.0) -> "100000000000000000000" + // ToShortest(1111111111111111111111.0) -> "1.1111111111111111e+21" + // + // When converting to precision mode the converter may add + // max_leading_padding_zeroes before returning the number in exponential + // format. + // Example with max_leading_padding_zeroes_in_precision_mode = 6. + // ToPrecision(0.0000012345, 2) -> "0.0000012" + // ToPrecision(0.00000012345, 2) -> "1.2e-7" + // Similarily the converter may add up to + // max_trailing_padding_zeroes_in_precision_mode in precision mode to avoid + // returning an exponential representation. A zero added by the + // EMIT_TRAILING_ZERO_AFTER_POINT flag is counted for this limit. + // Examples for max_trailing_padding_zeroes_in_precision_mode = 1: + // ToPrecision(230.0, 2) -> "230" + // ToPrecision(230.0, 2) -> "230." with EMIT_TRAILING_DECIMAL_POINT. + // ToPrecision(230.0, 2) -> "2.3e2" with EMIT_TRAILING_ZERO_AFTER_POINT. + DoubleToStringConverter(int flags, + const char* infinity_symbol, + const char* nan_symbol, + char exponent_character, + int decimal_in_shortest_low, + int decimal_in_shortest_high, + int max_leading_padding_zeroes_in_precision_mode, + int max_trailing_padding_zeroes_in_precision_mode) + : flags_(flags), + infinity_symbol_(infinity_symbol), + nan_symbol_(nan_symbol), + exponent_character_(exponent_character), + decimal_in_shortest_low_(decimal_in_shortest_low), + decimal_in_shortest_high_(decimal_in_shortest_high), + max_leading_padding_zeroes_in_precision_mode_( + max_leading_padding_zeroes_in_precision_mode), + max_trailing_padding_zeroes_in_precision_mode_( + max_trailing_padding_zeroes_in_precision_mode) { + // When 'trailing zero after the point' is set, then 'trailing point' + // must be set too. + ASSERT(((flags & EMIT_TRAILING_DECIMAL_POINT) != 0) || + !((flags & EMIT_TRAILING_ZERO_AFTER_POINT) != 0)); + } + + // Returns a converter following the EcmaScript specification. + static const DoubleToStringConverter& EcmaScriptConverter(); + + // Computes the shortest string of digits that correctly represent the input + // number. Depending on decimal_in_shortest_low and decimal_in_shortest_high + // (see constructor) it then either returns a decimal representation, or an + // exponential representation. + // Example with decimal_in_shortest_low = -6, + // decimal_in_shortest_high = 21, + // EMIT_POSITIVE_EXPONENT_SIGN activated, and + // EMIT_TRAILING_DECIMAL_POINT deactived: + // ToShortest(0.000001) -> "0.000001" + // ToShortest(0.0000001) -> "1e-7" + // ToShortest(111111111111111111111.0) -> "111111111111111110000" + // ToShortest(100000000000000000000.0) -> "100000000000000000000" + // ToShortest(1111111111111111111111.0) -> "1.1111111111111111e+21" + // + // Note: the conversion may round the output if the returned string + // is accurate enough to uniquely identify the input-number. + // For example the most precise representation of the double 9e59 equals + // "899999999999999918767229449717619953810131273674690656206848", but + // the converter will return the shorter (but still correct) "9e59". + // + // Returns true if the conversion succeeds. The conversion always succeeds + // except when the input value is special and no infinity_symbol or + // nan_symbol has been given to the constructor. + bool ToShortest(double value, StringBuilder* result_builder) const { + return ToShortestIeeeNumber(value, result_builder, SHORTEST); + } + + // Same as ToShortest, but for single-precision floats. + bool ToShortestSingle(float value, StringBuilder* result_builder) const { + return ToShortestIeeeNumber(value, result_builder, SHORTEST_SINGLE); + } + + + // Computes a decimal representation with a fixed number of digits after the + // decimal point. The last emitted digit is rounded. + // + // Examples: + // ToFixed(3.12, 1) -> "3.1" + // ToFixed(3.1415, 3) -> "3.142" + // ToFixed(1234.56789, 4) -> "1234.5679" + // ToFixed(1.23, 5) -> "1.23000" + // ToFixed(0.1, 4) -> "0.1000" + // ToFixed(1e30, 2) -> "1000000000000000019884624838656.00" + // ToFixed(0.1, 30) -> "0.100000000000000005551115123126" + // ToFixed(0.1, 17) -> "0.10000000000000001" + // + // If requested_digits equals 0, then the tail of the result depends on + // the EMIT_TRAILING_DECIMAL_POINT and EMIT_TRAILING_ZERO_AFTER_POINT. + // Examples, for requested_digits == 0, + // let EMIT_TRAILING_DECIMAL_POINT and EMIT_TRAILING_ZERO_AFTER_POINT be + // - false and false: then 123.45 -> 123 + // 0.678 -> 1 + // - true and false: then 123.45 -> 123. + // 0.678 -> 1. + // - true and true: then 123.45 -> 123.0 + // 0.678 -> 1.0 + // + // Returns true if the conversion succeeds. The conversion always succeeds + // except for the following cases: + // - the input value is special and no infinity_symbol or nan_symbol has + // been provided to the constructor, + // - 'value' > 10^kMaxFixedDigitsBeforePoint, or + // - 'requested_digits' > kMaxFixedDigitsAfterPoint. + // The last two conditions imply that the result will never contain more than + // 1 + kMaxFixedDigitsBeforePoint + 1 + kMaxFixedDigitsAfterPoint characters + // (one additional character for the sign, and one for the decimal point). + bool ToFixed(double value, + int requested_digits, + StringBuilder* result_builder) const; + + // Computes a representation in exponential format with requested_digits + // after the decimal point. The last emitted digit is rounded. + // If requested_digits equals -1, then the shortest exponential representation + // is computed. + // + // Examples with EMIT_POSITIVE_EXPONENT_SIGN deactivated, and + // exponent_character set to 'e'. + // ToExponential(3.12, 1) -> "3.1e0" + // ToExponential(5.0, 3) -> "5.000e0" + // ToExponential(0.001, 2) -> "1.00e-3" + // ToExponential(3.1415, -1) -> "3.1415e0" + // ToExponential(3.1415, 4) -> "3.1415e0" + // ToExponential(3.1415, 3) -> "3.142e0" + // ToExponential(123456789000000, 3) -> "1.235e14" + // ToExponential(1000000000000000019884624838656.0, -1) -> "1e30" + // ToExponential(1000000000000000019884624838656.0, 32) -> + // "1.00000000000000001988462483865600e30" + // ToExponential(1234, 0) -> "1e3" + // + // Returns true if the conversion succeeds. The conversion always succeeds + // except for the following cases: + // - the input value is special and no infinity_symbol or nan_symbol has + // been provided to the constructor, + // - 'requested_digits' > kMaxExponentialDigits. + // The last condition implies that the result will never contain more than + // kMaxExponentialDigits + 8 characters (the sign, the digit before the + // decimal point, the decimal point, the exponent character, the + // exponent's sign, and at most 3 exponent digits). + bool ToExponential(double value, + int requested_digits, + StringBuilder* result_builder) const; + + // Computes 'precision' leading digits of the given 'value' and returns them + // either in exponential or decimal format, depending on + // max_{leading|trailing}_padding_zeroes_in_precision_mode (given to the + // constructor). + // The last computed digit is rounded. + // + // Example with max_leading_padding_zeroes_in_precision_mode = 6. + // ToPrecision(0.0000012345, 2) -> "0.0000012" + // ToPrecision(0.00000012345, 2) -> "1.2e-7" + // Similarily the converter may add up to + // max_trailing_padding_zeroes_in_precision_mode in precision mode to avoid + // returning an exponential representation. A zero added by the + // EMIT_TRAILING_ZERO_AFTER_POINT flag is counted for this limit. + // Examples for max_trailing_padding_zeroes_in_precision_mode = 1: + // ToPrecision(230.0, 2) -> "230" + // ToPrecision(230.0, 2) -> "230." with EMIT_TRAILING_DECIMAL_POINT. + // ToPrecision(230.0, 2) -> "2.3e2" with EMIT_TRAILING_ZERO_AFTER_POINT. + // Examples for max_trailing_padding_zeroes_in_precision_mode = 3, and no + // EMIT_TRAILING_ZERO_AFTER_POINT: + // ToPrecision(123450.0, 6) -> "123450" + // ToPrecision(123450.0, 5) -> "123450" + // ToPrecision(123450.0, 4) -> "123500" + // ToPrecision(123450.0, 3) -> "123000" + // ToPrecision(123450.0, 2) -> "1.2e5" + // + // Returns true if the conversion succeeds. The conversion always succeeds + // except for the following cases: + // - the input value is special and no infinity_symbol or nan_symbol has + // been provided to the constructor, + // - precision < kMinPericisionDigits + // - precision > kMaxPrecisionDigits + // The last condition implies that the result will never contain more than + // kMaxPrecisionDigits + 7 characters (the sign, the decimal point, the + // exponent character, the exponent's sign, and at most 3 exponent digits). + bool ToPrecision(double value, + int precision, + StringBuilder* result_builder) const; + + enum DtoaMode { + // Produce the shortest correct representation. + // For example the output of 0.299999999999999988897 is (the less accurate + // but correct) 0.3. + SHORTEST, + // Same as SHORTEST, but for single-precision floats. + SHORTEST_SINGLE, + // Produce a fixed number of digits after the decimal point. + // For instance fixed(0.1, 4) becomes 0.1000 + // If the input number is big, the output will be big. + FIXED, + // Fixed number of digits (independent of the decimal point). + PRECISION + }; + + // The maximal number of digits that are needed to emit a double in base 10. + // A higher precision can be achieved by using more digits, but the shortest + // accurate representation of any double will never use more digits than + // kBase10MaximalLength. + // Note that DoubleToAscii null-terminates its input. So the given buffer + // should be at least kBase10MaximalLength + 1 characters long. + static const int kBase10MaximalLength = 17; + + // Converts the given double 'v' to ascii. 'v' must not be NaN, +Infinity, or + // -Infinity. In SHORTEST_SINGLE-mode this restriction also applies to 'v' + // after it has been casted to a single-precision float. That is, in this + // mode static_cast(v) must not be NaN, +Infinity or -Infinity. + // + // The result should be interpreted as buffer * 10^(point-length). + // + // The output depends on the given mode: + // - SHORTEST: produce the least amount of digits for which the internal + // identity requirement is still satisfied. If the digits are printed + // (together with the correct exponent) then reading this number will give + // 'v' again. The buffer will choose the representation that is closest to + // 'v'. If there are two at the same distance, than the one farther away + // from 0 is chosen (halfway cases - ending with 5 - are rounded up). + // In this mode the 'requested_digits' parameter is ignored. + // - SHORTEST_SINGLE: same as SHORTEST but with single-precision. + // - FIXED: produces digits necessary to print a given number with + // 'requested_digits' digits after the decimal point. The produced digits + // might be too short in which case the caller has to fill the remainder + // with '0's. + // Example: toFixed(0.001, 5) is allowed to return buffer="1", point=-2. + // Halfway cases are rounded towards +/-Infinity (away from 0). The call + // toFixed(0.15, 2) thus returns buffer="2", point=0. + // The returned buffer may contain digits that would be truncated from the + // shortest representation of the input. + // - PRECISION: produces 'requested_digits' where the first digit is not '0'. + // Even though the length of produced digits usually equals + // 'requested_digits', the function is allowed to return fewer digits, in + // which case the caller has to fill the missing digits with '0's. + // Halfway cases are again rounded away from 0. + // DoubleToAscii expects the given buffer to be big enough to hold all + // digits and a terminating null-character. In SHORTEST-mode it expects a + // buffer of at least kBase10MaximalLength + 1. In all other modes the + // requested_digits parameter and the padding-zeroes limit the size of the + // output. Don't forget the decimal point, the exponent character and the + // terminating null-character when computing the maximal output size. + // The given length is only used in debug mode to ensure the buffer is big + // enough. + static void DoubleToAscii(double v, + DtoaMode mode, + int requested_digits, + char* buffer, + int buffer_length, + bool* sign, + int* length, + int* point); + + private: + // Implementation for ToShortest and ToShortestSingle. + bool ToShortestIeeeNumber(double value, + StringBuilder* result_builder, + DtoaMode mode) const; + + // If the value is a special value (NaN or Infinity) constructs the + // corresponding string using the configured infinity/nan-symbol. + // If either of them is NULL or the value is not special then the + // function returns false. + bool HandleSpecialValues(double value, StringBuilder* result_builder) const; + // Constructs an exponential representation (i.e. 1.234e56). + // The given exponent assumes a decimal point after the first decimal digit. + void CreateExponentialRepresentation(const char* decimal_digits, + int length, + int exponent, + StringBuilder* result_builder) const; + // Creates a decimal representation (i.e 1234.5678). + void CreateDecimalRepresentation(const char* decimal_digits, + int length, + int decimal_point, + int digits_after_point, + StringBuilder* result_builder) const; + + const int flags_; + const char* const infinity_symbol_; + const char* const nan_symbol_; + const char exponent_character_; + const int decimal_in_shortest_low_; + const int decimal_in_shortest_high_; + const int max_leading_padding_zeroes_in_precision_mode_; + const int max_trailing_padding_zeroes_in_precision_mode_; + + DISALLOW_IMPLICIT_CONSTRUCTORS(DoubleToStringConverter); +}; + + +class StringToDoubleConverter { + public: + // Enumeration for allowing octals and ignoring junk when converting + // strings to numbers. + enum Flags { + NO_FLAGS = 0, + ALLOW_HEX = 1, + ALLOW_OCTALS = 2, + ALLOW_TRAILING_JUNK = 4, + ALLOW_LEADING_SPACES = 8, + ALLOW_TRAILING_SPACES = 16, + ALLOW_SPACES_AFTER_SIGN = 32 + }; + + // Flags should be a bit-or combination of the possible Flags-enum. + // - NO_FLAGS: no special flags. + // - ALLOW_HEX: recognizes the prefix "0x". Hex numbers may only be integers. + // Ex: StringToDouble("0x1234") -> 4660.0 + // In StringToDouble("0x1234.56") the characters ".56" are trailing + // junk. The result of the call is hence dependent on + // the ALLOW_TRAILING_JUNK flag and/or the junk value. + // With this flag "0x" is a junk-string. Even with ALLOW_TRAILING_JUNK, + // the string will not be parsed as "0" followed by junk. + // + // - ALLOW_OCTALS: recognizes the prefix "0" for octals: + // If a sequence of octal digits starts with '0', then the number is + // read as octal integer. Octal numbers may only be integers. + // Ex: StringToDouble("01234") -> 668.0 + // StringToDouble("012349") -> 12349.0 // Not a sequence of octal + // // digits. + // In StringToDouble("01234.56") the characters ".56" are trailing + // junk. The result of the call is hence dependent on + // the ALLOW_TRAILING_JUNK flag and/or the junk value. + // In StringToDouble("01234e56") the characters "e56" are trailing + // junk, too. + // - ALLOW_TRAILING_JUNK: ignore trailing characters that are not part of + // a double literal. + // - ALLOW_LEADING_SPACES: skip over leading spaces. + // - ALLOW_TRAILING_SPACES: ignore trailing spaces. + // - ALLOW_SPACES_AFTER_SIGN: ignore spaces after the sign. + // Ex: StringToDouble("- 123.2") -> -123.2. + // StringToDouble("+ 123.2") -> 123.2 + // + // empty_string_value is returned when an empty string is given as input. + // If ALLOW_LEADING_SPACES or ALLOW_TRAILING_SPACES are set, then a string + // containing only spaces is converted to the 'empty_string_value', too. + // + // junk_string_value is returned when + // a) ALLOW_TRAILING_JUNK is not set, and a junk character (a character not + // part of a double-literal) is found. + // b) ALLOW_TRAILING_JUNK is set, but the string does not start with a + // double literal. + // + // infinity_symbol and nan_symbol are strings that are used to detect + // inputs that represent infinity and NaN. They can be null, in which case + // they are ignored. + // The conversion routine first reads any possible signs. Then it compares the + // following character of the input-string with the first character of + // the infinity, and nan-symbol. If either matches, the function assumes, that + // a match has been found, and expects the following input characters to match + // the remaining characters of the special-value symbol. + // This means that the following restrictions apply to special-value symbols: + // - they must not start with signs ('+', or '-'), + // - they must not have the same first character. + // - they must not start with digits. + // + // Examples: + // flags = ALLOW_HEX | ALLOW_TRAILING_JUNK, + // empty_string_value = 0.0, + // junk_string_value = NaN, + // infinity_symbol = "infinity", + // nan_symbol = "nan": + // StringToDouble("0x1234") -> 4660.0. + // StringToDouble("0x1234K") -> 4660.0. + // StringToDouble("") -> 0.0 // empty_string_value. + // StringToDouble(" ") -> NaN // junk_string_value. + // StringToDouble(" 1") -> NaN // junk_string_value. + // StringToDouble("0x") -> NaN // junk_string_value. + // StringToDouble("-123.45") -> -123.45. + // StringToDouble("--123.45") -> NaN // junk_string_value. + // StringToDouble("123e45") -> 123e45. + // StringToDouble("123E45") -> 123e45. + // StringToDouble("123e+45") -> 123e45. + // StringToDouble("123E-45") -> 123e-45. + // StringToDouble("123e") -> 123.0 // trailing junk ignored. + // StringToDouble("123e-") -> 123.0 // trailing junk ignored. + // StringToDouble("+NaN") -> NaN // NaN string literal. + // StringToDouble("-infinity") -> -inf. // infinity literal. + // StringToDouble("Infinity") -> NaN // junk_string_value. + // + // flags = ALLOW_OCTAL | ALLOW_LEADING_SPACES, + // empty_string_value = 0.0, + // junk_string_value = NaN, + // infinity_symbol = NULL, + // nan_symbol = NULL: + // StringToDouble("0x1234") -> NaN // junk_string_value. + // StringToDouble("01234") -> 668.0. + // StringToDouble("") -> 0.0 // empty_string_value. + // StringToDouble(" ") -> 0.0 // empty_string_value. + // StringToDouble(" 1") -> 1.0 + // StringToDouble("0x") -> NaN // junk_string_value. + // StringToDouble("0123e45") -> NaN // junk_string_value. + // StringToDouble("01239E45") -> 1239e45. + // StringToDouble("-infinity") -> NaN // junk_string_value. + // StringToDouble("NaN") -> NaN // junk_string_value. + StringToDoubleConverter(int flags, + double empty_string_value, + double junk_string_value, + const char* infinity_symbol, + const char* nan_symbol) + : flags_(flags), + empty_string_value_(empty_string_value), + junk_string_value_(junk_string_value), + infinity_symbol_(infinity_symbol), + nan_symbol_(nan_symbol) { + } + + // Performs the conversion. + // The output parameter 'processed_characters_count' is set to the number + // of characters that have been processed to read the number. + // Spaces than are processed with ALLOW_{LEADING|TRAILING}_SPACES are included + // in the 'processed_characters_count'. Trailing junk is never included. + double StringToDouble(const char* buffer, + int length, + int* processed_characters_count) const { + return StringToIeee(buffer, length, processed_characters_count, true); + } + + // Same as StringToDouble but reads a float. + // Note that this is not equivalent to static_cast(StringToDouble(...)) + // due to potential double-rounding. + float StringToFloat(const char* buffer, + int length, + int* processed_characters_count) const { + return static_cast(StringToIeee(buffer, length, + processed_characters_count, false)); + } + + private: + const int flags_; + const double empty_string_value_; + const double junk_string_value_; + const char* const infinity_symbol_; + const char* const nan_symbol_; + + double StringToIeee(const char* buffer, + int length, + int* processed_characters_count, + bool read_as_double) const; + + DISALLOW_IMPLICIT_CONSTRUCTORS(StringToDoubleConverter); +}; + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_DOUBLE_CONVERSION_H_ diff --git a/klm/util/double-conversion/fast-dtoa.cc b/klm/util/double-conversion/fast-dtoa.cc new file mode 100644 index 00000000..1a0f8235 --- /dev/null +++ b/klm/util/double-conversion/fast-dtoa.cc @@ -0,0 +1,664 @@ +// Copyright 2012 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include "fast-dtoa.h" + +#include "cached-powers.h" +#include "diy-fp.h" +#include "ieee.h" + +namespace double_conversion { + +// The minimal and maximal target exponent define the range of w's binary +// exponent, where 'w' is the result of multiplying the input by a cached power +// of ten. +// +// A different range might be chosen on a different platform, to optimize digit +// generation, but a smaller range requires more powers of ten to be cached. +static const int kMinimalTargetExponent = -60; +static const int kMaximalTargetExponent = -32; + + +// Adjusts the last digit of the generated number, and screens out generated +// solutions that may be inaccurate. A solution may be inaccurate if it is +// outside the safe interval, or if we cannot prove that it is closer to the +// input than a neighboring representation of the same length. +// +// Input: * buffer containing the digits of too_high / 10^kappa +// * the buffer's length +// * distance_too_high_w == (too_high - w).f() * unit +// * unsafe_interval == (too_high - too_low).f() * unit +// * rest = (too_high - buffer * 10^kappa).f() * unit +// * ten_kappa = 10^kappa * unit +// * unit = the common multiplier +// Output: returns true if the buffer is guaranteed to contain the closest +// representable number to the input. +// Modifies the generated digits in the buffer to approach (round towards) w. +static bool RoundWeed(Vector buffer, + int length, + uint64_t distance_too_high_w, + uint64_t unsafe_interval, + uint64_t rest, + uint64_t ten_kappa, + uint64_t unit) { + uint64_t small_distance = distance_too_high_w - unit; + uint64_t big_distance = distance_too_high_w + unit; + // Let w_low = too_high - big_distance, and + // w_high = too_high - small_distance. + // Note: w_low < w < w_high + // + // The real w (* unit) must lie somewhere inside the interval + // ]w_low; w_high[ (often written as "(w_low; w_high)") + + // Basically the buffer currently contains a number in the unsafe interval + // ]too_low; too_high[ with too_low < w < too_high + // + // too_high - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // ^v 1 unit ^ ^ ^ ^ + // boundary_high --------------------- . . . . + // ^v 1 unit . . . . + // - - - - - - - - - - - - - - - - - - - + - - + - - - - - - . . + // . . ^ . . + // . big_distance . . . + // . . . . rest + // small_distance . . . . + // v . . . . + // w_high - - - - - - - - - - - - - - - - - - . . . . + // ^v 1 unit . . . . + // w ---------------------------------------- . . . . + // ^v 1 unit v . . . + // w_low - - - - - - - - - - - - - - - - - - - - - . . . + // . . v + // buffer --------------------------------------------------+-------+-------- + // . . + // safe_interval . + // v . + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - . + // ^v 1 unit . + // boundary_low ------------------------- unsafe_interval + // ^v 1 unit v + // too_low - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + // + // + // Note that the value of buffer could lie anywhere inside the range too_low + // to too_high. + // + // boundary_low, boundary_high and w are approximations of the real boundaries + // and v (the input number). They are guaranteed to be precise up to one unit. + // In fact the error is guaranteed to be strictly less than one unit. + // + // Anything that lies outside the unsafe interval is guaranteed not to round + // to v when read again. + // Anything that lies inside the safe interval is guaranteed to round to v + // when read again. + // If the number inside the buffer lies inside the unsafe interval but not + // inside the safe interval then we simply do not know and bail out (returning + // false). + // + // Similarly we have to take into account the imprecision of 'w' when finding + // the closest representation of 'w'. If we have two potential + // representations, and one is closer to both w_low and w_high, then we know + // it is closer to the actual value v. + // + // By generating the digits of too_high we got the largest (closest to + // too_high) buffer that is still in the unsafe interval. In the case where + // w_high < buffer < too_high we try to decrement the buffer. + // This way the buffer approaches (rounds towards) w. + // There are 3 conditions that stop the decrementation process: + // 1) the buffer is already below w_high + // 2) decrementing the buffer would make it leave the unsafe interval + // 3) decrementing the buffer would yield a number below w_high and farther + // away than the current number. In other words: + // (buffer{-1} < w_high) && w_high - buffer{-1} > buffer - w_high + // Instead of using the buffer directly we use its distance to too_high. + // Conceptually rest ~= too_high - buffer + // We need to do the following tests in this order to avoid over- and + // underflows. + ASSERT(rest <= unsafe_interval); + while (rest < small_distance && // Negated condition 1 + unsafe_interval - rest >= ten_kappa && // Negated condition 2 + (rest + ten_kappa < small_distance || // buffer{-1} > w_high + small_distance - rest >= rest + ten_kappa - small_distance)) { + buffer[length - 1]--; + rest += ten_kappa; + } + + // We have approached w+ as much as possible. We now test if approaching w- + // would require changing the buffer. If yes, then we have two possible + // representations close to w, but we cannot decide which one is closer. + if (rest < big_distance && + unsafe_interval - rest >= ten_kappa && + (rest + ten_kappa < big_distance || + big_distance - rest > rest + ten_kappa - big_distance)) { + return false; + } + + // Weeding test. + // The safe interval is [too_low + 2 ulp; too_high - 2 ulp] + // Since too_low = too_high - unsafe_interval this is equivalent to + // [too_high - unsafe_interval + 4 ulp; too_high - 2 ulp] + // Conceptually we have: rest ~= too_high - buffer + return (2 * unit <= rest) && (rest <= unsafe_interval - 4 * unit); +} + + +// Rounds the buffer upwards if the result is closer to v by possibly adding +// 1 to the buffer. If the precision of the calculation is not sufficient to +// round correctly, return false. +// The rounding might shift the whole buffer in which case the kappa is +// adjusted. For example "99", kappa = 3 might become "10", kappa = 4. +// +// If 2*rest > ten_kappa then the buffer needs to be round up. +// rest can have an error of +/- 1 unit. This function accounts for the +// imprecision and returns false, if the rounding direction cannot be +// unambiguously determined. +// +// Precondition: rest < ten_kappa. +static bool RoundWeedCounted(Vector buffer, + int length, + uint64_t rest, + uint64_t ten_kappa, + uint64_t unit, + int* kappa) { + ASSERT(rest < ten_kappa); + // The following tests are done in a specific order to avoid overflows. They + // will work correctly with any uint64 values of rest < ten_kappa and unit. + // + // If the unit is too big, then we don't know which way to round. For example + // a unit of 50 means that the real number lies within rest +/- 50. If + // 10^kappa == 40 then there is no way to tell which way to round. + if (unit >= ten_kappa) return false; + // Even if unit is just half the size of 10^kappa we are already completely + // lost. (And after the previous test we know that the expression will not + // over/underflow.) + if (ten_kappa - unit <= unit) return false; + // If 2 * (rest + unit) <= 10^kappa we can safely round down. + if ((ten_kappa - rest > rest) && (ten_kappa - 2 * rest >= 2 * unit)) { + return true; + } + // If 2 * (rest - unit) >= 10^kappa, then we can safely round up. + if ((rest > unit) && (ten_kappa - (rest - unit) <= (rest - unit))) { + // Increment the last digit recursively until we find a non '9' digit. + buffer[length - 1]++; + for (int i = length - 1; i > 0; --i) { + if (buffer[i] != '0' + 10) break; + buffer[i] = '0'; + buffer[i - 1]++; + } + // If the first digit is now '0'+ 10 we had a buffer with all '9's. With the + // exception of the first digit all digits are now '0'. Simply switch the + // first digit to '1' and adjust the kappa. Example: "99" becomes "10" and + // the power (the kappa) is increased. + if (buffer[0] == '0' + 10) { + buffer[0] = '1'; + (*kappa) += 1; + } + return true; + } + return false; +} + +// Returns the biggest power of ten that is less than or equal to the given +// number. We furthermore receive the maximum number of bits 'number' has. +// +// Returns power == 10^(exponent_plus_one-1) such that +// power <= number < power * 10. +// If number_bits == 0 then 0^(0-1) is returned. +// The number of bits must be <= 32. +// Precondition: number < (1 << (number_bits + 1)). + +// Inspired by the method for finding an integer log base 10 from here: +// http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10 +static unsigned int const kSmallPowersOfTen[] = + {0, 1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, + 1000000000}; + +static void BiggestPowerTen(uint32_t number, + int number_bits, + uint32_t* power, + int* exponent_plus_one) { + ASSERT(number < (1u << (number_bits + 1))); + // 1233/4096 is approximately 1/lg(10). + int exponent_plus_one_guess = ((number_bits + 1) * 1233 >> 12); + // We increment to skip over the first entry in the kPowersOf10 table. + // Note: kPowersOf10[i] == 10^(i-1). + exponent_plus_one_guess++; + // We don't have any guarantees that 2^number_bits <= number. + // TODO(floitsch): can we change the 'while' into an 'if'? We definitely see + // number < (2^number_bits - 1), but I haven't encountered + // number < (2^number_bits - 2) yet. + while (number < kSmallPowersOfTen[exponent_plus_one_guess]) { + exponent_plus_one_guess--; + } + *power = kSmallPowersOfTen[exponent_plus_one_guess]; + *exponent_plus_one = exponent_plus_one_guess; +} + +// Generates the digits of input number w. +// w is a floating-point number (DiyFp), consisting of a significand and an +// exponent. Its exponent is bounded by kMinimalTargetExponent and +// kMaximalTargetExponent. +// Hence -60 <= w.e() <= -32. +// +// Returns false if it fails, in which case the generated digits in the buffer +// should not be used. +// Preconditions: +// * low, w and high are correct up to 1 ulp (unit in the last place). That +// is, their error must be less than a unit of their last digits. +// * low.e() == w.e() == high.e() +// * low < w < high, and taking into account their error: low~ <= high~ +// * kMinimalTargetExponent <= w.e() <= kMaximalTargetExponent +// Postconditions: returns false if procedure fails. +// otherwise: +// * buffer is not null-terminated, but len contains the number of digits. +// * buffer contains the shortest possible decimal digit-sequence +// such that LOW < buffer * 10^kappa < HIGH, where LOW and HIGH are the +// correct values of low and high (without their error). +// * if more than one decimal representation gives the minimal number of +// decimal digits then the one closest to W (where W is the correct value +// of w) is chosen. +// Remark: this procedure takes into account the imprecision of its input +// numbers. If the precision is not enough to guarantee all the postconditions +// then false is returned. This usually happens rarely (~0.5%). +// +// Say, for the sake of example, that +// w.e() == -48, and w.f() == 0x1234567890abcdef +// w's value can be computed by w.f() * 2^w.e() +// We can obtain w's integral digits by simply shifting w.f() by -w.e(). +// -> w's integral part is 0x1234 +// w's fractional part is therefore 0x567890abcdef. +// Printing w's integral part is easy (simply print 0x1234 in decimal). +// In order to print its fraction we repeatedly multiply the fraction by 10 and +// get each digit. Example the first digit after the point would be computed by +// (0x567890abcdef * 10) >> 48. -> 3 +// The whole thing becomes slightly more complicated because we want to stop +// once we have enough digits. That is, once the digits inside the buffer +// represent 'w' we can stop. Everything inside the interval low - high +// represents w. However we have to pay attention to low, high and w's +// imprecision. +static bool DigitGen(DiyFp low, + DiyFp w, + DiyFp high, + Vector buffer, + int* length, + int* kappa) { + ASSERT(low.e() == w.e() && w.e() == high.e()); + ASSERT(low.f() + 1 <= high.f() - 1); + ASSERT(kMinimalTargetExponent <= w.e() && w.e() <= kMaximalTargetExponent); + // low, w and high are imprecise, but by less than one ulp (unit in the last + // place). + // If we remove (resp. add) 1 ulp from low (resp. high) we are certain that + // the new numbers are outside of the interval we want the final + // representation to lie in. + // Inversely adding (resp. removing) 1 ulp from low (resp. high) would yield + // numbers that are certain to lie in the interval. We will use this fact + // later on. + // We will now start by generating the digits within the uncertain + // interval. Later we will weed out representations that lie outside the safe + // interval and thus _might_ lie outside the correct interval. + uint64_t unit = 1; + DiyFp too_low = DiyFp(low.f() - unit, low.e()); + DiyFp too_high = DiyFp(high.f() + unit, high.e()); + // too_low and too_high are guaranteed to lie outside the interval we want the + // generated number in. + DiyFp unsafe_interval = DiyFp::Minus(too_high, too_low); + // We now cut the input number into two parts: the integral digits and the + // fractionals. We will not write any decimal separator though, but adapt + // kappa instead. + // Reminder: we are currently computing the digits (stored inside the buffer) + // such that: too_low < buffer * 10^kappa < too_high + // We use too_high for the digit_generation and stop as soon as possible. + // If we stop early we effectively round down. + DiyFp one = DiyFp(static_cast(1) << -w.e(), w.e()); + // Division by one is a shift. + uint32_t integrals = static_cast(too_high.f() >> -one.e()); + // Modulo by one is an and. + uint64_t fractionals = too_high.f() & (one.f() - 1); + uint32_t divisor; + int divisor_exponent_plus_one; + BiggestPowerTen(integrals, DiyFp::kSignificandSize - (-one.e()), + &divisor, &divisor_exponent_plus_one); + *kappa = divisor_exponent_plus_one; + *length = 0; + // Loop invariant: buffer = too_high / 10^kappa (integer division) + // The invariant holds for the first iteration: kappa has been initialized + // with the divisor exponent + 1. And the divisor is the biggest power of ten + // that is smaller than integrals. + while (*kappa > 0) { + int digit = integrals / divisor; + buffer[*length] = '0' + digit; + (*length)++; + integrals %= divisor; + (*kappa)--; + // Note that kappa now equals the exponent of the divisor and that the + // invariant thus holds again. + uint64_t rest = + (static_cast(integrals) << -one.e()) + fractionals; + // Invariant: too_high = buffer * 10^kappa + DiyFp(rest, one.e()) + // Reminder: unsafe_interval.e() == one.e() + if (rest < unsafe_interval.f()) { + // Rounding down (by not emitting the remaining digits) yields a number + // that lies within the unsafe interval. + return RoundWeed(buffer, *length, DiyFp::Minus(too_high, w).f(), + unsafe_interval.f(), rest, + static_cast(divisor) << -one.e(), unit); + } + divisor /= 10; + } + + // The integrals have been generated. We are at the point of the decimal + // separator. In the following loop we simply multiply the remaining digits by + // 10 and divide by one. We just need to pay attention to multiply associated + // data (like the interval or 'unit'), too. + // Note that the multiplication by 10 does not overflow, because w.e >= -60 + // and thus one.e >= -60. + ASSERT(one.e() >= -60); + ASSERT(fractionals < one.f()); + ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f()); + while (true) { + fractionals *= 10; + unit *= 10; + unsafe_interval.set_f(unsafe_interval.f() * 10); + // Integer division by one. + int digit = static_cast(fractionals >> -one.e()); + buffer[*length] = '0' + digit; + (*length)++; + fractionals &= one.f() - 1; // Modulo by one. + (*kappa)--; + if (fractionals < unsafe_interval.f()) { + return RoundWeed(buffer, *length, DiyFp::Minus(too_high, w).f() * unit, + unsafe_interval.f(), fractionals, one.f(), unit); + } + } +} + + + +// Generates (at most) requested_digits digits of input number w. +// w is a floating-point number (DiyFp), consisting of a significand and an +// exponent. Its exponent is bounded by kMinimalTargetExponent and +// kMaximalTargetExponent. +// Hence -60 <= w.e() <= -32. +// +// Returns false if it fails, in which case the generated digits in the buffer +// should not be used. +// Preconditions: +// * w is correct up to 1 ulp (unit in the last place). That +// is, its error must be strictly less than a unit of its last digit. +// * kMinimalTargetExponent <= w.e() <= kMaximalTargetExponent +// +// Postconditions: returns false if procedure fails. +// otherwise: +// * buffer is not null-terminated, but length contains the number of +// digits. +// * the representation in buffer is the most precise representation of +// requested_digits digits. +// * buffer contains at most requested_digits digits of w. If there are less +// than requested_digits digits then some trailing '0's have been removed. +// * kappa is such that +// w = buffer * 10^kappa + eps with |eps| < 10^kappa / 2. +// +// Remark: This procedure takes into account the imprecision of its input +// numbers. If the precision is not enough to guarantee all the postconditions +// then false is returned. This usually happens rarely, but the failure-rate +// increases with higher requested_digits. +static bool DigitGenCounted(DiyFp w, + int requested_digits, + Vector buffer, + int* length, + int* kappa) { + ASSERT(kMinimalTargetExponent <= w.e() && w.e() <= kMaximalTargetExponent); + ASSERT(kMinimalTargetExponent >= -60); + ASSERT(kMaximalTargetExponent <= -32); + // w is assumed to have an error less than 1 unit. Whenever w is scaled we + // also scale its error. + uint64_t w_error = 1; + // We cut the input number into two parts: the integral digits and the + // fractional digits. We don't emit any decimal separator, but adapt kappa + // instead. Example: instead of writing "1.2" we put "12" into the buffer and + // increase kappa by 1. + DiyFp one = DiyFp(static_cast(1) << -w.e(), w.e()); + // Division by one is a shift. + uint32_t integrals = static_cast(w.f() >> -one.e()); + // Modulo by one is an and. + uint64_t fractionals = w.f() & (one.f() - 1); + uint32_t divisor; + int divisor_exponent_plus_one; + BiggestPowerTen(integrals, DiyFp::kSignificandSize - (-one.e()), + &divisor, &divisor_exponent_plus_one); + *kappa = divisor_exponent_plus_one; + *length = 0; + + // Loop invariant: buffer = w / 10^kappa (integer division) + // The invariant holds for the first iteration: kappa has been initialized + // with the divisor exponent + 1. And the divisor is the biggest power of ten + // that is smaller than 'integrals'. + while (*kappa > 0) { + int digit = integrals / divisor; + buffer[*length] = '0' + digit; + (*length)++; + requested_digits--; + integrals %= divisor; + (*kappa)--; + // Note that kappa now equals the exponent of the divisor and that the + // invariant thus holds again. + if (requested_digits == 0) break; + divisor /= 10; + } + + if (requested_digits == 0) { + uint64_t rest = + (static_cast(integrals) << -one.e()) + fractionals; + return RoundWeedCounted(buffer, *length, rest, + static_cast(divisor) << -one.e(), w_error, + kappa); + } + + // The integrals have been generated. We are at the point of the decimal + // separator. In the following loop we simply multiply the remaining digits by + // 10 and divide by one. We just need to pay attention to multiply associated + // data (the 'unit'), too. + // Note that the multiplication by 10 does not overflow, because w.e >= -60 + // and thus one.e >= -60. + ASSERT(one.e() >= -60); + ASSERT(fractionals < one.f()); + ASSERT(UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF) / 10 >= one.f()); + while (requested_digits > 0 && fractionals > w_error) { + fractionals *= 10; + w_error *= 10; + // Integer division by one. + int digit = static_cast(fractionals >> -one.e()); + buffer[*length] = '0' + digit; + (*length)++; + requested_digits--; + fractionals &= one.f() - 1; // Modulo by one. + (*kappa)--; + } + if (requested_digits != 0) return false; + return RoundWeedCounted(buffer, *length, fractionals, one.f(), w_error, + kappa); +} + + +// Provides a decimal representation of v. +// Returns true if it succeeds, otherwise the result cannot be trusted. +// There will be *length digits inside the buffer (not null-terminated). +// If the function returns true then +// v == (double) (buffer * 10^decimal_exponent). +// The digits in the buffer are the shortest representation possible: no +// 0.09999999999999999 instead of 0.1. The shorter representation will even be +// chosen even if the longer one would be closer to v. +// The last digit will be closest to the actual v. That is, even if several +// digits might correctly yield 'v' when read again, the closest will be +// computed. +static bool Grisu3(double v, + FastDtoaMode mode, + Vector buffer, + int* length, + int* decimal_exponent) { + DiyFp w = Double(v).AsNormalizedDiyFp(); + // boundary_minus and boundary_plus are the boundaries between v and its + // closest floating-point neighbors. Any number strictly between + // boundary_minus and boundary_plus will round to v when convert to a double. + // Grisu3 will never output representations that lie exactly on a boundary. + DiyFp boundary_minus, boundary_plus; + if (mode == FAST_DTOA_SHORTEST) { + Double(v).NormalizedBoundaries(&boundary_minus, &boundary_plus); + } else { + ASSERT(mode == FAST_DTOA_SHORTEST_SINGLE); + float single_v = static_cast(v); + Single(single_v).NormalizedBoundaries(&boundary_minus, &boundary_plus); + } + ASSERT(boundary_plus.e() == w.e()); + DiyFp ten_mk; // Cached power of ten: 10^-k + int mk; // -k + int ten_mk_minimal_binary_exponent = + kMinimalTargetExponent - (w.e() + DiyFp::kSignificandSize); + int ten_mk_maximal_binary_exponent = + kMaximalTargetExponent - (w.e() + DiyFp::kSignificandSize); + PowersOfTenCache::GetCachedPowerForBinaryExponentRange( + ten_mk_minimal_binary_exponent, + ten_mk_maximal_binary_exponent, + &ten_mk, &mk); + ASSERT((kMinimalTargetExponent <= w.e() + ten_mk.e() + + DiyFp::kSignificandSize) && + (kMaximalTargetExponent >= w.e() + ten_mk.e() + + DiyFp::kSignificandSize)); + // Note that ten_mk is only an approximation of 10^-k. A DiyFp only contains a + // 64 bit significand and ten_mk is thus only precise up to 64 bits. + + // The DiyFp::Times procedure rounds its result, and ten_mk is approximated + // too. The variable scaled_w (as well as scaled_boundary_minus/plus) are now + // off by a small amount. + // In fact: scaled_w - w*10^k < 1ulp (unit in the last place) of scaled_w. + // In other words: let f = scaled_w.f() and e = scaled_w.e(), then + // (f-1) * 2^e < w*10^k < (f+1) * 2^e + DiyFp scaled_w = DiyFp::Times(w, ten_mk); + ASSERT(scaled_w.e() == + boundary_plus.e() + ten_mk.e() + DiyFp::kSignificandSize); + // In theory it would be possible to avoid some recomputations by computing + // the difference between w and boundary_minus/plus (a power of 2) and to + // compute scaled_boundary_minus/plus by subtracting/adding from + // scaled_w. However the code becomes much less readable and the speed + // enhancements are not terriffic. + DiyFp scaled_boundary_minus = DiyFp::Times(boundary_minus, ten_mk); + DiyFp scaled_boundary_plus = DiyFp::Times(boundary_plus, ten_mk); + + // DigitGen will generate the digits of scaled_w. Therefore we have + // v == (double) (scaled_w * 10^-mk). + // Set decimal_exponent == -mk and pass it to DigitGen. If scaled_w is not an + // integer than it will be updated. For instance if scaled_w == 1.23 then + // the buffer will be filled with "123" und the decimal_exponent will be + // decreased by 2. + int kappa; + bool result = DigitGen(scaled_boundary_minus, scaled_w, scaled_boundary_plus, + buffer, length, &kappa); + *decimal_exponent = -mk + kappa; + return result; +} + + +// The "counted" version of grisu3 (see above) only generates requested_digits +// number of digits. This version does not generate the shortest representation, +// and with enough requested digits 0.1 will at some point print as 0.9999999... +// Grisu3 is too imprecise for real halfway cases (1.5 will not work) and +// therefore the rounding strategy for halfway cases is irrelevant. +static bool Grisu3Counted(double v, + int requested_digits, + Vector buffer, + int* length, + int* decimal_exponent) { + DiyFp w = Double(v).AsNormalizedDiyFp(); + DiyFp ten_mk; // Cached power of ten: 10^-k + int mk; // -k + int ten_mk_minimal_binary_exponent = + kMinimalTargetExponent - (w.e() + DiyFp::kSignificandSize); + int ten_mk_maximal_binary_exponent = + kMaximalTargetExponent - (w.e() + DiyFp::kSignificandSize); + PowersOfTenCache::GetCachedPowerForBinaryExponentRange( + ten_mk_minimal_binary_exponent, + ten_mk_maximal_binary_exponent, + &ten_mk, &mk); + ASSERT((kMinimalTargetExponent <= w.e() + ten_mk.e() + + DiyFp::kSignificandSize) && + (kMaximalTargetExponent >= w.e() + ten_mk.e() + + DiyFp::kSignificandSize)); + // Note that ten_mk is only an approximation of 10^-k. A DiyFp only contains a + // 64 bit significand and ten_mk is thus only precise up to 64 bits. + + // The DiyFp::Times procedure rounds its result, and ten_mk is approximated + // too. The variable scaled_w (as well as scaled_boundary_minus/plus) are now + // off by a small amount. + // In fact: scaled_w - w*10^k < 1ulp (unit in the last place) of scaled_w. + // In other words: let f = scaled_w.f() and e = scaled_w.e(), then + // (f-1) * 2^e < w*10^k < (f+1) * 2^e + DiyFp scaled_w = DiyFp::Times(w, ten_mk); + + // We now have (double) (scaled_w * 10^-mk). + // DigitGen will generate the first requested_digits digits of scaled_w and + // return together with a kappa such that scaled_w ~= buffer * 10^kappa. (It + // will not always be exactly the same since DigitGenCounted only produces a + // limited number of digits.) + int kappa; + bool result = DigitGenCounted(scaled_w, requested_digits, + buffer, length, &kappa); + *decimal_exponent = -mk + kappa; + return result; +} + + +bool FastDtoa(double v, + FastDtoaMode mode, + int requested_digits, + Vector buffer, + int* length, + int* decimal_point) { + ASSERT(v > 0); + ASSERT(!Double(v).IsSpecial()); + + bool result = false; + int decimal_exponent = 0; + switch (mode) { + case FAST_DTOA_SHORTEST: + case FAST_DTOA_SHORTEST_SINGLE: + result = Grisu3(v, mode, buffer, length, &decimal_exponent); + break; + case FAST_DTOA_PRECISION: + result = Grisu3Counted(v, requested_digits, + buffer, length, &decimal_exponent); + break; + default: + UNREACHABLE(); + } + if (result) { + *decimal_point = *length + decimal_exponent; + buffer[*length] = '\0'; + } + return result; +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/fast-dtoa.h b/klm/util/double-conversion/fast-dtoa.h new file mode 100644 index 00000000..5f1e8eee --- /dev/null +++ b/klm/util/double-conversion/fast-dtoa.h @@ -0,0 +1,88 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_FAST_DTOA_H_ +#define DOUBLE_CONVERSION_FAST_DTOA_H_ + +#include "utils.h" + +namespace double_conversion { + +enum FastDtoaMode { + // Computes the shortest representation of the given input. The returned + // result will be the most accurate number of this length. Longer + // representations might be more accurate. + FAST_DTOA_SHORTEST, + // Same as FAST_DTOA_SHORTEST but for single-precision floats. + FAST_DTOA_SHORTEST_SINGLE, + // Computes a representation where the precision (number of digits) is + // given as input. The precision is independent of the decimal point. + FAST_DTOA_PRECISION +}; + +// FastDtoa will produce at most kFastDtoaMaximalLength digits. This does not +// include the terminating '\0' character. +static const int kFastDtoaMaximalLength = 17; +// Same for single-precision numbers. +static const int kFastDtoaMaximalSingleLength = 9; + +// Provides a decimal representation of v. +// The result should be interpreted as buffer * 10^(point - length). +// +// Precondition: +// * v must be a strictly positive finite double. +// +// Returns true if it succeeds, otherwise the result can not be trusted. +// There will be *length digits inside the buffer followed by a null terminator. +// If the function returns true and mode equals +// - FAST_DTOA_SHORTEST, then +// the parameter requested_digits is ignored. +// The result satisfies +// v == (double) (buffer * 10^(point - length)). +// The digits in the buffer are the shortest representation possible. E.g. +// if 0.099999999999 and 0.1 represent the same double then "1" is returned +// with point = 0. +// The last digit will be closest to the actual v. That is, even if several +// digits might correctly yield 'v' when read again, the buffer will contain +// the one closest to v. +// - FAST_DTOA_PRECISION, then +// the buffer contains requested_digits digits. +// the difference v - (buffer * 10^(point-length)) is closest to zero for +// all possible representations of requested_digits digits. +// If there are two values that are equally close, then FastDtoa returns +// false. +// For both modes the buffer must be large enough to hold the result. +bool FastDtoa(double d, + FastDtoaMode mode, + int requested_digits, + Vector buffer, + int* length, + int* decimal_point); + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_FAST_DTOA_H_ diff --git a/klm/util/double-conversion/fixed-dtoa.cc b/klm/util/double-conversion/fixed-dtoa.cc new file mode 100644 index 00000000..d56b1449 --- /dev/null +++ b/klm/util/double-conversion/fixed-dtoa.cc @@ -0,0 +1,402 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "fixed-dtoa.h" +#include "ieee.h" + +namespace double_conversion { + +// Represents a 128bit type. This class should be replaced by a native type on +// platforms that support 128bit integers. +class UInt128 { + public: + UInt128() : high_bits_(0), low_bits_(0) { } + UInt128(uint64_t high, uint64_t low) : high_bits_(high), low_bits_(low) { } + + void Multiply(uint32_t multiplicand) { + uint64_t accumulator; + + accumulator = (low_bits_ & kMask32) * multiplicand; + uint32_t part = static_cast(accumulator & kMask32); + accumulator >>= 32; + accumulator = accumulator + (low_bits_ >> 32) * multiplicand; + low_bits_ = (accumulator << 32) + part; + accumulator >>= 32; + accumulator = accumulator + (high_bits_ & kMask32) * multiplicand; + part = static_cast(accumulator & kMask32); + accumulator >>= 32; + accumulator = accumulator + (high_bits_ >> 32) * multiplicand; + high_bits_ = (accumulator << 32) + part; + ASSERT((accumulator >> 32) == 0); + } + + void Shift(int shift_amount) { + ASSERT(-64 <= shift_amount && shift_amount <= 64); + if (shift_amount == 0) { + return; + } else if (shift_amount == -64) { + high_bits_ = low_bits_; + low_bits_ = 0; + } else if (shift_amount == 64) { + low_bits_ = high_bits_; + high_bits_ = 0; + } else if (shift_amount <= 0) { + high_bits_ <<= -shift_amount; + high_bits_ += low_bits_ >> (64 + shift_amount); + low_bits_ <<= -shift_amount; + } else { + low_bits_ >>= shift_amount; + low_bits_ += high_bits_ << (64 - shift_amount); + high_bits_ >>= shift_amount; + } + } + + // Modifies *this to *this MOD (2^power). + // Returns *this DIV (2^power). + int DivModPowerOf2(int power) { + if (power >= 64) { + int result = static_cast(high_bits_ >> (power - 64)); + high_bits_ -= static_cast(result) << (power - 64); + return result; + } else { + uint64_t part_low = low_bits_ >> power; + uint64_t part_high = high_bits_ << (64 - power); + int result = static_cast(part_low + part_high); + high_bits_ = 0; + low_bits_ -= part_low << power; + return result; + } + } + + bool IsZero() const { + return high_bits_ == 0 && low_bits_ == 0; + } + + int BitAt(int position) { + if (position >= 64) { + return static_cast(high_bits_ >> (position - 64)) & 1; + } else { + return static_cast(low_bits_ >> position) & 1; + } + } + + private: + static const uint64_t kMask32 = 0xFFFFFFFF; + // Value == (high_bits_ << 64) + low_bits_ + uint64_t high_bits_; + uint64_t low_bits_; +}; + + +static const int kDoubleSignificandSize = 53; // Includes the hidden bit. + + +static void FillDigits32FixedLength(uint32_t number, int requested_length, + Vector buffer, int* length) { + for (int i = requested_length - 1; i >= 0; --i) { + buffer[(*length) + i] = '0' + number % 10; + number /= 10; + } + *length += requested_length; +} + + +static void FillDigits32(uint32_t number, Vector buffer, int* length) { + int number_length = 0; + // We fill the digits in reverse order and exchange them afterwards. + while (number != 0) { + int digit = number % 10; + number /= 10; + buffer[(*length) + number_length] = '0' + digit; + number_length++; + } + // Exchange the digits. + int i = *length; + int j = *length + number_length - 1; + while (i < j) { + char tmp = buffer[i]; + buffer[i] = buffer[j]; + buffer[j] = tmp; + i++; + j--; + } + *length += number_length; +} + + +static void FillDigits64FixedLength(uint64_t number, int requested_length, + Vector buffer, int* length) { + const uint32_t kTen7 = 10000000; + // For efficiency cut the number into 3 uint32_t parts, and print those. + uint32_t part2 = static_cast(number % kTen7); + number /= kTen7; + uint32_t part1 = static_cast(number % kTen7); + uint32_t part0 = static_cast(number / kTen7); + + FillDigits32FixedLength(part0, 3, buffer, length); + FillDigits32FixedLength(part1, 7, buffer, length); + FillDigits32FixedLength(part2, 7, buffer, length); +} + + +static void FillDigits64(uint64_t number, Vector buffer, int* length) { + const uint32_t kTen7 = 10000000; + // For efficiency cut the number into 3 uint32_t parts, and print those. + uint32_t part2 = static_cast(number % kTen7); + number /= kTen7; + uint32_t part1 = static_cast(number % kTen7); + uint32_t part0 = static_cast(number / kTen7); + + if (part0 != 0) { + FillDigits32(part0, buffer, length); + FillDigits32FixedLength(part1, 7, buffer, length); + FillDigits32FixedLength(part2, 7, buffer, length); + } else if (part1 != 0) { + FillDigits32(part1, buffer, length); + FillDigits32FixedLength(part2, 7, buffer, length); + } else { + FillDigits32(part2, buffer, length); + } +} + + +static void RoundUp(Vector buffer, int* length, int* decimal_point) { + // An empty buffer represents 0. + if (*length == 0) { + buffer[0] = '1'; + *decimal_point = 1; + *length = 1; + return; + } + // Round the last digit until we either have a digit that was not '9' or until + // we reached the first digit. + buffer[(*length) - 1]++; + for (int i = (*length) - 1; i > 0; --i) { + if (buffer[i] != '0' + 10) { + return; + } + buffer[i] = '0'; + buffer[i - 1]++; + } + // If the first digit is now '0' + 10, we would need to set it to '0' and add + // a '1' in front. However we reach the first digit only if all following + // digits had been '9' before rounding up. Now all trailing digits are '0' and + // we simply switch the first digit to '1' and update the decimal-point + // (indicating that the point is now one digit to the right). + if (buffer[0] == '0' + 10) { + buffer[0] = '1'; + (*decimal_point)++; + } +} + + +// The given fractionals number represents a fixed-point number with binary +// point at bit (-exponent). +// Preconditions: +// -128 <= exponent <= 0. +// 0 <= fractionals * 2^exponent < 1 +// The buffer holds the result. +// The function will round its result. During the rounding-process digits not +// generated by this function might be updated, and the decimal-point variable +// might be updated. If this function generates the digits 99 and the buffer +// already contained "199" (thus yielding a buffer of "19999") then a +// rounding-up will change the contents of the buffer to "20000". +static void FillFractionals(uint64_t fractionals, int exponent, + int fractional_count, Vector buffer, + int* length, int* decimal_point) { + ASSERT(-128 <= exponent && exponent <= 0); + // 'fractionals' is a fixed-point number, with binary point at bit + // (-exponent). Inside the function the non-converted remainder of fractionals + // is a fixed-point number, with binary point at bit 'point'. + if (-exponent <= 64) { + // One 64 bit number is sufficient. + ASSERT(fractionals >> 56 == 0); + int point = -exponent; + for (int i = 0; i < fractional_count; ++i) { + if (fractionals == 0) break; + // Instead of multiplying by 10 we multiply by 5 and adjust the point + // location. This way the fractionals variable will not overflow. + // Invariant at the beginning of the loop: fractionals < 2^point. + // Initially we have: point <= 64 and fractionals < 2^56 + // After each iteration the point is decremented by one. + // Note that 5^3 = 125 < 128 = 2^7. + // Therefore three iterations of this loop will not overflow fractionals + // (even without the subtraction at the end of the loop body). At this + // time point will satisfy point <= 61 and therefore fractionals < 2^point + // and any further multiplication of fractionals by 5 will not overflow. + fractionals *= 5; + point--; + int digit = static_cast(fractionals >> point); + buffer[*length] = '0' + digit; + (*length)++; + fractionals -= static_cast(digit) << point; + } + // If the first bit after the point is set we have to round up. + if (((fractionals >> (point - 1)) & 1) == 1) { + RoundUp(buffer, length, decimal_point); + } + } else { // We need 128 bits. + ASSERT(64 < -exponent && -exponent <= 128); + UInt128 fractionals128 = UInt128(fractionals, 0); + fractionals128.Shift(-exponent - 64); + int point = 128; + for (int i = 0; i < fractional_count; ++i) { + if (fractionals128.IsZero()) break; + // As before: instead of multiplying by 10 we multiply by 5 and adjust the + // point location. + // This multiplication will not overflow for the same reasons as before. + fractionals128.Multiply(5); + point--; + int digit = fractionals128.DivModPowerOf2(point); + buffer[*length] = '0' + digit; + (*length)++; + } + if (fractionals128.BitAt(point - 1) == 1) { + RoundUp(buffer, length, decimal_point); + } + } +} + + +// Removes leading and trailing zeros. +// If leading zeros are removed then the decimal point position is adjusted. +static void TrimZeros(Vector buffer, int* length, int* decimal_point) { + while (*length > 0 && buffer[(*length) - 1] == '0') { + (*length)--; + } + int first_non_zero = 0; + while (first_non_zero < *length && buffer[first_non_zero] == '0') { + first_non_zero++; + } + if (first_non_zero != 0) { + for (int i = first_non_zero; i < *length; ++i) { + buffer[i - first_non_zero] = buffer[i]; + } + *length -= first_non_zero; + *decimal_point -= first_non_zero; + } +} + + +bool FastFixedDtoa(double v, + int fractional_count, + Vector buffer, + int* length, + int* decimal_point) { + const uint32_t kMaxUInt32 = 0xFFFFFFFF; + uint64_t significand = Double(v).Significand(); + int exponent = Double(v).Exponent(); + // v = significand * 2^exponent (with significand a 53bit integer). + // If the exponent is larger than 20 (i.e. we may have a 73bit number) then we + // don't know how to compute the representation. 2^73 ~= 9.5*10^21. + // If necessary this limit could probably be increased, but we don't need + // more. + if (exponent > 20) return false; + if (fractional_count > 20) return false; + *length = 0; + // At most kDoubleSignificandSize bits of the significand are non-zero. + // Given a 64 bit integer we have 11 0s followed by 53 potentially non-zero + // bits: 0..11*..0xxx..53*..xx + if (exponent + kDoubleSignificandSize > 64) { + // The exponent must be > 11. + // + // We know that v = significand * 2^exponent. + // And the exponent > 11. + // We simplify the task by dividing v by 10^17. + // The quotient delivers the first digits, and the remainder fits into a 64 + // bit number. + // Dividing by 10^17 is equivalent to dividing by 5^17*2^17. + const uint64_t kFive17 = UINT64_2PART_C(0xB1, A2BC2EC5); // 5^17 + uint64_t divisor = kFive17; + int divisor_power = 17; + uint64_t dividend = significand; + uint32_t quotient; + uint64_t remainder; + // Let v = f * 2^e with f == significand and e == exponent. + // Then need q (quotient) and r (remainder) as follows: + // v = q * 10^17 + r + // f * 2^e = q * 10^17 + r + // f * 2^e = q * 5^17 * 2^17 + r + // If e > 17 then + // f * 2^(e-17) = q * 5^17 + r/2^17 + // else + // f = q * 5^17 * 2^(17-e) + r/2^e + if (exponent > divisor_power) { + // We only allow exponents of up to 20 and therefore (17 - e) <= 3 + dividend <<= exponent - divisor_power; + quotient = static_cast(dividend / divisor); + remainder = (dividend % divisor) << divisor_power; + } else { + divisor <<= divisor_power - exponent; + quotient = static_cast(dividend / divisor); + remainder = (dividend % divisor) << exponent; + } + FillDigits32(quotient, buffer, length); + FillDigits64FixedLength(remainder, divisor_power, buffer, length); + *decimal_point = *length; + } else if (exponent >= 0) { + // 0 <= exponent <= 11 + significand <<= exponent; + FillDigits64(significand, buffer, length); + *decimal_point = *length; + } else if (exponent > -kDoubleSignificandSize) { + // We have to cut the number. + uint64_t integrals = significand >> -exponent; + uint64_t fractionals = significand - (integrals << -exponent); + if (integrals > kMaxUInt32) { + FillDigits64(integrals, buffer, length); + } else { + FillDigits32(static_cast(integrals), buffer, length); + } + *decimal_point = *length; + FillFractionals(fractionals, exponent, fractional_count, + buffer, length, decimal_point); + } else if (exponent < -128) { + // This configuration (with at most 20 digits) means that all digits must be + // 0. + ASSERT(fractional_count <= 20); + buffer[0] = '\0'; + *length = 0; + *decimal_point = -fractional_count; + } else { + *decimal_point = 0; + FillFractionals(significand, exponent, fractional_count, + buffer, length, decimal_point); + } + TrimZeros(buffer, length, decimal_point); + buffer[*length] = '\0'; + if ((*length) == 0) { + // The string is empty and the decimal_point thus has no importance. Mimick + // Gay's dtoa and and set it to -fractional_count. + *decimal_point = -fractional_count; + } + return true; +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/fixed-dtoa.h b/klm/util/double-conversion/fixed-dtoa.h new file mode 100644 index 00000000..3bdd08e2 --- /dev/null +++ b/klm/util/double-conversion/fixed-dtoa.h @@ -0,0 +1,56 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_FIXED_DTOA_H_ +#define DOUBLE_CONVERSION_FIXED_DTOA_H_ + +#include "utils.h" + +namespace double_conversion { + +// Produces digits necessary to print a given number with +// 'fractional_count' digits after the decimal point. +// The buffer must be big enough to hold the result plus one terminating null +// character. +// +// The produced digits might be too short in which case the caller has to fill +// the gaps with '0's. +// Example: FastFixedDtoa(0.001, 5, ...) is allowed to return buffer = "1", and +// decimal_point = -2. +// Halfway cases are rounded towards +/-Infinity (away from 0). The call +// FastFixedDtoa(0.15, 2, ...) thus returns buffer = "2", decimal_point = 0. +// The returned buffer may contain digits that would be truncated from the +// shortest representation of the input. +// +// This method only works for some parameters. If it can't handle the input it +// returns false. The output is null-terminated when the function succeeds. +bool FastFixedDtoa(double v, int fractional_count, + Vector buffer, int* length, int* decimal_point); + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_FIXED_DTOA_H_ diff --git a/klm/util/double-conversion/ieee.h b/klm/util/double-conversion/ieee.h new file mode 100644 index 00000000..839dc47d --- /dev/null +++ b/klm/util/double-conversion/ieee.h @@ -0,0 +1,398 @@ +// Copyright 2012 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_DOUBLE_H_ +#define DOUBLE_CONVERSION_DOUBLE_H_ + +#include "diy-fp.h" + +namespace double_conversion { + +// We assume that doubles and uint64_t have the same endianness. +static uint64_t double_to_uint64(double d) { return BitCast(d); } +static double uint64_to_double(uint64_t d64) { return BitCast(d64); } +static uint32_t float_to_uint32(float f) { return BitCast(f); } +static float uint32_to_float(uint32_t d32) { return BitCast(d32); } + +// Helper functions for doubles. +class Double { + public: + static const uint64_t kSignMask = UINT64_2PART_C(0x80000000, 00000000); + static const uint64_t kExponentMask = UINT64_2PART_C(0x7FF00000, 00000000); + static const uint64_t kSignificandMask = UINT64_2PART_C(0x000FFFFF, FFFFFFFF); + static const uint64_t kHiddenBit = UINT64_2PART_C(0x00100000, 00000000); + static const int kPhysicalSignificandSize = 52; // Excludes the hidden bit. + static const int kSignificandSize = 53; + + Double() : d64_(0) {} + explicit Double(double d) : d64_(double_to_uint64(d)) {} + explicit Double(uint64_t d64) : d64_(d64) {} + explicit Double(DiyFp diy_fp) + : d64_(DiyFpToUint64(diy_fp)) {} + + // The value encoded by this Double must be greater or equal to +0.0. + // It must not be special (infinity, or NaN). + DiyFp AsDiyFp() const { + ASSERT(Sign() > 0); + ASSERT(!IsSpecial()); + return DiyFp(Significand(), Exponent()); + } + + // The value encoded by this Double must be strictly greater than 0. + DiyFp AsNormalizedDiyFp() const { + ASSERT(value() > 0.0); + uint64_t f = Significand(); + int e = Exponent(); + + // The current double could be a denormal. + while ((f & kHiddenBit) == 0) { + f <<= 1; + e--; + } + // Do the final shifts in one go. + f <<= DiyFp::kSignificandSize - kSignificandSize; + e -= DiyFp::kSignificandSize - kSignificandSize; + return DiyFp(f, e); + } + + // Returns the double's bit as uint64. + uint64_t AsUint64() const { + return d64_; + } + + // Returns the next greater double. Returns +infinity on input +infinity. + double NextDouble() const { + if (d64_ == kInfinity) return Double(kInfinity).value(); + if (Sign() < 0 && Significand() == 0) { + // -0.0 + return 0.0; + } + if (Sign() < 0) { + return Double(d64_ - 1).value(); + } else { + return Double(d64_ + 1).value(); + } + } + + double PreviousDouble() const { + if (d64_ == (kInfinity | kSignMask)) return -Double::Infinity(); + if (Sign() < 0) { + return Double(d64_ + 1).value(); + } else { + if (Significand() == 0) return -0.0; + return Double(d64_ - 1).value(); + } + } + + int Exponent() const { + if (IsDenormal()) return kDenormalExponent; + + uint64_t d64 = AsUint64(); + int biased_e = + static_cast((d64 & kExponentMask) >> kPhysicalSignificandSize); + return biased_e - kExponentBias; + } + + uint64_t Significand() const { + uint64_t d64 = AsUint64(); + uint64_t significand = d64 & kSignificandMask; + if (!IsDenormal()) { + return significand + kHiddenBit; + } else { + return significand; + } + } + + // Returns true if the double is a denormal. + bool IsDenormal() const { + uint64_t d64 = AsUint64(); + return (d64 & kExponentMask) == 0; + } + + // We consider denormals not to be special. + // Hence only Infinity and NaN are special. + bool IsSpecial() const { + uint64_t d64 = AsUint64(); + return (d64 & kExponentMask) == kExponentMask; + } + + bool IsNan() const { + uint64_t d64 = AsUint64(); + return ((d64 & kExponentMask) == kExponentMask) && + ((d64 & kSignificandMask) != 0); + } + + bool IsInfinite() const { + uint64_t d64 = AsUint64(); + return ((d64 & kExponentMask) == kExponentMask) && + ((d64 & kSignificandMask) == 0); + } + + int Sign() const { + uint64_t d64 = AsUint64(); + return (d64 & kSignMask) == 0? 1: -1; + } + + // Precondition: the value encoded by this Double must be greater or equal + // than +0.0. + DiyFp UpperBoundary() const { + ASSERT(Sign() > 0); + return DiyFp(Significand() * 2 + 1, Exponent() - 1); + } + + // Computes the two boundaries of this. + // The bigger boundary (m_plus) is normalized. The lower boundary has the same + // exponent as m_plus. + // Precondition: the value encoded by this Double must be greater than 0. + void NormalizedBoundaries(DiyFp* out_m_minus, DiyFp* out_m_plus) const { + ASSERT(value() > 0.0); + DiyFp v = this->AsDiyFp(); + DiyFp m_plus = DiyFp::Normalize(DiyFp((v.f() << 1) + 1, v.e() - 1)); + DiyFp m_minus; + if (LowerBoundaryIsCloser()) { + m_minus = DiyFp((v.f() << 2) - 1, v.e() - 2); + } else { + m_minus = DiyFp((v.f() << 1) - 1, v.e() - 1); + } + m_minus.set_f(m_minus.f() << (m_minus.e() - m_plus.e())); + m_minus.set_e(m_plus.e()); + *out_m_plus = m_plus; + *out_m_minus = m_minus; + } + + bool LowerBoundaryIsCloser() const { + // The boundary is closer if the significand is of the form f == 2^p-1 then + // the lower boundary is closer. + // Think of v = 1000e10 and v- = 9999e9. + // Then the boundary (== (v - v-)/2) is not just at a distance of 1e9 but + // at a distance of 1e8. + // The only exception is for the smallest normal: the largest denormal is + // at the same distance as its successor. + // Note: denormals have the same exponent as the smallest normals. + bool physical_significand_is_zero = ((AsUint64() & kSignificandMask) == 0); + return physical_significand_is_zero && (Exponent() != kDenormalExponent); + } + + double value() const { return uint64_to_double(d64_); } + + // Returns the significand size for a given order of magnitude. + // If v = f*2^e with 2^p-1 <= f <= 2^p then p+e is v's order of magnitude. + // This function returns the number of significant binary digits v will have + // once it's encoded into a double. In almost all cases this is equal to + // kSignificandSize. The only exceptions are denormals. They start with + // leading zeroes and their effective significand-size is hence smaller. + static int SignificandSizeForOrderOfMagnitude(int order) { + if (order >= (kDenormalExponent + kSignificandSize)) { + return kSignificandSize; + } + if (order <= kDenormalExponent) return 0; + return order - kDenormalExponent; + } + + static double Infinity() { + return Double(kInfinity).value(); + } + + static double NaN() { + return Double(kNaN).value(); + } + + private: + static const int kExponentBias = 0x3FF + kPhysicalSignificandSize; + static const int kDenormalExponent = -kExponentBias + 1; + static const int kMaxExponent = 0x7FF - kExponentBias; + static const uint64_t kInfinity = UINT64_2PART_C(0x7FF00000, 00000000); + static const uint64_t kNaN = UINT64_2PART_C(0x7FF80000, 00000000); + + const uint64_t d64_; + + static uint64_t DiyFpToUint64(DiyFp diy_fp) { + uint64_t significand = diy_fp.f(); + int exponent = diy_fp.e(); + while (significand > kHiddenBit + kSignificandMask) { + significand >>= 1; + exponent++; + } + if (exponent >= kMaxExponent) { + return kInfinity; + } + if (exponent < kDenormalExponent) { + return 0; + } + while (exponent > kDenormalExponent && (significand & kHiddenBit) == 0) { + significand <<= 1; + exponent--; + } + uint64_t biased_exponent; + if (exponent == kDenormalExponent && (significand & kHiddenBit) == 0) { + biased_exponent = 0; + } else { + biased_exponent = static_cast(exponent + kExponentBias); + } + return (significand & kSignificandMask) | + (biased_exponent << kPhysicalSignificandSize); + } +}; + +class Single { + public: + static const uint32_t kSignMask = 0x80000000; + static const uint32_t kExponentMask = 0x7F800000; + static const uint32_t kSignificandMask = 0x007FFFFF; + static const uint32_t kHiddenBit = 0x00800000; + static const int kPhysicalSignificandSize = 23; // Excludes the hidden bit. + static const int kSignificandSize = 24; + + Single() : d32_(0) {} + explicit Single(float f) : d32_(float_to_uint32(f)) {} + explicit Single(uint32_t d32) : d32_(d32) {} + + // The value encoded by this Single must be greater or equal to +0.0. + // It must not be special (infinity, or NaN). + DiyFp AsDiyFp() const { + ASSERT(Sign() > 0); + ASSERT(!IsSpecial()); + return DiyFp(Significand(), Exponent()); + } + + // Returns the single's bit as uint64. + uint32_t AsUint32() const { + return d32_; + } + + int Exponent() const { + if (IsDenormal()) return kDenormalExponent; + + uint32_t d32 = AsUint32(); + int biased_e = + static_cast((d32 & kExponentMask) >> kPhysicalSignificandSize); + return biased_e - kExponentBias; + } + + uint32_t Significand() const { + uint32_t d32 = AsUint32(); + uint32_t significand = d32 & kSignificandMask; + if (!IsDenormal()) { + return significand + kHiddenBit; + } else { + return significand; + } + } + + // Returns true if the single is a denormal. + bool IsDenormal() const { + uint32_t d32 = AsUint32(); + return (d32 & kExponentMask) == 0; + } + + // We consider denormals not to be special. + // Hence only Infinity and NaN are special. + bool IsSpecial() const { + uint32_t d32 = AsUint32(); + return (d32 & kExponentMask) == kExponentMask; + } + + bool IsNan() const { + uint32_t d32 = AsUint32(); + return ((d32 & kExponentMask) == kExponentMask) && + ((d32 & kSignificandMask) != 0); + } + + bool IsInfinite() const { + uint32_t d32 = AsUint32(); + return ((d32 & kExponentMask) == kExponentMask) && + ((d32 & kSignificandMask) == 0); + } + + int Sign() const { + uint32_t d32 = AsUint32(); + return (d32 & kSignMask) == 0? 1: -1; + } + + // Computes the two boundaries of this. + // The bigger boundary (m_plus) is normalized. The lower boundary has the same + // exponent as m_plus. + // Precondition: the value encoded by this Single must be greater than 0. + void NormalizedBoundaries(DiyFp* out_m_minus, DiyFp* out_m_plus) const { + ASSERT(value() > 0.0); + DiyFp v = this->AsDiyFp(); + DiyFp m_plus = DiyFp::Normalize(DiyFp((v.f() << 1) + 1, v.e() - 1)); + DiyFp m_minus; + if (LowerBoundaryIsCloser()) { + m_minus = DiyFp((v.f() << 2) - 1, v.e() - 2); + } else { + m_minus = DiyFp((v.f() << 1) - 1, v.e() - 1); + } + m_minus.set_f(m_minus.f() << (m_minus.e() - m_plus.e())); + m_minus.set_e(m_plus.e()); + *out_m_plus = m_plus; + *out_m_minus = m_minus; + } + + // Precondition: the value encoded by this Single must be greater or equal + // than +0.0. + DiyFp UpperBoundary() const { + ASSERT(Sign() > 0); + return DiyFp(Significand() * 2 + 1, Exponent() - 1); + } + + bool LowerBoundaryIsCloser() const { + // The boundary is closer if the significand is of the form f == 2^p-1 then + // the lower boundary is closer. + // Think of v = 1000e10 and v- = 9999e9. + // Then the boundary (== (v - v-)/2) is not just at a distance of 1e9 but + // at a distance of 1e8. + // The only exception is for the smallest normal: the largest denormal is + // at the same distance as its successor. + // Note: denormals have the same exponent as the smallest normals. + bool physical_significand_is_zero = ((AsUint32() & kSignificandMask) == 0); + return physical_significand_is_zero && (Exponent() != kDenormalExponent); + } + + float value() const { return uint32_to_float(d32_); } + + static float Infinity() { + return Single(kInfinity).value(); + } + + static float NaN() { + return Single(kNaN).value(); + } + + private: + static const int kExponentBias = 0x7F + kPhysicalSignificandSize; + static const int kDenormalExponent = -kExponentBias + 1; + static const int kMaxExponent = 0xFF - kExponentBias; + static const uint32_t kInfinity = 0x7F800000; + static const uint32_t kNaN = 0x7FC00000; + + const uint32_t d32_; +}; + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_DOUBLE_H_ diff --git a/klm/util/double-conversion/strtod.cc b/klm/util/double-conversion/strtod.cc new file mode 100644 index 00000000..9758989f --- /dev/null +++ b/klm/util/double-conversion/strtod.cc @@ -0,0 +1,554 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include +#include + +#include "strtod.h" +#include "bignum.h" +#include "cached-powers.h" +#include "ieee.h" + +namespace double_conversion { + +// 2^53 = 9007199254740992. +// Any integer with at most 15 decimal digits will hence fit into a double +// (which has a 53bit significand) without loss of precision. +static const int kMaxExactDoubleIntegerDecimalDigits = 15; +// 2^64 = 18446744073709551616 > 10^19 +static const int kMaxUint64DecimalDigits = 19; + +// Max double: 1.7976931348623157 x 10^308 +// Min non-zero double: 4.9406564584124654 x 10^-324 +// Any x >= 10^309 is interpreted as +infinity. +// Any x <= 10^-324 is interpreted as 0. +// Note that 2.5e-324 (despite being smaller than the min double) will be read +// as non-zero (equal to the min non-zero double). +static const int kMaxDecimalPower = 309; +static const int kMinDecimalPower = -324; + +// 2^64 = 18446744073709551616 +static const uint64_t kMaxUint64 = UINT64_2PART_C(0xFFFFFFFF, FFFFFFFF); + + +static const double exact_powers_of_ten[] = { + 1.0, // 10^0 + 10.0, + 100.0, + 1000.0, + 10000.0, + 100000.0, + 1000000.0, + 10000000.0, + 100000000.0, + 1000000000.0, + 10000000000.0, // 10^10 + 100000000000.0, + 1000000000000.0, + 10000000000000.0, + 100000000000000.0, + 1000000000000000.0, + 10000000000000000.0, + 100000000000000000.0, + 1000000000000000000.0, + 10000000000000000000.0, + 100000000000000000000.0, // 10^20 + 1000000000000000000000.0, + // 10^22 = 0x21e19e0c9bab2400000 = 0x878678326eac9 * 2^22 + 10000000000000000000000.0 +}; +static const int kExactPowersOfTenSize = ARRAY_SIZE(exact_powers_of_ten); + +// Maximum number of significant digits in the decimal representation. +// In fact the value is 772 (see conversions.cc), but to give us some margin +// we round up to 780. +static const int kMaxSignificantDecimalDigits = 780; + +static Vector TrimLeadingZeros(Vector buffer) { + for (int i = 0; i < buffer.length(); i++) { + if (buffer[i] != '0') { + return buffer.SubVector(i, buffer.length()); + } + } + return Vector(buffer.start(), 0); +} + + +static Vector TrimTrailingZeros(Vector buffer) { + for (int i = buffer.length() - 1; i >= 0; --i) { + if (buffer[i] != '0') { + return buffer.SubVector(0, i + 1); + } + } + return Vector(buffer.start(), 0); +} + + +static void CutToMaxSignificantDigits(Vector buffer, + int exponent, + char* significant_buffer, + int* significant_exponent) { + for (int i = 0; i < kMaxSignificantDecimalDigits - 1; ++i) { + significant_buffer[i] = buffer[i]; + } + // The input buffer has been trimmed. Therefore the last digit must be + // different from '0'. + ASSERT(buffer[buffer.length() - 1] != '0'); + // Set the last digit to be non-zero. This is sufficient to guarantee + // correct rounding. + significant_buffer[kMaxSignificantDecimalDigits - 1] = '1'; + *significant_exponent = + exponent + (buffer.length() - kMaxSignificantDecimalDigits); +} + + +// Trims the buffer and cuts it to at most kMaxSignificantDecimalDigits. +// If possible the input-buffer is reused, but if the buffer needs to be +// modified (due to cutting), then the input needs to be copied into the +// buffer_copy_space. +static void TrimAndCut(Vector buffer, int exponent, + char* buffer_copy_space, int space_size, + Vector* trimmed, int* updated_exponent) { + Vector left_trimmed = TrimLeadingZeros(buffer); + Vector right_trimmed = TrimTrailingZeros(left_trimmed); + exponent += left_trimmed.length() - right_trimmed.length(); + if (right_trimmed.length() > kMaxSignificantDecimalDigits) { + ASSERT(space_size >= kMaxSignificantDecimalDigits); + CutToMaxSignificantDigits(right_trimmed, exponent, + buffer_copy_space, updated_exponent); + *trimmed = Vector(buffer_copy_space, + kMaxSignificantDecimalDigits); + } else { + *trimmed = right_trimmed; + *updated_exponent = exponent; + } +} + + +// Reads digits from the buffer and converts them to a uint64. +// Reads in as many digits as fit into a uint64. +// When the string starts with "1844674407370955161" no further digit is read. +// Since 2^64 = 18446744073709551616 it would still be possible read another +// digit if it was less or equal than 6, but this would complicate the code. +static uint64_t ReadUint64(Vector buffer, + int* number_of_read_digits) { + uint64_t result = 0; + int i = 0; + while (i < buffer.length() && result <= (kMaxUint64 / 10 - 1)) { + int digit = buffer[i++] - '0'; + ASSERT(0 <= digit && digit <= 9); + result = 10 * result + digit; + } + *number_of_read_digits = i; + return result; +} + + +// Reads a DiyFp from the buffer. +// The returned DiyFp is not necessarily normalized. +// If remaining_decimals is zero then the returned DiyFp is accurate. +// Otherwise it has been rounded and has error of at most 1/2 ulp. +static void ReadDiyFp(Vector buffer, + DiyFp* result, + int* remaining_decimals) { + int read_digits; + uint64_t significand = ReadUint64(buffer, &read_digits); + if (buffer.length() == read_digits) { + *result = DiyFp(significand, 0); + *remaining_decimals = 0; + } else { + // Round the significand. + if (buffer[read_digits] >= '5') { + significand++; + } + // Compute the binary exponent. + int exponent = 0; + *result = DiyFp(significand, exponent); + *remaining_decimals = buffer.length() - read_digits; + } +} + + +static bool DoubleStrtod(Vector trimmed, + int exponent, + double* result) { +#if !defined(DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS) + // On x86 the floating-point stack can be 64 or 80 bits wide. If it is + // 80 bits wide (as is the case on Linux) then double-rounding occurs and the + // result is not accurate. + // We know that Windows32 uses 64 bits and is therefore accurate. + // Note that the ARM simulator is compiled for 32bits. It therefore exhibits + // the same problem. + return false; +#endif + if (trimmed.length() <= kMaxExactDoubleIntegerDecimalDigits) { + int read_digits; + // The trimmed input fits into a double. + // If the 10^exponent (resp. 10^-exponent) fits into a double too then we + // can compute the result-double simply by multiplying (resp. dividing) the + // two numbers. + // This is possible because IEEE guarantees that floating-point operations + // return the best possible approximation. + if (exponent < 0 && -exponent < kExactPowersOfTenSize) { + // 10^-exponent fits into a double. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result /= exact_powers_of_ten[-exponent]; + return true; + } + if (0 <= exponent && exponent < kExactPowersOfTenSize) { + // 10^exponent fits into a double. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result *= exact_powers_of_ten[exponent]; + return true; + } + int remaining_digits = + kMaxExactDoubleIntegerDecimalDigits - trimmed.length(); + if ((0 <= exponent) && + (exponent - remaining_digits < kExactPowersOfTenSize)) { + // The trimmed string was short and we can multiply it with + // 10^remaining_digits. As a result the remaining exponent now fits + // into a double too. + *result = static_cast(ReadUint64(trimmed, &read_digits)); + ASSERT(read_digits == trimmed.length()); + *result *= exact_powers_of_ten[remaining_digits]; + *result *= exact_powers_of_ten[exponent - remaining_digits]; + return true; + } + } + return false; +} + + +// Returns 10^exponent as an exact DiyFp. +// The given exponent must be in the range [1; kDecimalExponentDistance[. +static DiyFp AdjustmentPowerOfTen(int exponent) { + ASSERT(0 < exponent); + ASSERT(exponent < PowersOfTenCache::kDecimalExponentDistance); + // Simply hardcode the remaining powers for the given decimal exponent + // distance. + ASSERT(PowersOfTenCache::kDecimalExponentDistance == 8); + switch (exponent) { + case 1: return DiyFp(UINT64_2PART_C(0xa0000000, 00000000), -60); + case 2: return DiyFp(UINT64_2PART_C(0xc8000000, 00000000), -57); + case 3: return DiyFp(UINT64_2PART_C(0xfa000000, 00000000), -54); + case 4: return DiyFp(UINT64_2PART_C(0x9c400000, 00000000), -50); + case 5: return DiyFp(UINT64_2PART_C(0xc3500000, 00000000), -47); + case 6: return DiyFp(UINT64_2PART_C(0xf4240000, 00000000), -44); + case 7: return DiyFp(UINT64_2PART_C(0x98968000, 00000000), -40); + default: + UNREACHABLE(); + return DiyFp(0, 0); + } +} + + +// If the function returns true then the result is the correct double. +// Otherwise it is either the correct double or the double that is just below +// the correct double. +static bool DiyFpStrtod(Vector buffer, + int exponent, + double* result) { + DiyFp input; + int remaining_decimals; + ReadDiyFp(buffer, &input, &remaining_decimals); + // Since we may have dropped some digits the input is not accurate. + // If remaining_decimals is different than 0 than the error is at most + // .5 ulp (unit in the last place). + // We don't want to deal with fractions and therefore keep a common + // denominator. + const int kDenominatorLog = 3; + const int kDenominator = 1 << kDenominatorLog; + // Move the remaining decimals into the exponent. + exponent += remaining_decimals; + int error = (remaining_decimals == 0 ? 0 : kDenominator / 2); + + int old_e = input.e(); + input.Normalize(); + error <<= old_e - input.e(); + + ASSERT(exponent <= PowersOfTenCache::kMaxDecimalExponent); + if (exponent < PowersOfTenCache::kMinDecimalExponent) { + *result = 0.0; + return true; + } + DiyFp cached_power; + int cached_decimal_exponent; + PowersOfTenCache::GetCachedPowerForDecimalExponent(exponent, + &cached_power, + &cached_decimal_exponent); + + if (cached_decimal_exponent != exponent) { + int adjustment_exponent = exponent - cached_decimal_exponent; + DiyFp adjustment_power = AdjustmentPowerOfTen(adjustment_exponent); + input.Multiply(adjustment_power); + if (kMaxUint64DecimalDigits - buffer.length() >= adjustment_exponent) { + // The product of input with the adjustment power fits into a 64 bit + // integer. + ASSERT(DiyFp::kSignificandSize == 64); + } else { + // The adjustment power is exact. There is hence only an error of 0.5. + error += kDenominator / 2; + } + } + + input.Multiply(cached_power); + // The error introduced by a multiplication of a*b equals + // error_a + error_b + error_a*error_b/2^64 + 0.5 + // Substituting a with 'input' and b with 'cached_power' we have + // error_b = 0.5 (all cached powers have an error of less than 0.5 ulp), + // error_ab = 0 or 1 / kDenominator > error_a*error_b/ 2^64 + int error_b = kDenominator / 2; + int error_ab = (error == 0 ? 0 : 1); // We round up to 1. + int fixed_error = kDenominator / 2; + error += error_b + error_ab + fixed_error; + + old_e = input.e(); + input.Normalize(); + error <<= old_e - input.e(); + + // See if the double's significand changes if we add/subtract the error. + int order_of_magnitude = DiyFp::kSignificandSize + input.e(); + int effective_significand_size = + Double::SignificandSizeForOrderOfMagnitude(order_of_magnitude); + int precision_digits_count = + DiyFp::kSignificandSize - effective_significand_size; + if (precision_digits_count + kDenominatorLog >= DiyFp::kSignificandSize) { + // This can only happen for very small denormals. In this case the + // half-way multiplied by the denominator exceeds the range of an uint64. + // Simply shift everything to the right. + int shift_amount = (precision_digits_count + kDenominatorLog) - + DiyFp::kSignificandSize + 1; + input.set_f(input.f() >> shift_amount); + input.set_e(input.e() + shift_amount); + // We add 1 for the lost precision of error, and kDenominator for + // the lost precision of input.f(). + error = (error >> shift_amount) + 1 + kDenominator; + precision_digits_count -= shift_amount; + } + // We use uint64_ts now. This only works if the DiyFp uses uint64_ts too. + ASSERT(DiyFp::kSignificandSize == 64); + ASSERT(precision_digits_count < 64); + uint64_t one64 = 1; + uint64_t precision_bits_mask = (one64 << precision_digits_count) - 1; + uint64_t precision_bits = input.f() & precision_bits_mask; + uint64_t half_way = one64 << (precision_digits_count - 1); + precision_bits *= kDenominator; + half_way *= kDenominator; + DiyFp rounded_input(input.f() >> precision_digits_count, + input.e() + precision_digits_count); + if (precision_bits >= half_way + error) { + rounded_input.set_f(rounded_input.f() + 1); + } + // If the last_bits are too close to the half-way case than we are too + // inaccurate and round down. In this case we return false so that we can + // fall back to a more precise algorithm. + + *result = Double(rounded_input).value(); + if (half_way - error < precision_bits && precision_bits < half_way + error) { + // Too imprecise. The caller will have to fall back to a slower version. + // However the returned number is guaranteed to be either the correct + // double, or the next-lower double. + return false; + } else { + return true; + } +} + + +// Returns +// - -1 if buffer*10^exponent < diy_fp. +// - 0 if buffer*10^exponent == diy_fp. +// - +1 if buffer*10^exponent > diy_fp. +// Preconditions: +// buffer.length() + exponent <= kMaxDecimalPower + 1 +// buffer.length() + exponent > kMinDecimalPower +// buffer.length() <= kMaxDecimalSignificantDigits +static int CompareBufferWithDiyFp(Vector buffer, + int exponent, + DiyFp diy_fp) { + ASSERT(buffer.length() + exponent <= kMaxDecimalPower + 1); + ASSERT(buffer.length() + exponent > kMinDecimalPower); + ASSERT(buffer.length() <= kMaxSignificantDecimalDigits); + // Make sure that the Bignum will be able to hold all our numbers. + // Our Bignum implementation has a separate field for exponents. Shifts will + // consume at most one bigit (< 64 bits). + // ln(10) == 3.3219... + ASSERT(((kMaxDecimalPower + 1) * 333 / 100) < Bignum::kMaxSignificantBits); + Bignum buffer_bignum; + Bignum diy_fp_bignum; + buffer_bignum.AssignDecimalString(buffer); + diy_fp_bignum.AssignUInt64(diy_fp.f()); + if (exponent >= 0) { + buffer_bignum.MultiplyByPowerOfTen(exponent); + } else { + diy_fp_bignum.MultiplyByPowerOfTen(-exponent); + } + if (diy_fp.e() > 0) { + diy_fp_bignum.ShiftLeft(diy_fp.e()); + } else { + buffer_bignum.ShiftLeft(-diy_fp.e()); + } + return Bignum::Compare(buffer_bignum, diy_fp_bignum); +} + + +// Returns true if the guess is the correct double. +// Returns false, when guess is either correct or the next-lower double. +static bool ComputeGuess(Vector trimmed, int exponent, + double* guess) { + if (trimmed.length() == 0) { + *guess = 0.0; + return true; + } + if (exponent + trimmed.length() - 1 >= kMaxDecimalPower) { + *guess = Double::Infinity(); + return true; + } + if (exponent + trimmed.length() <= kMinDecimalPower) { + *guess = 0.0; + return true; + } + + if (DoubleStrtod(trimmed, exponent, guess) || + DiyFpStrtod(trimmed, exponent, guess)) { + return true; + } + if (*guess == Double::Infinity()) { + return true; + } + return false; +} + +double Strtod(Vector buffer, int exponent) { + char copy_buffer[kMaxSignificantDecimalDigits]; + Vector trimmed; + int updated_exponent; + TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits, + &trimmed, &updated_exponent); + exponent = updated_exponent; + + double guess; + bool is_correct = ComputeGuess(trimmed, exponent, &guess); + if (is_correct) return guess; + + DiyFp upper_boundary = Double(guess).UpperBoundary(); + int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary); + if (comparison < 0) { + return guess; + } else if (comparison > 0) { + return Double(guess).NextDouble(); + } else if ((Double(guess).Significand() & 1) == 0) { + // Round towards even. + return guess; + } else { + return Double(guess).NextDouble(); + } +} + +float Strtof(Vector buffer, int exponent) { + char copy_buffer[kMaxSignificantDecimalDigits]; + Vector trimmed; + int updated_exponent; + TrimAndCut(buffer, exponent, copy_buffer, kMaxSignificantDecimalDigits, + &trimmed, &updated_exponent); + exponent = updated_exponent; + + double double_guess; + bool is_correct = ComputeGuess(trimmed, exponent, &double_guess); + + float float_guess = static_cast(double_guess); + if (float_guess == double_guess) { + // This shortcut triggers for integer values. + return float_guess; + } + + // We must catch double-rounding. Say the double has been rounded up, and is + // now a boundary of a float, and rounds up again. This is why we have to + // look at previous too. + // Example (in decimal numbers): + // input: 12349 + // high-precision (4 digits): 1235 + // low-precision (3 digits): + // when read from input: 123 + // when rounded from high precision: 124. + // To do this we simply look at the neigbors of the correct result and see + // if they would round to the same float. If the guess is not correct we have + // to look at four values (since two different doubles could be the correct + // double). + + double double_next = Double(double_guess).NextDouble(); + double double_previous = Double(double_guess).PreviousDouble(); + + float f1 = static_cast(double_previous); + float f2 = float_guess; + float f3 = static_cast(double_next); + float f4; + if (is_correct) { + f4 = f3; + } else { + double double_next2 = Double(double_next).NextDouble(); + f4 = static_cast(double_next2); + } + ASSERT(f1 <= f2 && f2 <= f3 && f3 <= f4); + + // If the guess doesn't lie near a single-precision boundary we can simply + // return its float-value. + if (f1 == f4) { + return float_guess; + } + + ASSERT((f1 != f2 && f2 == f3 && f3 == f4) || + (f1 == f2 && f2 != f3 && f3 == f4) || + (f1 == f2 && f2 == f3 && f3 != f4)); + + // guess and next are the two possible canditates (in the same way that + // double_guess was the lower candidate for a double-precision guess). + float guess = f1; + float next = f4; + DiyFp upper_boundary; + if (guess == 0.0f) { + float min_float = 1e-45f; + upper_boundary = Double(static_cast(min_float) / 2).AsDiyFp(); + } else { + upper_boundary = Single(guess).UpperBoundary(); + } + int comparison = CompareBufferWithDiyFp(trimmed, exponent, upper_boundary); + if (comparison < 0) { + return guess; + } else if (comparison > 0) { + return next; + } else if ((Single(guess).Significand() & 1) == 0) { + // Round towards even. + return guess; + } else { + return next; + } +} + +} // namespace double_conversion diff --git a/klm/util/double-conversion/strtod.h b/klm/util/double-conversion/strtod.h new file mode 100644 index 00000000..ed0293b8 --- /dev/null +++ b/klm/util/double-conversion/strtod.h @@ -0,0 +1,45 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_STRTOD_H_ +#define DOUBLE_CONVERSION_STRTOD_H_ + +#include "utils.h" + +namespace double_conversion { + +// The buffer must only contain digits in the range [0-9]. It must not +// contain a dot or a sign. It must not start with '0', and must not be empty. +double Strtod(Vector buffer, int exponent); + +// The buffer must only contain digits in the range [0-9]. It must not +// contain a dot or a sign. It must not start with '0', and must not be empty. +float Strtof(Vector buffer, int exponent); + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_STRTOD_H_ diff --git a/klm/util/double-conversion/utils.h b/klm/util/double-conversion/utils.h new file mode 100644 index 00000000..767094b8 --- /dev/null +++ b/klm/util/double-conversion/utils.h @@ -0,0 +1,313 @@ +// Copyright 2010 the V8 project authors. All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef DOUBLE_CONVERSION_UTILS_H_ +#define DOUBLE_CONVERSION_UTILS_H_ + +#include +#include + +#include +#ifndef ASSERT +#define ASSERT(condition) (assert(condition)) +#endif +#ifndef UNIMPLEMENTED +#define UNIMPLEMENTED() (abort()) +#endif +#ifndef UNREACHABLE +#define UNREACHABLE() (abort()) +#endif + +// Double operations detection based on target architecture. +// Linux uses a 80bit wide floating point stack on x86. This induces double +// rounding, which in turn leads to wrong results. +// An easy way to test if the floating-point operations are correct is to +// evaluate: 89255.0/1e22. If the floating-point stack is 64 bits wide then +// the result is equal to 89255e-22. +// The best way to test this, is to create a division-function and to compare +// the output of the division with the expected result. (Inlining must be +// disabled.) +// On Linux,x86 89255e-22 != Div_double(89255.0/1e22) +#if defined(_M_X64) || defined(__x86_64__) || \ + defined(__ARMEL__) || defined(__avr32__) || \ + defined(__hppa__) || defined(__ia64__) || \ + defined(__mips__) || defined(__powerpc__) || \ + defined(__sparc__) || defined(__sparc) || defined(__s390__) || \ + defined(__SH4__) || defined(__alpha__) || \ + defined(_MIPS_ARCH_MIPS32R2) +#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1 +#elif defined(_M_IX86) || defined(__i386__) || defined(__i386) +#if defined(_WIN32) +// Windows uses a 64bit wide floating point stack. +#define DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS 1 +#else +#undef DOUBLE_CONVERSION_CORRECT_DOUBLE_OPERATIONS +#endif // _WIN32 +#else +#error Target architecture was not detected as supported by Double-Conversion. +#endif + + +#if defined(_WIN32) && !defined(__MINGW32__) + +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; // NOLINT +typedef unsigned short uint16_t; // NOLINT +typedef int int32_t; +typedef unsigned int uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +// intptr_t and friends are defined in crtdefs.h through stdio.h. + +#else + +#include + +#endif + +// The following macro works on both 32 and 64-bit platforms. +// Usage: instead of writing 0x1234567890123456 +// write UINT64_2PART_C(0x12345678,90123456); +#define UINT64_2PART_C(a, b) (((static_cast(a) << 32) + 0x##b##u)) + + +// The expression ARRAY_SIZE(a) is a compile-time constant of type +// size_t which represents the number of elements of the given +// array. You should only use ARRAY_SIZE on statically allocated +// arrays. +#ifndef ARRAY_SIZE +#define ARRAY_SIZE(a) \ + ((sizeof(a) / sizeof(*(a))) / \ + static_cast(!(sizeof(a) % sizeof(*(a))))) +#endif + +// A macro to disallow the evil copy constructor and operator= functions +// This should be used in the private: declarations for a class +#ifndef DISALLOW_COPY_AND_ASSIGN +#define DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&); \ + void operator=(const TypeName&) +#endif + +// A macro to disallow all the implicit constructors, namely the +// default constructor, copy constructor and operator= functions. +// +// This should be used in the private: declarations for a class +// that wants to prevent anyone from instantiating it. This is +// especially useful for classes containing only static methods. +#ifndef DISALLOW_IMPLICIT_CONSTRUCTORS +#define DISALLOW_IMPLICIT_CONSTRUCTORS(TypeName) \ + TypeName(); \ + DISALLOW_COPY_AND_ASSIGN(TypeName) +#endif + +namespace double_conversion { + +static const int kCharSize = sizeof(char); + +// Returns the maximum of the two parameters. +template +static T Max(T a, T b) { + return a < b ? b : a; +} + + +// Returns the minimum of the two parameters. +template +static T Min(T a, T b) { + return a < b ? a : b; +} + + +inline int StrLength(const char* string) { + size_t length = strlen(string); + ASSERT(length == static_cast(static_cast(length))); + return static_cast(length); +} + +// This is a simplified version of V8's Vector class. +template +class Vector { + public: + Vector() : start_(NULL), length_(0) {} + Vector(T* data, int length) : start_(data), length_(length) { + ASSERT(length == 0 || (length > 0 && data != NULL)); + } + + // Returns a vector using the same backing storage as this one, + // spanning from and including 'from', to but not including 'to'. + Vector SubVector(int from, int to) { + ASSERT(to <= length_); + ASSERT(from < to); + ASSERT(0 <= from); + return Vector(start() + from, to - from); + } + + // Returns the length of the vector. + int length() const { return length_; } + + // Returns whether or not the vector is empty. + bool is_empty() const { return length_ == 0; } + + // Returns the pointer to the start of the data in the vector. + T* start() const { return start_; } + + // Access individual vector elements - checks bounds in debug mode. + T& operator[](int index) const { + ASSERT(0 <= index && index < length_); + return start_[index]; + } + + T& first() { return start_[0]; } + + T& last() { return start_[length_ - 1]; } + + private: + T* start_; + int length_; +}; + + +// Helper class for building result strings in a character buffer. The +// purpose of the class is to use safe operations that checks the +// buffer bounds on all operations in debug mode. +class StringBuilder { + public: + StringBuilder(char* buffer, int size) + : buffer_(buffer, size), position_(0) { } + + ~StringBuilder() { if (!is_finalized()) Finalize(); } + + int size() const { return buffer_.length(); } + + // Get the current position in the builder. + int position() const { + ASSERT(!is_finalized()); + return position_; + } + + // Reset the position. + void Reset() { position_ = 0; } + + // Add a single character to the builder. It is not allowed to add + // 0-characters; use the Finalize() method to terminate the string + // instead. + void AddCharacter(char c) { + ASSERT(c != '\0'); + ASSERT(!is_finalized() && position_ < buffer_.length()); + buffer_[position_++] = c; + } + + // Add an entire string to the builder. Uses strlen() internally to + // compute the length of the input string. + void AddString(const char* s) { + AddSubstring(s, StrLength(s)); + } + + // Add the first 'n' characters of the given string 's' to the + // builder. The input string must have enough characters. + void AddSubstring(const char* s, int n) { + ASSERT(!is_finalized() && position_ + n < buffer_.length()); + ASSERT(static_cast(n) <= strlen(s)); + memmove(&buffer_[position_], s, n * kCharSize); + position_ += n; + } + + + // Add character padding to the builder. If count is non-positive, + // nothing is added to the builder. + void AddPadding(char c, int count) { + for (int i = 0; i < count; i++) { + AddCharacter(c); + } + } + + // Finalize the string by 0-terminating it and returning the buffer. + char* Finalize() { + ASSERT(!is_finalized() && position_ < buffer_.length()); + buffer_[position_] = '\0'; + // Make sure nobody managed to add a 0-character to the + // buffer while building the string. + ASSERT(strlen(buffer_.start()) == static_cast(position_)); + position_ = -1; + ASSERT(is_finalized()); + return buffer_.start(); + } + + private: + Vector buffer_; + int position_; + + bool is_finalized() const { return position_ < 0; } + + DISALLOW_IMPLICIT_CONSTRUCTORS(StringBuilder); +}; + +// The type-based aliasing rule allows the compiler to assume that pointers of +// different types (for some definition of different) never alias each other. +// Thus the following code does not work: +// +// float f = foo(); +// int fbits = *(int*)(&f); +// +// The compiler 'knows' that the int pointer can't refer to f since the types +// don't match, so the compiler may cache f in a register, leaving random data +// in fbits. Using C++ style casts makes no difference, however a pointer to +// char data is assumed to alias any other pointer. This is the 'memcpy +// exception'. +// +// Bit_cast uses the memcpy exception to move the bits from a variable of one +// type of a variable of another type. Of course the end result is likely to +// be implementation dependent. Most compilers (gcc-4.2 and MSVC 2005) +// will completely optimize BitCast away. +// +// There is an additional use for BitCast. +// Recent gccs will warn when they see casts that may result in breakage due to +// the type-based aliasing rule. If you have checked that there is no breakage +// you can use BitCast to cast one pointer type to another. This confuses gcc +// enough that it can no longer see that you have cast one pointer type to +// another thus avoiding the warning. +template +inline Dest BitCast(const Source& source) { + // Compile time assertion: sizeof(Dest) == sizeof(Source) + // A compile error here means your Dest and Source have different sizes. + typedef char VerifySizesAreEqual[sizeof(Dest) == sizeof(Source) ? 1 : -1]; + + Dest dest; + memmove(&dest, &source, sizeof(dest)); + return dest; +} + +template +inline Dest BitCast(Source* source) { + return BitCast(reinterpret_cast(source)); +} + +} // namespace double_conversion + +#endif // DOUBLE_CONVERSION_UTILS_H_ diff --git a/klm/util/ersatz_progress.cc b/klm/util/ersatz_progress.cc index eb635ad8..498ab5c5 100644 --- a/klm/util/ersatz_progress.cc +++ b/klm/util/ersatz_progress.cc @@ -9,6 +9,8 @@ namespace util { namespace { const unsigned char kWidth = 100; } +const char kProgressBanner[] = "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; + ErsatzProgress::ErsatzProgress() : current_(0), next_(std::numeric_limits::max()), complete_(next_), out_(NULL) {} ErsatzProgress::~ErsatzProgress() { @@ -22,7 +24,7 @@ ErsatzProgress::ErsatzProgress(uint64_t complete, std::ostream *to, const std::s return; } if (!message.empty()) *out_ << message << '\n'; - *out_ << "----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100\n"; + *out_ << kProgressBanner; } void ErsatzProgress::Milestone() { @@ -38,7 +40,7 @@ void ErsatzProgress::Milestone() { next_ = std::numeric_limits::max(); out_ = NULL; } else { - next_ = std::max(next_, (stone * complete_) / kWidth); + next_ = std::max(next_, ((stone + 1) * complete_ + kWidth - 1) / kWidth); } } diff --git a/klm/util/ersatz_progress.hh b/klm/util/ersatz_progress.hh index 9909736d..b94399a8 100644 --- a/klm/util/ersatz_progress.hh +++ b/klm/util/ersatz_progress.hh @@ -10,6 +10,9 @@ // boost. Also adds option to print nothing. namespace util { + +extern const char kProgressBanner[]; + class ErsatzProgress { public: // No output. @@ -32,7 +35,6 @@ class ErsatzProgress { void Set(uint64_t to) { if ((current_ = to) >= next_) Milestone(); - Milestone(); } void Finished() { diff --git a/klm/util/exception.cc b/klm/util/exception.cc index 3806e6de..557c3986 100644 --- a/klm/util/exception.cc +++ b/klm/util/exception.cc @@ -79,11 +79,6 @@ ErrnoException::ErrnoException() throw() : errno_(errno) { ErrnoException::~ErrnoException() throw() {} -EndOfFileException::EndOfFileException() throw() { - *this << "End of file"; -} -EndOfFileException::~EndOfFileException() throw() {} - OverflowException::OverflowException() throw() {} OverflowException::~OverflowException() throw() {} diff --git a/klm/util/exception.hh b/klm/util/exception.hh index 0165a7a3..74046cf9 100644 --- a/klm/util/exception.hh +++ b/klm/util/exception.hh @@ -44,7 +44,7 @@ class Exception : public std::exception { }; /* This implements the normal operator<< for Exception and all its children. - * SNIFAE means it only applies to Exception. Think of this as an ersatz + * SFINAE means it only applies to Exception. Think of this as an ersatz * boost::enable_if. */ template typename Except::template ExceptionTag::Identity operator<<(Except &e, const Data &data) { @@ -62,30 +62,26 @@ template typename Except::template ExceptionTag= 3 #define UTIL_UNLIKELY(x) __builtin_expect (!!(x), 0) @@ -93,15 +89,16 @@ template typename Except::template ExceptionTag #include +#include #include #include +#include #include #include #include @@ -37,6 +42,18 @@ scoped_FILE::~scoped_FILE() { } } +// Note that ErrnoException records errno before NameFromFD is called. +FDException::FDException(int fd) throw() : fd_(fd), name_guess_(NameFromFD(fd)) { + *this << "in " << name_guess_ << ' '; +} + +FDException::~FDException() throw() {} + +EndOfFileException::EndOfFileException() throw() { + *this << "End of file"; +} +EndOfFileException::~EndOfFileException() throw() {} + int OpenReadOrThrow(const char *name) { int ret; #if defined(_WIN32) || defined(_WIN64) @@ -61,19 +78,36 @@ uint64_t SizeFile(int fd) { #if defined(_WIN32) || defined(_WIN64) __int64 ret = _filelengthi64(fd); return (ret == -1) ? kBadSize : ret; +#else // Not windows. + +#ifdef OS_ANDROID + struct stat64 sb; + int ret = fstat64(fd, &sb); #else struct stat sb; - if (fstat(fd, &sb) == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; + int ret = fstat(fd, &sb); +#endif + if (ret == -1 || (!sb.st_size && !S_ISREG(sb.st_mode))) return kBadSize; return sb.st_size; #endif } +uint64_t SizeOrThrow(int fd) { + uint64_t ret = SizeFile(fd); + UTIL_THROW_IF_ARG(ret == kBadSize, FDException, (fd), "Failed to size"); + return ret; +} + void ResizeOrThrow(int fd, uint64_t to) { + UTIL_THROW_IF_ARG( #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF(_chsize_s(fd, to), ErrnoException, "Resizing to " << to << " bytes failed"); + _chsize_s +#elif defined(OS_ANDROID) + ftruncate64 #else - UTIL_THROW_IF(ftruncate(fd, to), ErrnoException, "Resizing to " << to << " bytes failed"); + ftruncate #endif + (fd, to), FDException, (fd), "while resizing to " << to << " bytes"); } std::size_t PartialRead(int fd, void *to, std::size_t amount) { @@ -81,9 +115,13 @@ std::size_t PartialRead(int fd, void *to, std::size_t amount) { amount = min(static_cast(INT_MAX), amount); int ret = _read(fd, to, amount); #else - ssize_t ret = read(fd, to, amount); + errno = 0; + ssize_t ret; + do { + ret = read(fd, to, amount); + } while (ret == -1 && errno == EINTR); #endif - UTIL_THROW_IF(ret < 0, ErrnoException, "Reading " << amount << " from fd " << fd << " failed."); + UTIL_THROW_IF_ARG(ret < 0, FDException, (fd), "while reading " << amount << " bytes"); return static_cast(ret); } @@ -91,7 +129,7 @@ void ReadOrThrow(int fd, void *to_void, std::size_t amount) { uint8_t *to = static_cast(to_void); while (amount) { std::size_t ret = PartialRead(fd, to, amount); - UTIL_THROW_IF(ret == 0, EndOfFileException, " in fd " << fd << " but there should be " << amount << " more bytes to read."); + UTIL_THROW_IF(ret == 0, EndOfFileException, " in " << NameFromFD(fd) << " but there should be " << amount << " more bytes to read."); amount -= ret; to += ret; } @@ -109,40 +147,86 @@ std::size_t ReadOrEOF(int fd, void *to_void, std::size_t amount) { return amount; } +void PReadOrThrow(int fd, void *to_void, std::size_t size, uint64_t off) { + uint8_t *to = static_cast(to_void); +#if defined(_WIN32) || defined(_WIN64) + UTIL_THROW(Exception, "TODO: PReadOrThrow for windows using ReadFile http://stackoverflow.com/questions/766477/are-there-equivalents-to-pread-on-different-platforms"); +#else + for (;size ;) { + ssize_t ret; + errno = 0; + do { +#ifdef OS_ANDROID + ret = pread64(fd, to, size, off); +#else + ret = pread(fd, to, size, off); +#endif + } while (ret == -1 && errno == EINTR); + if (ret <= 0) { + UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from " << NameFromFD(fd)); + UTIL_THROW_ARG(FDException, (fd), "while reading " << size << " bytes at offset " << off); + } + size -= ret; + off += ret; + to += ret; + } +#endif +} + void WriteOrThrow(int fd, const void *data_void, std::size_t size) { const uint8_t *data = static_cast(data_void); while (size) { #if defined(_WIN32) || defined(_WIN64) int ret = write(fd, data, min(static_cast(INT_MAX), size)); #else - ssize_t ret = write(fd, data, size); + errno = 0; + ssize_t ret; + do { + ret = write(fd, data, size); + } while (ret == -1 && errno == EINTR); #endif - if (ret < 1) UTIL_THROW(util::ErrnoException, "Write failed"); + UTIL_THROW_IF_ARG(ret < 1, FDException, (fd), "while writing " << size << " bytes"); data += ret; size -= ret; } } void WriteOrThrow(FILE *to, const void *data, std::size_t size) { - assert(size); - UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), util::ErrnoException, "Short write; requested size " << size); + if (!size) return; + UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); } void FSyncOrThrow(int fd) { // Apparently windows doesn't have fsync? #if !defined(_WIN32) && !defined(_WIN64) - UTIL_THROW_IF(-1 == fsync(fd), ErrnoException, "Sync of " << fd << " failed."); + UTIL_THROW_IF_ARG(-1 == fsync(fd), FDException, (fd), "while syncing"); #endif } namespace { + +// Static assert for 64-bit off_t size. +#if !defined(_WIN32) && !defined(_WIN64) && !defined(OS_ANDROID) +template struct CheckOffT; +template <> struct CheckOffT<8> { + struct True {}; +}; +// If there's a compiler error on the next line, then off_t isn't 64 bit. And +// that makes me a sad panda. +typedef CheckOffT::True IgnoredType; +#endif + +// Can't we all just get along? void InternalSeek(int fd, int64_t off, int whence) { + UTIL_THROW_IF_ARG( #if defined(_WIN32) || defined(_WIN64) - UTIL_THROW_IF((__int64)-1 == _lseeki64(fd, off, whence), ErrnoException, "Windows seek failed"); - + (__int64)-1 == _lseeki64(fd, off, whence), +#elif defined(OS_ANDROID) + (off64_t)-1 == lseek64(fd, off, whence), #else - UTIL_THROW_IF((off_t)-1 == lseek(fd, off, whence), ErrnoException, "Seek failed"); + (off_t)-1 == lseek(fd, off, whence), #endif + FDException, (fd), "while seeking to " << off << " whence " << whence); } } // namespace @@ -160,22 +244,18 @@ void SeekEnd(int fd) { std::FILE *FDOpenOrThrow(scoped_fd &file) { std::FILE *ret = fdopen(file.get(), "r+b"); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); + UTIL_THROW_IF_ARG(!ret, FDException, (file.get()), "Could not fdopen for write"); file.release(); return ret; } std::FILE *FDOpenReadOrThrow(scoped_fd &file) { std::FILE *ret = fdopen(file.get(), "rb"); - if (!ret) UTIL_THROW(util::ErrnoException, "Could not fdopen descriptor " << file.get()); + UTIL_THROW_IF_ARG(!ret, FDException, (file.get()), "Could not fdopen for read"); file.release(); return ret; } -TempMaker::TempMaker(const std::string &prefix) : base_(prefix) { - base_ += "XXXXXX"; -} - // Sigh. Windows temporary file creation is full of race conditions. #if defined(_WIN32) || defined(_WIN64) /* mkstemp extracted from libc/sysdeps/posix/tempname.c. Copyright @@ -292,23 +372,87 @@ int mkstemp_and_unlink(char *tmpl) { int ret = mkstemp(tmpl); if (ret != -1) { - UTIL_THROW_IF(unlink(tmpl), util::ErrnoException, "Failed to delete " << tmpl); + UTIL_THROW_IF(unlink(tmpl), ErrnoException, "while deleting delete " << tmpl); } return ret; } #endif -int TempMaker::Make() const { - std::string name(base_); +// If it's a directory, add a /. This lets users say -T /tmp without creating +// /tmpAAAAAA +void NormalizeTempPrefix(std::string &base) { + if (base.empty()) return; + if (base[base.size() - 1] == '/') return; + struct stat sb; + // It's fine for it to not exist. + if (-1 == stat(base.c_str(), &sb)) return; + if (S_ISDIR(sb.st_mode)) base += '/'; +} + +int MakeTemp(const std::string &base) { + std::string name(base); + name += "XXXXXX"; name.push_back(0); int ret; - UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&name[0])), util::ErrnoException, "Failed to make a temporary based on " << base_); + UTIL_THROW_IF(-1 == (ret = mkstemp_and_unlink(&name[0])), ErrnoException, "while making a temporary based on " << base); return ret; } -std::FILE *TempMaker::MakeFile() const { - util::scoped_fd file(Make()); +std::FILE *FMakeTemp(const std::string &base) { + util::scoped_fd file(MakeTemp(base)); return FDOpenOrThrow(file); } +int DupOrThrow(int fd) { + int ret = dup(fd); + UTIL_THROW_IF_ARG(ret == -1, FDException, (fd), "in duplicating the file descriptor"); + return ret; +} + +namespace { +// Try to name things but be willing to fail too. +bool TryName(int fd, std::string &out) { +#if defined(_WIN32) || defined(_WIN64) + return false; +#else + std::string name("/proc/self/fd/"); + std::ostringstream convert; + convert << fd; + name += convert.str(); + + struct stat sb; + if (-1 == lstat(name.c_str(), &sb)) + return false; + out.resize(sb.st_size + 1); + ssize_t ret = readlink(name.c_str(), &out[0], sb.st_size + 1); + if (-1 == ret) + return false; + if (ret > sb.st_size) { + // Increased in size?! + return false; + } + out.resize(ret); + // Don't use the non-file names. + if (!out.empty() && out[0] != '/') + return false; + return true; +#endif +} +} // namespace + +std::string NameFromFD(int fd) { + std::string ret; + if (TryName(fd, ret)) return ret; + switch (fd) { + case 0: return "stdin"; + case 1: return "stdout"; + case 2: return "stderr"; + } + ret = "fd "; + std::ostringstream convert; + convert << fd; + ret += convert.str(); + return ret; +} + } // namespace util diff --git a/klm/util/file.hh b/klm/util/file.hh index c24580d6..be88431d 100644 --- a/klm/util/file.hh +++ b/klm/util/file.hh @@ -1,6 +1,8 @@ #ifndef UTIL_FILE__ #define UTIL_FILE__ +#include "util/exception.hh" + #include #include #include @@ -17,7 +19,7 @@ class scoped_fd { ~scoped_fd(); - void reset(int to) { + void reset(int to = -1) { scoped_fd other(fd_); fd_ = to; } @@ -63,6 +65,32 @@ class scoped_FILE { std::FILE *file_; }; +/* Thrown for any operation where the fd is known. */ +class FDException : public ErrnoException { + public: + explicit FDException(int fd) throw(); + + virtual ~FDException() throw(); + + // This may no longer be valid if the exception was thrown past open. + int FD() const { return fd_; } + + // Guess from NameFromFD. + const std::string &NameGuess() const { return name_guess_; } + + private: + int fd_; + + std::string name_guess_; +}; + +// End of file reached. +class EndOfFileException : public Exception { + public: + EndOfFileException() throw(); + ~EndOfFileException() throw(); +}; + // Open for read only. int OpenReadOrThrow(const char *name); // Create file if it doesn't exist, truncate if it does. Opened for write. @@ -71,12 +99,15 @@ int CreateOrThrow(const char *name); // Return value for SizeFile when it can't size properly. const uint64_t kBadSize = (uint64_t)-1; uint64_t SizeFile(int fd); +uint64_t SizeOrThrow(int fd); void ResizeOrThrow(int fd, uint64_t to); std::size_t PartialRead(int fd, void *to, std::size_t size); void ReadOrThrow(int fd, void *to, std::size_t size); std::size_t ReadOrEOF(int fd, void *to_void, std::size_t size); +// Positioned: unix only for now. +void PReadOrThrow(int fd, void *to, std::size_t size, uint64_t off); void WriteOrThrow(int fd, const void *data_void, std::size_t size); void WriteOrThrow(FILE *to, const void *data, std::size_t size); @@ -91,17 +122,20 @@ void SeekEnd(int fd); std::FILE *FDOpenOrThrow(scoped_fd &file); std::FILE *FDOpenReadOrThrow(scoped_fd &file); -class TempMaker { - public: - explicit TempMaker(const std::string &prefix); +// Temporary files +// Append a / if base is a directory. +void NormalizeTempPrefix(std::string &base); +int MakeTemp(const std::string &prefix); +std::FILE *FMakeTemp(const std::string &prefix); - // These will already be unlinked for you. - int Make() const; - std::FILE *MakeFile() const; +// dup an fd. +int DupOrThrow(int fd); - private: - std::string base_; -}; +/* Attempt get file name from fd. This won't always work (i.e. on Windows or + * a pipe). The file might have been renamed. It's intended for diagnostics + * and logging only. + */ +std::string NameFromFD(int fd); } // namespace util diff --git a/klm/util/file_piece.cc b/klm/util/file_piece.cc index 5a208eff..fbfa0e0e 100644 --- a/klm/util/file_piece.cc +++ b/klm/util/file_piece.cc @@ -1,13 +1,15 @@ #include "util/file_piece.hh" +#include "util/double-conversion/double-conversion.h" #include "util/exception.hh" #include "util/file.hh" #include "util/mmap.hh" -#ifdef WIN32 + +#if defined(_WIN32) || defined(_WIN64) #include #else #include -#endif // WIN32 +#endif #include #include @@ -34,10 +36,17 @@ FilePiece::FilePiece(const char *name, std::ostream *show_progress, std::size_t Initialize(name, show_progress, min_buffer); } -FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std::size_t min_buffer) : +namespace { +std::string NamePossiblyFind(int fd, const char *name) { + if (name) return name; + return NameFromFD(fd); +} +} // namespace + +FilePiece::FilePiece(int fd, const char *name, std::ostream *show_progress, std::size_t min_buffer) : file_(fd), total_size_(SizeFile(file_.get())), page_(SizePage()), - progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + name) { - Initialize(name, show_progress, min_buffer); + progress_(total_size_, total_size_ == kBadSize ? NULL : show_progress, std::string("Reading ") + NamePossiblyFind(fd, name)) { + Initialize(NamePossiblyFind(fd, name).c_str(), show_progress, min_buffer); } FilePiece::~FilePiece() {} @@ -103,21 +112,33 @@ void FilePiece::Initialize(const char *name, std::ostream *show_progress, std::s } namespace { -void ParseNumber(const char *begin, char *&end, float &out) { -#if defined(sun) || defined(WIN32) - out = static_cast(strtod(begin, &end)); -#else - out = strtof(begin, &end); -#endif + +static const double_conversion::StringToDoubleConverter kConverter( + double_conversion::StringToDoubleConverter::ALLOW_TRAILING_JUNK | double_conversion::StringToDoubleConverter::ALLOW_LEADING_SPACES, + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + "inf", + "NaN"); + +void ParseNumber(const char *begin, const char *&end, float &out) { + int count; + out = kConverter.StringToFloat(begin, end - begin, &count); + end = begin + count; } -void ParseNumber(const char *begin, char *&end, double &out) { - out = strtod(begin, &end); +void ParseNumber(const char *begin, const char *&end, double &out) { + int count; + out = kConverter.StringToDouble(begin, end - begin, &count); + end = begin + count; } -void ParseNumber(const char *begin, char *&end, long int &out) { - out = strtol(begin, &end, 10); +void ParseNumber(const char *begin, const char *&end, long int &out) { + char *silly_end; + out = strtol(begin, &silly_end, 10); + end = silly_end; } -void ParseNumber(const char *begin, char *&end, unsigned long int &out) { - out = strtoul(begin, &end, 10); +void ParseNumber(const char *begin, const char *&end, unsigned long int &out) { + char *silly_end; + out = strtoul(begin, &silly_end, 10); + end = silly_end; } } // namespace @@ -127,16 +148,17 @@ template T FilePiece::ReadNumber() { if (at_end_) { // Hallucinate a null off the end of the file. std::string buffer(position_, position_end_); - char *end; + const char *buf = buffer.c_str(); + const char *end = buf + buffer.size(); T ret; - ParseNumber(buffer.c_str(), end, ret); - if (buffer.c_str() == end) throw ParseNumberException(buffer); - position_ += end - buffer.c_str(); + ParseNumber(buf, end, ret); + if (buf == end) throw ParseNumberException(buffer); + position_ += end - buf; return ret; } Shift(); } - char *end; + const char *end = last_space_; T ret; ParseNumber(position_, end, ret); if (end == position_) throw ParseNumberException(ReadDelimited()); diff --git a/klm/util/file_piece.hh b/klm/util/file_piece.hh index 39bd1581..53310976 100644 --- a/klm/util/file_piece.hh +++ b/klm/util/file_piece.hh @@ -29,7 +29,7 @@ class FilePiece { // 1 MB default. explicit FilePiece(const char *file, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); // Takes ownership of fd. name is used for messages. - explicit FilePiece(int fd, const char *name, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); + explicit FilePiece(int fd, const char *name = NULL, std::ostream *show_progress = NULL, std::size_t min_buffer = 1048576); ~FilePiece(); diff --git a/klm/util/file_piece_test.cc b/klm/util/file_piece_test.cc index e79ece7a..91e4c559 100644 --- a/klm/util/file_piece_test.cc +++ b/klm/util/file_piece_test.cc @@ -1,6 +1,7 @@ // Tests might fail if you have creative characters in your path. Sue me. #include "util/file_piece.hh" +#include "util/file.hh" #include "util/scoped.hh" #define BOOST_TEST_MODULE FilePieceTest diff --git a/klm/util/have.hh b/klm/util/have.hh index 85b838e4..1523c0c5 100644 --- a/klm/util/have.hh +++ b/klm/util/have.hh @@ -10,8 +10,4 @@ //#define HAVE_BOOST #endif -#ifdef HAVE_CONFIG_H -#include "config.h" -#endif - #endif // UTIL_HAVE__ diff --git a/klm/util/multi_intersection.hh b/klm/util/multi_intersection.hh new file mode 100644 index 00000000..8334d39d --- /dev/null +++ b/klm/util/multi_intersection.hh @@ -0,0 +1,80 @@ +#ifndef UTIL_MULTI_INTERSECTION__ +#define UTIL_MULTI_INTERSECTION__ + +#include +#include + +#include +#include +#include + +namespace util { + +namespace detail { +template struct RangeLessBySize : public std::binary_function { + bool operator()(const Range &left, const Range &right) const { + return left.size() < right.size(); + } +}; + +/* Takes sets specified by their iterators and a boost::optional containing + * the lowest intersection if any. Each set must be sorted in increasing + * order. sets is changed to truncate the beginning of each sequence to the + * location of the match or an empty set. Precondition: sets is not empty + * since the intersection over null is the universe and this function does not + * know the universe. + */ +template boost::optional::value_type> FirstIntersectionSorted(std::vector > &sets, const Less &less = std::less::value_type>()) { + typedef std::vector > Sets; + typedef typename std::iterator_traits::value_type Value; + + assert(!sets.empty()); + + if (sets.front().empty()) return boost::optional(); + // Possibly suboptimal to copy for general Value; makes unsigned int go slightly faster. + Value highest(sets.front().front()); + for (typename Sets::iterator i(sets.begin()); i != sets.end(); ) { + i->advance_begin(std::lower_bound(i->begin(), i->end(), highest, less) - i->begin()); + if (i->empty()) return boost::optional(); + if (less(highest, i->front())) { + highest = i->front(); + // start over + i = sets.begin(); + } else { + ++i; + } + } + return boost::optional(highest); +} + +} // namespace detail + +template boost::optional::value_type> FirstIntersection(std::vector > &sets, const Less less) { + assert(!sets.empty()); + + std::sort(sets.begin(), sets.end(), detail::RangeLessBySize >()); + return detail::FirstIntersectionSorted(sets, less); +} + +template boost::optional::value_type> FirstIntersection(std::vector > &sets) { + return FirstIntersection(sets, std::less::value_type>()); +} + +template void AllIntersection(std::vector > &sets, Output &out, const Less less) { + typedef typename std::iterator_traits::value_type Value; + assert(!sets.empty()); + + std::sort(sets.begin(), sets.end(), detail::RangeLessBySize >()); + boost::optional ret; + for (boost::optional ret; ret = detail::FirstIntersectionSorted(sets, less); sets.front().advance_begin(1)) { + out(*ret); + } +} + +template void AllIntersection(std::vector > &sets, Output &out) { + AllIntersection(sets, out, std::less::value_type>()); +} + +} // namespace util + +#endif // UTIL_MULTI_INTERSECTION__ diff --git a/klm/util/multi_intersection_test.cc b/klm/util/multi_intersection_test.cc new file mode 100644 index 00000000..970afc17 --- /dev/null +++ b/klm/util/multi_intersection_test.cc @@ -0,0 +1,63 @@ +#include "util/multi_intersection.hh" + +#define BOOST_TEST_MODULE MultiIntersectionTest +#include + +namespace util { +namespace { + +BOOST_AUTO_TEST_CASE(Empty) { + std::vector > sets; + + sets.push_back(boost::iterator_range(static_cast(NULL), static_cast(NULL))); + BOOST_CHECK(!FirstIntersection(sets)); +} + +BOOST_AUTO_TEST_CASE(Single) { + std::vector nums; + nums.push_back(1); + nums.push_back(4); + nums.push_back(100); + std::vector::const_iterator> > sets; + sets.push_back(nums); + + boost::optional ret(FirstIntersection(sets)); + + BOOST_REQUIRE(ret); + BOOST_CHECK_EQUAL(static_cast(1), *ret); +} + +template boost::iterator_range RangeFromArray(const T (&arr)[len]) { + return boost::iterator_range(arr, arr + len); +} + +BOOST_AUTO_TEST_CASE(MultiNone) { + unsigned int nums0[] = {1, 3, 4, 22}; + unsigned int nums1[] = {2, 5, 12}; + unsigned int nums2[] = {4, 17}; + + std::vector > sets; + sets.push_back(RangeFromArray(nums0)); + sets.push_back(RangeFromArray(nums1)); + sets.push_back(RangeFromArray(nums2)); + + BOOST_CHECK(!FirstIntersection(sets)); +} + +BOOST_AUTO_TEST_CASE(MultiOne) { + unsigned int nums0[] = {1, 3, 4, 17, 22}; + unsigned int nums1[] = {2, 5, 12, 17}; + unsigned int nums2[] = {4, 17}; + + std::vector > sets; + sets.push_back(RangeFromArray(nums0)); + sets.push_back(RangeFromArray(nums1)); + sets.push_back(RangeFromArray(nums2)); + + boost::optional ret(FirstIntersection(sets)); + BOOST_REQUIRE(ret); + BOOST_CHECK_EQUAL(static_cast(17), *ret); +} + +} // namespace +} // namespace util diff --git a/klm/util/pcqueue.hh b/klm/util/pcqueue.hh new file mode 100644 index 00000000..3df8749b --- /dev/null +++ b/klm/util/pcqueue.hh @@ -0,0 +1,105 @@ +#ifndef UTIL_PCQUEUE__ +#define UTIL_PCQUEUE__ + +#include +#include +#include +#include + +#include + +namespace util { + +inline void WaitSemaphore (boost::interprocess::interprocess_semaphore &on) { + while (1) { + try { + on.wait(); + break; + } + catch (boost::interprocess::interprocess_exception &e) { + if (e.get_native_error() != EINTR) throw; + } + } +} + +/* Producer consumer queue safe for multiple producers and multiple consumers. + * T must be default constructable and have operator=. + * The value is copied twice for Consume(T &out) or three times for Consume(), + * so larger objects should be passed via pointer. + * Strong exception guarantee if operator= throws. Undefined if semaphores throw. + */ +template class PCQueue : boost::noncopyable { + public: + explicit PCQueue(size_t size) + : empty_(size), used_(0), + storage_(new T[size]), + end_(storage_.get() + size), + produce_at_(storage_.get()), + consume_at_(storage_.get()) {} + + // Add a value to the queue. + void Produce(const T &val) { + WaitSemaphore(empty_); + { + boost::unique_lock produce_lock(produce_at_mutex_); + try { + *produce_at_ = val; + } + catch (...) { + empty_.post(); + throw; + } + if (++produce_at_ == end_) produce_at_ = storage_.get(); + } + used_.post(); + } + + // Consume a value, assigning it to out. + T& Consume(T &out) { + WaitSemaphore(used_); + { + boost::unique_lock consume_lock(consume_at_mutex_); + try { + out = *consume_at_; + } + catch (...) { + used_.post(); + throw; + } + if (++consume_at_ == end_) consume_at_ = storage_.get(); + } + empty_.post(); + return out; + } + + // Convenience version of Consume that copies the value to return. + // The other version is faster. + T Consume() { + T ret; + Consume(ret); + return ret; + } + + private: + // Number of empty spaces in storage_. + boost::interprocess::interprocess_semaphore empty_; + // Number of occupied spaces in storage_. + boost::interprocess::interprocess_semaphore used_; + + boost::scoped_array storage_; + + T *const end_; + + // Index for next write in storage_. + T *produce_at_; + boost::mutex produce_at_mutex_; + + // Index for next read from storage_. + T *consume_at_; + boost::mutex consume_at_mutex_; + +}; + +} // namespace util + +#endif // UTIL_PCQUEUE__ diff --git a/klm/util/pool.cc b/klm/util/pool.cc index 2dffd06f..429ba158 100644 --- a/klm/util/pool.cc +++ b/klm/util/pool.cc @@ -1,5 +1,7 @@ #include "util/pool.hh" +#include "util/scoped.hh" + #include namespace util { @@ -24,8 +26,7 @@ void Pool::FreeAll() { void *Pool::More(std::size_t size) { std::size_t amount = std::max(static_cast(32) << free_list_.size(), size); - uint8_t *ret = static_cast(malloc(amount)); - if (!ret) throw std::bad_alloc(); + uint8_t *ret = static_cast(MallocOrThrow(amount)); free_list_.push_back(ret); current_ = ret + size; current_end_ = ret + amount; diff --git a/klm/util/probing_hash_table.hh b/klm/util/probing_hash_table.hh index 4a8aff35..6780489d 100644 --- a/klm/util/probing_hash_table.hh +++ b/klm/util/probing_hash_table.hh @@ -126,6 +126,11 @@ template + +namespace util { + +MallocException::MallocException(std::size_t requested) throw() { + *this << "for " << requested << " bytes "; +} + +MallocException::~MallocException() throw() {} + +void *MallocOrThrow(std::size_t requested) { + void *ret; + UTIL_THROW_IF_ARG(!(ret = std::malloc(requested)), MallocException, (requested), "in malloc"); + return ret; +} + +scoped_malloc::~scoped_malloc() { + std::free(p_); +} + +void scoped_malloc::call_realloc(std::size_t to) { + void *ret; + UTIL_THROW_IF_ARG(!(ret = std::realloc(p_, to)) && to, MallocException, (to), "in realloc"); + p_ = ret; +} + +} // namespace util diff --git a/klm/util/scoped.hh b/klm/util/scoped.hh index d62c6df1..d0a5aabd 100644 --- a/klm/util/scoped.hh +++ b/klm/util/scoped.hh @@ -4,28 +4,31 @@ #include "util/exception.hh" #include -#include namespace util { +class MallocException : public ErrnoException { + public: + explicit MallocException(std::size_t requested) throw(); + ~MallocException() throw(); +}; + +void *MallocOrThrow(std::size_t requested); + class scoped_malloc { public: scoped_malloc() : p_(NULL) {} scoped_malloc(void *p) : p_(p) {} - ~scoped_malloc() { std::free(p_); } + ~scoped_malloc(); void reset(void *p = NULL) { scoped_malloc other(p_); p_ = p; } - void call_realloc(std::size_t to) { - void *ret; - UTIL_THROW_IF(!(ret = std::realloc(p_, to)) && to, util::ErrnoException, "realloc to " << to << " bytes failed."); - p_ = ret; - } + void call_realloc(std::size_t to); void *get() { return p_; } const void *get() const { return p_; } diff --git a/klm/util/stream/block.hh b/klm/util/stream/block.hh new file mode 100644 index 00000000..11aa991e --- /dev/null +++ b/klm/util/stream/block.hh @@ -0,0 +1,43 @@ +#ifndef UTIL_STREAM_BLOCK__ +#define UTIL_STREAM_BLOCK__ + +#include +#include + +namespace util { +namespace stream { + +class Block { + public: + Block() : mem_(NULL), valid_size_(0) {} + + Block(void *mem, std::size_t size) : mem_(mem), valid_size_(size) {} + + void SetValidSize(std::size_t to) { valid_size_ = to; } + // Read might fill in less than Allocated at EOF. + std::size_t ValidSize() const { return valid_size_; } + + void *Get() { return mem_; } + const void *Get() const { return mem_; } + + const void *ValidEnd() const { + return reinterpret_cast(mem_) + valid_size_; + } + + operator bool() const { return mem_ != NULL; } + bool operator!() const { return mem_ == NULL; } + + private: + friend class Link; + void SetToPoison() { + mem_ = NULL; + } + + void *mem_; + std::size_t valid_size_; +}; + +} // namespace stream +} // namespace util + +#endif // UTIL_STREAM_BLOCK__ diff --git a/klm/util/stream/chain.cc b/klm/util/stream/chain.cc new file mode 100644 index 00000000..46708c60 --- /dev/null +++ b/klm/util/stream/chain.cc @@ -0,0 +1,155 @@ +#include "util/stream/chain.hh" + +#include "util/stream/io.hh" + +#include "util/exception.hh" +#include "util/pcqueue.hh" + +#include +#include +#include + +#include +#include + +namespace util { +namespace stream { + +ChainConfigException::ChainConfigException() throw() { *this << "Chain configured with "; } +ChainConfigException::~ChainConfigException() throw() {} + +Thread::~Thread() { + thread_.join(); +} + +void Thread::UnhandledException(const std::exception &e) { + std::cerr << e.what() << std::endl; + abort(); +} + +void Recycler::Run(const ChainPosition &position) { + for (Link l(position); l; ++l) { + l->SetValidSize(position.GetChain().BlockSize()); + } +} + +const Recycler kRecycle = Recycler(); + +Chain::Chain(const ChainConfig &config) : config_(config), complete_called_(false) { + UTIL_THROW_IF(!config.entry_size, ChainConfigException, "zero-size entries."); + UTIL_THROW_IF(!config.block_count, ChainConfigException, "block count zero"); + UTIL_THROW_IF(config.total_memory < config.entry_size * config.block_count, ChainConfigException, config.total_memory << " total memory, too small for " << config.block_count << " blocks of containing entries of size " << config.entry_size); + // Round down block size to a multiple of entry size. + block_size_ = config.total_memory / (config.block_count * config.entry_size) * config.entry_size; +} + +Chain::~Chain() { + Wait(); +} + +ChainPosition Chain::Add() { + if (!Running()) Start(); + PCQueue &in = queues_.back(); + queues_.push_back(new PCQueue(config_.block_count)); + return ChainPosition(in, queues_.back(), this, progress_); +} + +Chain &Chain::operator>>(const WriteAndRecycle &writer) { + threads_.push_back(new Thread(Complete(), writer)); + return *this; +} + +void Chain::Wait(bool release_memory) { + if (queues_.empty()) { + assert(threads_.empty()); + return; // Nothing to wait for. + } + if (!complete_called_) CompleteLoop(); + threads_.clear(); + for (std::size_t i = 0; queues_.front().Consume(); ++i) { + if (i == config_.block_count) { + std::cerr << "Chain ending without poison." << std::endl; + abort(); + } + } + queues_.clear(); + progress_.Finished(); + complete_called_ = false; + if (release_memory) memory_.reset(); +} + +void Chain::Start() { + Wait(false); + if (!memory_.get()) { + // Allocate memory. + assert(threads_.empty()); + assert(queues_.empty()); + std::size_t malloc_size = block_size_ * config_.block_count; + memory_.reset(MallocOrThrow(malloc_size)); + } + // This queue can accomodate all blocks. + queues_.push_back(new PCQueue(config_.block_count)); + // Populate the lead queue with blocks. + uint8_t *base = static_cast(memory_.get()); + for (std::size_t i = 0; i < config_.block_count; ++i) { + queues_.front().Produce(Block(base, block_size_)); + base += block_size_; + } +} + +ChainPosition Chain::Complete() { + assert(Running()); + UTIL_THROW_IF(complete_called_, util::Exception, "CompleteLoop() called twice"); + complete_called_ = true; + return ChainPosition(queues_.back(), queues_.front(), this, progress_); +} + +Link::Link() : in_(NULL), out_(NULL), poisoned_(true) {} + +void Link::Init(const ChainPosition &position) { + UTIL_THROW_IF(in_, util::Exception, "Link::Init twice"); + in_ = position.in_; + out_ = position.out_; + poisoned_ = false; + progress_ = position.progress_; + in_->Consume(current_); +} + +Link::Link(const ChainPosition &position) : in_(NULL) { + Init(position); +} + +Link::~Link() { + if (current_) { + // Probably an exception unwinding. + std::cerr << "Last input should have been poison." << std::endl; +// abort(); + } else { + if (!poisoned_) { + // Pass the poison! + out_->Produce(current_); + } + } +} + +Link &Link::operator++() { + assert(current_); + progress_ += current_.ValidSize(); + out_->Produce(current_); + in_->Consume(current_); + if (!current_) { + poisoned_ = true; + out_->Produce(current_); + } + return *this; +} + +void Link::Poison() { + assert(!poisoned_); + current_.SetToPoison(); + out_->Produce(current_); + poisoned_ = true; +} + +} // namespace stream +} // namespace util diff --git a/klm/util/stream/chain.hh b/klm/util/stream/chain.hh new file mode 100644 index 00000000..154b9b33 --- /dev/null +++ b/klm/util/stream/chain.hh @@ -0,0 +1,198 @@ +#ifndef UTIL_STREAM_CHAIN__ +#define UTIL_STREAM_CHAIN__ + +#include "util/stream/block.hh" +#include "util/stream/config.hh" +#include "util/stream/multi_progress.hh" +#include "util/scoped.hh" + +#include +#include + +#include + +#include + +namespace util { +template class PCQueue; +namespace stream { + +class ChainConfigException : public Exception { + public: + ChainConfigException() throw(); + ~ChainConfigException() throw(); +}; + +class Chain; +// Specifies position in chain for Link constructor. +class ChainPosition { + public: + const Chain &GetChain() const { return *chain_; } + private: + friend class Chain; + friend class Link; + ChainPosition(PCQueue &in, PCQueue &out, Chain *chain, MultiProgress &progress) + : in_(&in), out_(&out), chain_(chain), progress_(progress.Add()) {} + + PCQueue *in_, *out_; + + Chain *chain_; + + WorkerProgress progress_; +}; + +// Position is usually ChainPosition but if there are multiple streams involved, this can be ChainPositions. +class Thread { + public: + template Thread(const Position &position, const Worker &worker) + : thread_(boost::ref(*this), position, worker) {} + + ~Thread(); + + template void operator()(const Position &position, Worker &worker) { + try { + worker.Run(position); + } catch (const std::exception &e) { + UnhandledException(e); + } + } + + private: + void UnhandledException(const std::exception &e); + + boost::thread thread_; +}; + +class Recycler { + public: + void Run(const ChainPosition &position); +}; + +extern const Recycler kRecycle; +class WriteAndRecycle; + +class Chain { + private: + template struct CheckForRun { + typedef Chain type; + }; + + public: + explicit Chain(const ChainConfig &config); + + ~Chain(); + + void ActivateProgress() { + assert(!Running()); + progress_.Activate(); + } + + void SetProgressTarget(uint64_t target) { + progress_.SetTarget(target); + } + + std::size_t EntrySize() const { + return config_.entry_size; + } + std::size_t BlockSize() const { + return block_size_; + } + + // Two ways to add to the chain: Add() or operator>>. + ChainPosition Add(); + + // This is for adding threaded workers with a Run method. + template typename CheckForRun::type &operator>>(const Worker &worker) { + assert(!complete_called_); + threads_.push_back(new Thread(Add(), worker)); + return *this; + } + + // Avoid copying the worker. + template typename CheckForRun::type &operator>>(const boost::reference_wrapper &worker) { + assert(!complete_called_); + threads_.push_back(new Thread(Add(), worker)); + return *this; + } + + // Note that Link and Stream also define operator>> outside this class. + + // To complete the loop, call CompleteLoop(), >> kRecycle, or the destructor. + void CompleteLoop() { + threads_.push_back(new Thread(Complete(), kRecycle)); + } + + Chain &operator>>(const Recycler &recycle) { + CompleteLoop(); + return *this; + } + + Chain &operator>>(const WriteAndRecycle &writer); + + // Chains are reusable. Call Wait to wait for everything to finish and free memory. + void Wait(bool release_memory = true); + + // Waits for the current chain to complete (if any) then starts again. + void Start(); + + bool Running() const { return !queues_.empty(); } + + private: + ChainPosition Complete(); + + ChainConfig config_; + + std::size_t block_size_; + + scoped_malloc memory_; + + boost::ptr_vector > queues_; + + bool complete_called_; + + boost::ptr_vector threads_; + + MultiProgress progress_; +}; + +// Create the link in the worker thread using the position token. +class Link { + public: + // Either default construct and Init or just construct all at once. + Link(); + void Init(const ChainPosition &position); + + explicit Link(const ChainPosition &position); + + ~Link(); + + Block &operator*() { return current_; } + const Block &operator*() const { return current_; } + + Block *operator->() { return ¤t_; } + const Block *operator->() const { return ¤t_; } + + Link &operator++(); + + operator bool() const { return current_; } + + void Poison(); + + private: + Block current_; + PCQueue *in_, *out_; + + bool poisoned_; + + WorkerProgress progress_; +}; + +inline Chain &operator>>(Chain &chain, Link &link) { + link.Init(chain.Add()); + return chain; +} + +} // namespace stream +} // namespace util + +#endif // UTIL_STREAM_CHAIN__ diff --git a/klm/util/stream/config.hh b/klm/util/stream/config.hh new file mode 100644 index 00000000..1eeb3a8a --- /dev/null +++ b/klm/util/stream/config.hh @@ -0,0 +1,32 @@ +#ifndef UTIL_STREAM_CONFIG__ +#define UTIL_STREAM_CONFIG__ + +#include +#include + +namespace util { namespace stream { + +struct ChainConfig { + ChainConfig() {} + + ChainConfig(std::size_t in_entry_size, std::size_t in_block_count, std::size_t in_total_memory) + : entry_size(in_entry_size), block_count(in_block_count), total_memory(in_total_memory) {} + + std::size_t entry_size; + std::size_t block_count; + // Chain's constructor will make this a multiple of entry_size. + std::size_t total_memory; +}; + +struct SortConfig { + std::string temp_prefix; + + // Size of each input/output buffer. + std::size_t buffer_size; + + // Total memory to use when running alone. + std::size_t total_memory; +}; + +}} // namespaces +#endif // UTIL_STREAM_CONFIG__ diff --git a/klm/util/stream/io.cc b/klm/util/stream/io.cc new file mode 100644 index 00000000..c7ad2980 --- /dev/null +++ b/klm/util/stream/io.cc @@ -0,0 +1,64 @@ +#include "util/stream/io.hh" + +#include "util/file.hh" +#include "util/stream/chain.hh" + +#include + +namespace util { +namespace stream { + +ReadSizeException::ReadSizeException() throw() {} +ReadSizeException::~ReadSizeException() throw() {} + +void Read::Run(const ChainPosition &position) { + const std::size_t block_size = position.GetChain().BlockSize(); + const std::size_t entry_size = position.GetChain().EntrySize(); + for (Link link(position); link; ++link) { + std::size_t got = util::ReadOrEOF(file_, link->Get(), block_size); + UTIL_THROW_IF(got % entry_size, ReadSizeException, "File ended with " << got << " bytes, not a multiple of " << entry_size << "."); + if (got == 0) { + link.Poison(); + return; + } else { + link->SetValidSize(got); + } + } +} + +void PRead::Run(const ChainPosition &position) { + scoped_fd owner; + if (own_) owner.reset(file_); + uint64_t size = SizeOrThrow(file_); + UTIL_THROW_IF(size % static_cast(position.GetChain().EntrySize()), ReadSizeException, "File size " << file_ << " size is " << size << " not a multiple of " << position.GetChain().EntrySize()); + std::size_t block_size = position.GetChain().BlockSize(); + Link link(position); + uint64_t offset = 0; + for (; offset + block_size < size; offset += block_size, ++link) { + PReadOrThrow(file_, link->Get(), block_size, offset); + link->SetValidSize(block_size); + } + if (size - offset) { + PReadOrThrow(file_, link->Get(), size - offset, offset); + link->SetValidSize(size - offset); + ++link; + } + link.Poison(); +} + +void Write::Run(const ChainPosition &position) { + for (Link link(position); link; ++link) { + WriteOrThrow(file_, link->Get(), link->ValidSize()); + } +} + +void WriteAndRecycle::Run(const ChainPosition &position) { + const std::size_t block_size = position.GetChain().BlockSize(); + for (Link link(position); link; ++link) { + WriteOrThrow(file_, link->Get(), link->ValidSize()); + link->SetValidSize(block_size); + } +} + +} // namespace stream +} // namespace util diff --git a/klm/util/stream/io.hh b/klm/util/stream/io.hh new file mode 100644 index 00000000..934b6b3f --- /dev/null +++ b/klm/util/stream/io.hh @@ -0,0 +1,76 @@ +#ifndef UTIL_STREAM_IO__ +#define UTIL_STREAM_IO__ + +#include "util/exception.hh" +#include "util/file.hh" + +namespace util { +namespace stream { + +class ChainPosition; + +class ReadSizeException : public util::Exception { + public: + ReadSizeException() throw(); + ~ReadSizeException() throw(); +}; + +class Read { + public: + explicit Read(int fd) : file_(fd) {} + void Run(const ChainPosition &position); + private: + int file_; +}; + +// Like read but uses pread so that the file can be accessed from multiple threads. +class PRead { + public: + explicit PRead(int fd, bool take_own = false) : file_(fd), own_(take_own) {} + void Run(const ChainPosition &position); + private: + int file_; + bool own_; +}; + +class Write { + public: + explicit Write(int fd) : file_(fd) {} + void Run(const ChainPosition &position); + private: + int file_; +}; + +class WriteAndRecycle { + public: + explicit WriteAndRecycle(int fd) : file_(fd) {} + void Run(const ChainPosition &position); + private: + int file_; +}; + +// Reuse the same file over and over again to buffer output. +class FileBuffer { + public: + explicit FileBuffer(int fd) : file_(fd) {} + + WriteAndRecycle Sink() const { + util::SeekOrThrow(file_.get(), 0); + return WriteAndRecycle(file_.get()); + } + + PRead Source() const { + return PRead(file_.get()); + } + + uint64_t Size() const { + return SizeOrThrow(file_.get()); + } + + private: + scoped_fd file_; +}; + +} // namespace stream +} // namespace util +#endif // UTIL_STREAM_IO__ diff --git a/klm/util/stream/io_test.cc b/klm/util/stream/io_test.cc new file mode 100644 index 00000000..82108335 --- /dev/null +++ b/klm/util/stream/io_test.cc @@ -0,0 +1,38 @@ +#include "util/stream/io.hh" + +#include "util/stream/chain.hh" +#include "util/file.hh" + +#define BOOST_TEST_MODULE IOTest +#include + +#include + +namespace util { namespace stream { namespace { + +BOOST_AUTO_TEST_CASE(CopyFile) { + std::string temps("io_test_temp"); + + scoped_fd in(MakeTemp(temps)); + for (uint64_t i = 0; i < 100000; ++i) { + WriteOrThrow(in.get(), &i, sizeof(uint64_t)); + } + SeekOrThrow(in.get(), 0); + scoped_fd out(MakeTemp(temps)); + + ChainConfig config; + config.entry_size = 8; + config.total_memory = 1024; + config.block_count = 10; + + Chain(config) >> PRead(in.get()) >> Write(out.get()); + + SeekOrThrow(out.get(), 0); + for (uint64_t i = 0; i < 100000; ++i) { + uint64_t got; + ReadOrThrow(out.get(), &got, sizeof(uint64_t)); + BOOST_CHECK_EQUAL(i, got); + } +} + +}}} // namespaces diff --git a/klm/util/stream/line_input.cc b/klm/util/stream/line_input.cc new file mode 100644 index 00000000..dafa5020 --- /dev/null +++ b/klm/util/stream/line_input.cc @@ -0,0 +1,52 @@ +#include "util/stream/line_input.hh" + +#include "util/exception.hh" +#include "util/file.hh" +#include "util/read_compressed.hh" +#include "util/stream/chain.hh" + +#include +#include + +namespace util { namespace stream { + +void LineInput::Run(const ChainPosition &position) { + ReadCompressed reader(fd_); + // Holding area for beginning of line to be placed in next block. + std::vector carry; + + for (Link block(position); ; ++block) { + char *to = static_cast(block->Get()); + char *begin = to; + char *end = to + position.GetChain().BlockSize(); + std::copy(carry.begin(), carry.end(), to); + to += carry.size(); + while (to != end) { + std::size_t got = reader.Read(to, end - to); + if (!got) { + // EOF + block->SetValidSize(to - begin); + ++block; + block.Poison(); + return; + } + to += got; + } + + // Find the last newline. + char *newline; + for (newline = to - 1; ; --newline) { + UTIL_THROW_IF(newline < begin, Exception, "Did not find a newline in " << position.GetChain().BlockSize() << " bytes of input of " << NameFromFD(fd_) << ". Is this a text file?"); + if (*newline == '\n') break; + } + + // Copy everything after the last newline to the carry. + carry.clear(); + carry.resize(to - (newline + 1)); + std::copy(newline + 1, to, &*carry.begin()); + + block->SetValidSize(newline + 1 - begin); + } +} + +}} // namespaces diff --git a/klm/util/stream/line_input.hh b/klm/util/stream/line_input.hh new file mode 100644 index 00000000..86db1dd0 --- /dev/null +++ b/klm/util/stream/line_input.hh @@ -0,0 +1,22 @@ +#ifndef UTIL_STREAM_LINE_INPUT__ +#define UTIL_STREAM_LINE_INPUT__ +namespace util {namespace stream { + +class ChainPosition; + +/* Worker that reads input into blocks, ensuring that blocks contain whole + * lines. Assumes that the maximum size of a line is less than the block size + */ +class LineInput { + public: + // Takes ownership upon thread execution. + explicit LineInput(int fd); + + void Run(const ChainPosition &position); + + private: + int fd_; +}; + +}} // namespaces +#endif // UTIL_STREAM_LINE_INPUT__ diff --git a/klm/util/stream/multi_progress.cc b/klm/util/stream/multi_progress.cc new file mode 100644 index 00000000..8ba10386 --- /dev/null +++ b/klm/util/stream/multi_progress.cc @@ -0,0 +1,86 @@ +#include "util/stream/multi_progress.hh" + +// TODO: merge some functionality with the simple progress bar? +#include "util/ersatz_progress.hh" + +#include +#include + +#include + +#if !defined(_WIN32) && !defined(_WIN64) +#include +#endif + +namespace util { namespace stream { + +namespace { +const char kDisplayCharacters[] = "-+*#0123456789"; + +uint64_t Next(unsigned char stone, uint64_t complete) { + return (static_cast(stone + 1) * complete + MultiProgress::kWidth - 1) / MultiProgress::kWidth; +} + +} // namespace + +MultiProgress::MultiProgress() : active_(false), complete_(std::numeric_limits::max()), character_handout_(0) {} + +MultiProgress::~MultiProgress() { + if (active_ && complete_ != std::numeric_limits::max()) + std::cerr << '\n'; +} + +void MultiProgress::Activate() { + active_ = +#if !defined(_WIN32) && !defined(_WIN64) + // Is stderr a terminal? + (isatty(2) == 1) +#else + true +#endif + ; +} + +void MultiProgress::SetTarget(uint64_t complete) { + if (!active_) return; + complete_ = complete; + if (!complete) complete_ = 1; + memset(display_, 0, sizeof(display_)); + character_handout_ = 0; + std::cerr << kProgressBanner; +} + +WorkerProgress MultiProgress::Add() { + if (!active_) + return WorkerProgress(std::numeric_limits::max(), *this, '\0'); + std::size_t character_index; + { + boost::unique_lock lock(mutex_); + character_index = character_handout_++; + if (character_handout_ == sizeof(kDisplayCharacters) - 1) + character_handout_ = 0; + } + return WorkerProgress(Next(0, complete_), *this, kDisplayCharacters[character_index]); +} + +void MultiProgress::Finished() { + if (!active_ || complete_ == std::numeric_limits::max()) return; + std::cerr << '\n'; + complete_ = std::numeric_limits::max(); +} + +void MultiProgress::Milestone(WorkerProgress &worker) { + if (!active_ || complete_ == std::numeric_limits::max()) return; + unsigned char stone = std::min(static_cast(kWidth), worker.current_ * kWidth / complete_); + for (char *i = &display_[worker.stone_]; i < &display_[stone]; ++i) { + *i = worker.character_; + } + worker.next_ = Next(stone, complete_); + worker.stone_ = stone; + { + boost::unique_lock lock(mutex_); + std::cerr << '\r' << display_ << std::flush; + } +} + +}} // namespaces diff --git a/klm/util/stream/multi_progress.hh b/klm/util/stream/multi_progress.hh new file mode 100644 index 00000000..c4dd45a9 --- /dev/null +++ b/klm/util/stream/multi_progress.hh @@ -0,0 +1,90 @@ +/* Progress bar suitable for chains of workers */ +#ifndef UTIL_MULTI_PROGRESS__ +#define UTIL_MULTI_PROGRESS__ + +#include + +#include + +#include + +namespace util { namespace stream { + +class WorkerProgress; + +class MultiProgress { + public: + static const unsigned char kWidth = 100; + + MultiProgress(); + + ~MultiProgress(); + + // Turns on showing (requires SetTarget too). + void Activate(); + + void SetTarget(uint64_t complete); + + WorkerProgress Add(); + + void Finished(); + + private: + friend class WorkerProgress; + void Milestone(WorkerProgress &worker); + + bool active_; + + uint64_t complete_; + + boost::mutex mutex_; + + // \0 at the end. + char display_[kWidth + 1]; + + std::size_t character_handout_; + + MultiProgress(const MultiProgress &); + MultiProgress &operator=(const MultiProgress &); +}; + +class WorkerProgress { + public: + // Default contrutor must be initialized with operator= later. + WorkerProgress() : parent_(NULL) {} + + // Not threadsafe for the same worker by default. + WorkerProgress &operator++() { + if (++current_ >= next_) { + parent_->Milestone(*this); + } + return *this; + } + + WorkerProgress &operator+=(uint64_t amount) { + current_ += amount; + if (current_ >= next_) { + parent_->Milestone(*this); + } + return *this; + } + + private: + friend class MultiProgress; + WorkerProgress(uint64_t next, MultiProgress &parent, char character) + : current_(0), next_(next), parent_(&parent), stone_(0), character_(character) {} + + uint64_t current_, next_; + + MultiProgress *parent_; + + // Previous milestone reached. + unsigned char stone_; + + // Character to display in bar. + char character_; +}; + +}} // namespaces + +#endif // UTIL_MULTI_PROGRESS__ diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh new file mode 100644 index 00000000..be6c11ea --- /dev/null +++ b/klm/util/stream/sort.hh @@ -0,0 +1,542 @@ +/* Usage: + * Sort sorter(temp, compare); + * Chain(config) >> Read(file) >> sorter.Unsorted(); + * Stream stream; + * Chain chain(config) >> sorter.Sorted(internal_config, lazy_config) >> stream; + * + * Note that sorter must outlive any threads that use Unsorted or Sorted. + * + * Combiners take the form: + * bool operator()(void *into, const void *option, const Compare &compare) const + * which returns true iff a combination happened. The sorting algorithm + * guarantees compare(into, option). But it does not guarantee + * compare(option, into). + * Currently, combining is only done in merge steps, not during on-the-fly + * sort. Use a hash table for that. + */ + +#ifndef UTIL_STREAM_SORT__ +#define UTIL_STREAM_SORT__ + +#include "util/stream/chain.hh" +#include "util/stream/config.hh" +#include "util/stream/io.hh" +#include "util/stream/stream.hh" +#include "util/stream/timer.hh" + +#include "util/file.hh" +#include "util/scoped.hh" +#include "util/sized_iterator.hh" + +#include +#include +#include +#include + +namespace util { +namespace stream { + +struct NeverCombine { + template bool operator()(const void *, const void *, const Compare &) const { + return false; + } +}; + +// Manage the offsets of sorted blocks in a file. +class Offsets { + public: + explicit Offsets(int fd) : log_(fd) { + Reset(); + } + + int File() const { return log_; } + + void Append(uint64_t length) { + if (!length) return; + ++block_count_; + if (length == cur_.length) { + ++cur_.run; + return; + } + WriteOrThrow(log_, &cur_, sizeof(Entry)); + cur_.length = length; + cur_.run = 1; + } + + void FinishedAppending() { + WriteOrThrow(log_, &cur_, sizeof(Entry)); + SeekOrThrow(log_, sizeof(Entry)); // Skip 0,0 at beginning. + cur_.run = 0; + if (block_count_) { + ReadOrThrow(log_, &cur_, sizeof(Entry)); + assert(cur_.length); + assert(cur_.run); + } + } + + uint64_t RemainingBlocks() const { return block_count_; } + + uint64_t TotalOffset() const { return output_sum_; } + + uint64_t PeekSize() const { + return cur_.length; + } + + uint64_t NextSize() { + assert(block_count_); + uint64_t ret = cur_.length; + output_sum_ += ret; + + --cur_.run; + --block_count_; + if (!cur_.run && block_count_) { + ReadOrThrow(log_, &cur_, sizeof(Entry)); + assert(cur_.length); + assert(cur_.run); + } + return ret; + } + + void Reset() { + SeekOrThrow(log_, 0); + ResizeOrThrow(log_, 0); + cur_.length = 0; + cur_.run = 0; + block_count_ = 0; + output_sum_ = 0; + } + + private: + int log_; + + struct Entry { + uint64_t length; + uint64_t run; + }; + Entry cur_; + + uint64_t block_count_; + + uint64_t output_sum_; +}; + +// A priority queue of entries backed by file buffers +template class MergeQueue { + public: + MergeQueue(int fd, std::size_t buffer_size, std::size_t entry_size, const Compare &compare) + : queue_(Greater(compare)), in_(fd), buffer_size_(buffer_size), entry_size_(entry_size) {} + + void Push(void *base, uint64_t offset, uint64_t amount) { + queue_.push(Entry(base, in_, offset, amount, buffer_size_)); + } + + const void *Top() const { + return queue_.top().Current(); + } + + void Pop() { + Entry top(queue_.top()); + queue_.pop(); + if (top.Increment(in_, buffer_size_, entry_size_)) + queue_.push(top); + } + + std::size_t Size() const { + return queue_.size(); + } + + bool Empty() const { + return queue_.empty(); + } + + private: + // Priority queue contains these entries. + class Entry { + public: + Entry() {} + + Entry(void *base, int fd, uint64_t offset, uint64_t amount, std::size_t buf_size) { + offset_ = offset; + remaining_ = amount; + buffer_end_ = static_cast(base) + buf_size; + Read(fd, buf_size); + } + + bool Increment(int fd, std::size_t buf_size, std::size_t entry_size) { + current_ += entry_size; + if (current_ != buffer_end_) return true; + return Read(fd, buf_size); + } + + const void *Current() const { return current_; } + + private: + bool Read(int fd, std::size_t buf_size) { + current_ = buffer_end_ - buf_size; + std::size_t amount; + if (static_cast(buf_size) < remaining_) { + amount = buf_size; + } else if (!remaining_) { + return false; + } else { + amount = remaining_; + buffer_end_ = current_ + remaining_; + } + PReadOrThrow(fd, current_, amount, offset_); + offset_ += amount; + assert(current_ <= buffer_end_); + remaining_ -= amount; + return true; + } + + // Buffer + uint8_t *current_, *buffer_end_; + // File + uint64_t remaining_, offset_; + }; + + // Wrapper comparison function for queue entries. + class Greater : public std::binary_function { + public: + explicit Greater(const Compare &compare) : compare_(compare) {} + + bool operator()(const Entry &first, const Entry &second) const { + return compare_(second.Current(), first.Current()); + } + + private: + const Compare compare_; + }; + + typedef std::priority_queue, Greater> Queue; + Queue queue_; + + const int in_; + const std::size_t buffer_size_; + const std::size_t entry_size_; +}; + +/* A worker object that merges. If the number of pieces to merge exceeds the + * arity, it outputs multiple sorted blocks, recording to out_offsets. + * However, users will only every see a single sorted block out output because + * Sort::Sorted insures the arity is higher than the number of pieces before + * returning this. + */ +template class MergingReader { + public: + MergingReader(int in, Offsets *in_offsets, Offsets *out_offsets, std::size_t buffer_size, std::size_t total_memory, const Compare &compare, const Combine &combine) : + compare_(compare), combine_(combine), + in_(in), + in_offsets_(in_offsets), out_offsets_(out_offsets), + buffer_size_(buffer_size), total_memory_(total_memory) {} + + void Run(const ChainPosition &position) { + Run(position, false); + } + + void Run(const ChainPosition &position, bool assert_one) { + // Special case: nothing to read. + if (!in_offsets_->RemainingBlocks()) { + Link l(position); + l.Poison(); + return; + } + // If there's just one entry, just read. + if (in_offsets_->RemainingBlocks() == 1) { + // Sequencing is important. + uint64_t offset = in_offsets_->TotalOffset(); + uint64_t amount = in_offsets_->NextSize(); + ReadSingle(offset, amount, position); + if (out_offsets_) out_offsets_->Append(amount); + return; + } + + Stream str(position); + scoped_malloc buffer(MallocOrThrow(total_memory_)); + uint8_t *const buffer_end = static_cast(buffer.get()) + total_memory_; + + const std::size_t entry_size = position.GetChain().EntrySize(); + + while (in_offsets_->RemainingBlocks()) { + // Use bigger buffers if there's less remaining. + uint64_t per_buffer = std::max(buffer_size_, total_memory_ / in_offsets_->RemainingBlocks()); + per_buffer -= per_buffer % entry_size; + assert(per_buffer); + + // Populate queue. + MergeQueue queue(in_, per_buffer, entry_size, compare_); + for (uint8_t *buf = static_cast(buffer.get()); + in_offsets_->RemainingBlocks() && (buf + std::min(per_buffer, in_offsets_->PeekSize()) <= buffer_end);) { + uint64_t offset = in_offsets_->TotalOffset(); + uint64_t size = in_offsets_->NextSize(); + queue.Push(buf, offset, size); + buf += static_cast(std::min(size, per_buffer)); + } + // This shouldn't happen but it's probably better to die than loop indefinitely. + if (queue.Size() < 2 && in_offsets_->RemainingBlocks()) { + std::cerr << "Bug in sort implementation: not merging at least two stripes." << std::endl; + abort(); + } + if (assert_one && in_offsets_->RemainingBlocks()) { + std::cerr << "Bug in sort implementation: should only be one merge group for lazy sort" << std::endl; + abort(); + } + + uint64_t written = 0; + // Merge including combiner support. + memcpy(str.Get(), queue.Top(), entry_size); + for (queue.Pop(); !queue.Empty(); queue.Pop()) { + if (!combine_(str.Get(), queue.Top(), compare_)) { + ++written; ++str; + memcpy(str.Get(), queue.Top(), entry_size); + } + } + ++written; ++str; + if (out_offsets_) + out_offsets_->Append(written * entry_size); + } + str.Poison(); + } + + private: + void ReadSingle(uint64_t offset, const uint64_t size, const ChainPosition &position) { + // Special case: only one to read. + const uint64_t end = offset + size; + const uint64_t block_size = position.GetChain().BlockSize(); + Link l(position); + for (; offset + block_size < end; ++l, offset += block_size) { + PReadOrThrow(in_, l->Get(), block_size, offset); + l->SetValidSize(block_size); + } + PReadOrThrow(in_, l->Get(), end - offset, offset); + l->SetValidSize(end - offset); + (++l).Poison(); + return; + } + + Compare compare_; + Combine combine_; + + int in_; + + protected: + Offsets *in_offsets_; + + private: + Offsets *out_offsets_; + + std::size_t buffer_size_; + std::size_t total_memory_; +}; + +// The lazy step owns the remaining files. This keeps track of them. +template class OwningMergingReader : public MergingReader { + private: + typedef MergingReader P; + public: + OwningMergingReader(int data, const Offsets &offsets, std::size_t buffer, std::size_t lazy, const Compare &compare, const Combine &combine) + : P(data, NULL, NULL, buffer, lazy, compare, combine), + data_(data), + offsets_(offsets) {} + + void Run(const ChainPosition &position) { + P::in_offsets_ = &offsets_; + scoped_fd data(data_); + scoped_fd offsets_file(offsets_.File()); + P::Run(position, true); + } + + private: + int data_; + Offsets offsets_; +}; + +// Don't use this directly. Worker that sorts blocks. +template class BlockSorter { + public: + BlockSorter(Offsets &offsets, const Compare &compare) : + offsets_(&offsets), compare_(compare) {} + + void Run(const ChainPosition &position) { + const std::size_t entry_size = position.GetChain().EntrySize(); + for (Link link(position); link; ++link) { + // Record the size of each block in a separate file. + offsets_->Append(link->ValidSize()); + void *end = static_cast(link->Get()) + link->ValidSize(); + std::sort( + SizedIt(link->Get(), entry_size), + SizedIt(end, entry_size), + compare_); + } + offsets_->FinishedAppending(); + } + + private: + Offsets *offsets_; + SizedCompare compare_; +}; + +class BadSortConfig : public Exception { + public: + BadSortConfig() throw() {} + ~BadSortConfig() throw() {} +}; + +template class Sort { + public: + Sort(Chain &in, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = Combine()) + : config_(config), + data_(MakeTemp(config.temp_prefix)), + offsets_file_(MakeTemp(config.temp_prefix)), offsets_(offsets_file_.get()), + compare_(compare), combine_(combine), + entry_size_(in.EntrySize()) { + UTIL_THROW_IF(!entry_size_, BadSortConfig, "Sorting entries of size 0"); + // Make buffer_size a multiple of the entry_size. + config_.buffer_size -= config_.buffer_size % entry_size_; + UTIL_THROW_IF(!config_.buffer_size, BadSortConfig, "Sort buffer too small"); + UTIL_THROW_IF(config_.total_memory < config_.buffer_size * 4, BadSortConfig, "Sorting memory " << config_.total_memory << " is too small for four buffers (two read and two write)."); + in >> BlockSorter(offsets_, compare_) >> WriteAndRecycle(data_.get()); + } + + uint64_t Size() const { + return SizeOrThrow(data_.get()); + } + + // Do merge sort, terminating when lazy merge could be done with the + // specified memory. Return the minimum memory necessary to do lazy merge. + std::size_t Merge(std::size_t lazy_memory) { + if (offsets_.RemainingBlocks() <= 1) return 0; + const uint64_t lazy_arity = std::max(1, lazy_memory / config_.buffer_size); + uint64_t size = Size(); + /* No overflow because + * offsets_.RemainingBlocks() * config_.buffer_size <= lazy_memory || + * size < lazy_memory + */ + if (offsets_.RemainingBlocks() <= lazy_arity || size <= static_cast(lazy_memory)) + return std::min(size, offsets_.RemainingBlocks() * config_.buffer_size); + + scoped_fd data2(MakeTemp(config_.temp_prefix)); + int fd_in = data_.get(), fd_out = data2.get(); + scoped_fd offsets2_file(MakeTemp(config_.temp_prefix)); + Offsets offsets2(offsets2_file.get()); + Offsets *offsets_in = &offsets_, *offsets_out = &offsets2; + + // Double buffered writing. + ChainConfig chain_config; + chain_config.entry_size = entry_size_; + chain_config.block_count = 2; + chain_config.total_memory = config_.buffer_size * 2; + Chain chain(chain_config); + + while (offsets_in->RemainingBlocks() > lazy_arity) { + if (size <= static_cast(lazy_memory)) break; + std::size_t reading_memory = config_.total_memory - 2 * config_.buffer_size; + if (size < static_cast(reading_memory)) { + reading_memory = static_cast(size); + } + SeekOrThrow(fd_in, 0); + chain >> + MergingReader( + fd_in, + offsets_in, offsets_out, + config_.buffer_size, + reading_memory, + compare_, combine_) >> + WriteAndRecycle(fd_out); + chain.Wait(); + offsets_out->FinishedAppending(); + ResizeOrThrow(fd_in, 0); + offsets_in->Reset(); + std::swap(fd_in, fd_out); + std::swap(offsets_in, offsets_out); + size = SizeOrThrow(fd_in); + } + + SeekOrThrow(fd_in, 0); + if (fd_in == data2.get()) { + data_.reset(data2.release()); + offsets_file_.reset(offsets2_file.release()); + offsets_ = offsets2; + } + if (offsets_.RemainingBlocks() <= 1) return 0; + // No overflow because the while loop exited. + return std::min(size, offsets_.RemainingBlocks() * static_cast(config_.buffer_size)); + } + + // Output to chain, using this amount of memory, maximum, for lazy merge + // sort. + void Output(Chain &out, std::size_t lazy_memory) { + Merge(lazy_memory); + out.SetProgressTarget(Size()); + out >> OwningMergingReader(data_.get(), offsets_, config_.buffer_size, lazy_memory, compare_, combine_); + data_.release(); + offsets_file_.release(); + } + + /* If a pipeline step is reading sorted input and writing to a different + * sort order, then there's a trade-off between using RAM to read lazily + * (avoiding copying the file) and using RAM to increase block size and, + * therefore, decrease the number of merge sort passes in the next + * iteration. + * + * Merge sort takes log_{arity}(pieces) passes. Thus, each time the chain + * block size is multiplied by arity, the number of output passes decreases + * by one. Up to a constant, then, log_{arity}(chain) is the number of + * passes saved. Chain simply divides the memory evenly over all blocks. + * + * Lazy sort saves this many passes (up to a constant) + * log_{arity}((memory-lazy)/block_count) + 1 + * Non-lazy sort saves this many passes (up to the same constant): + * log_{arity}(memory/block_count) + * Add log_{arity}(block_count) to both: + * log_{arity}(memory-lazy) + 1 versus log_{arity}(memory) + * Take arity to the power of both sizes (arity > 1) + * (memory - lazy)*arity versus memory + * Solve for lazy + * lazy = memory * (arity - 1) / arity + */ + std::size_t DefaultLazy() { + float arity = static_cast(config_.total_memory / config_.buffer_size); + return static_cast(static_cast(config_.total_memory) * (arity - 1.0) / arity); + } + + // Same as Output with default lazy memory setting. + void Output(Chain &out) { + Output(out, DefaultLazy()); + } + + // Completely merge sort and transfer ownership to the caller. + int StealCompleted() { + // Merge all the way. + Merge(0); + SeekOrThrow(data_.get(), 0); + offsets_file_.reset(); + return data_.release(); + } + + private: + SortConfig config_; + + scoped_fd data_; + + scoped_fd offsets_file_; + Offsets offsets_; + + const Compare compare_; + const Combine combine_; + const std::size_t entry_size_; +}; + +// returns bytes to be read on demand. +template uint64_t BlockingSort(Chain &chain, const SortConfig &config, const Compare &compare = Compare(), const Combine &combine = NeverCombine()) { + Sort sorter(chain, config, compare, combine); + chain.Wait(true); + uint64_t size = sorter.Size(); + sorter.Output(chain); + return size; +} + +} // namespace stream +} // namespace util + +#endif // UTIL_STREAM_SORT__ diff --git a/klm/util/stream/sort_test.cc b/klm/util/stream/sort_test.cc new file mode 100644 index 00000000..fd7705cd --- /dev/null +++ b/klm/util/stream/sort_test.cc @@ -0,0 +1,62 @@ +#include "util/stream/sort.hh" + +#define BOOST_TEST_MODULE SortTest +#include + +#include + +#include + +namespace util { namespace stream { namespace { + +struct CompareUInt64 : public std::binary_function { + bool operator()(const void *first, const void *second) const { + return *static_cast(first) < *reinterpret_cast(second); + } +}; + +const uint64_t kSize = 100000; + +struct Putter { + Putter(std::vector &shuffled) : shuffled_(shuffled) {} + + void Run(const ChainPosition &position) { + Stream put_shuffled(position); + for (uint64_t i = 0; i < shuffled_.size(); ++i, ++put_shuffled) { + *static_cast(put_shuffled.Get()) = shuffled_[i]; + } + put_shuffled.Poison(); + } + std::vector &shuffled_; +}; + +BOOST_AUTO_TEST_CASE(FromShuffled) { + std::vector shuffled; + shuffled.reserve(kSize); + for (uint64_t i = 0; i < kSize; ++i) { + shuffled.push_back(i); + } + std::random_shuffle(shuffled.begin(), shuffled.end()); + + ChainConfig config; + config.entry_size = 8; + config.total_memory = 800; + config.block_count = 3; + + SortConfig merge_config; + merge_config.temp_prefix = "sort_test_temp"; + merge_config.buffer_size = 800; + merge_config.total_memory = 3300; + + Chain chain(config); + chain >> Putter(shuffled); + BlockingSort(chain, merge_config, CompareUInt64(), NeverCombine()); + Stream sorted; + chain >> sorted >> kRecycle; + for (uint64_t i = 0; i < kSize; ++i, ++sorted) { + BOOST_CHECK_EQUAL(i, *static_cast(sorted.Get())); + } + BOOST_CHECK(!sorted); +} + +}}} // namespaces diff --git a/klm/util/stream/stream.hh b/klm/util/stream/stream.hh new file mode 100644 index 00000000..6ff45b82 --- /dev/null +++ b/klm/util/stream/stream.hh @@ -0,0 +1,74 @@ +#ifndef UTIL_STREAM_STREAM__ +#define UTIL_STREAM_STREAM__ + +#include "util/stream/chain.hh" + +#include + +#include +#include + +namespace util { +namespace stream { + +class Stream : boost::noncopyable { + public: + Stream() : current_(NULL), end_(NULL) {} + + void Init(const ChainPosition &position) { + entry_size_ = position.GetChain().EntrySize(); + block_size_ = position.GetChain().BlockSize(); + block_it_.Init(position); + StartBlock(); + } + + explicit Stream(const ChainPosition &position) { + Init(position); + } + + operator bool() const { return current_ != NULL; } + bool operator!() const { return current_ == NULL; } + + const void *Get() const { return current_; } + void *Get() { return current_; } + + void Poison() { + block_it_->SetValidSize(current_ - static_cast(block_it_->Get())); + ++block_it_; + block_it_.Poison(); + } + + Stream &operator++() { + assert(*this); + assert(current_ < end_); + current_ += entry_size_; + if (current_ == end_) { + ++block_it_; + StartBlock(); + } + return *this; + } + + private: + void StartBlock() { + for (; block_it_ && !block_it_->ValidSize(); ++block_it_) {} + current_ = static_cast(block_it_->Get()); + end_ = current_ + block_it_->ValidSize(); + } + + uint8_t *current_, *end_; + + std::size_t entry_size_; + std::size_t block_size_; + + Link block_it_; +}; + +inline Chain &operator>>(Chain &chain, Stream &stream) { + stream.Init(chain.Add()); + return chain; +} + +} // namespace stream +} // namespace util +#endif // UTIL_STREAM_STREAM__ diff --git a/klm/util/stream/stream_test.cc b/klm/util/stream/stream_test.cc new file mode 100644 index 00000000..6575d50d --- /dev/null +++ b/klm/util/stream/stream_test.cc @@ -0,0 +1,35 @@ +#include "util/stream/io.hh" + +#include "util/stream/stream.hh" +#include "util/file.hh" + +#define BOOST_TEST_MODULE StreamTest +#include + +#include + +namespace util { namespace stream { namespace { + +BOOST_AUTO_TEST_CASE(StreamTest) { + scoped_fd in(MakeTemp("io_test_temp")); + for (uint64_t i = 0; i < 100000; ++i) { + WriteOrThrow(in.get(), &i, sizeof(uint64_t)); + } + SeekOrThrow(in.get(), 0); + + ChainConfig config; + config.entry_size = 8; + config.total_memory = 100; + config.block_count = 12; + + Stream s; + Chain chain(config); + chain >> Read(in.get()) >> s >> kRecycle; + uint64_t i = 0; + for (; s; ++s, ++i) { + BOOST_CHECK_EQUAL(i, *static_cast(s.Get())); + } + BOOST_CHECK_EQUAL(100000ULL, i); +} + +}}} // namespaces diff --git a/klm/util/stream/timer.hh b/klm/util/stream/timer.hh new file mode 100644 index 00000000..50e94fe8 --- /dev/null +++ b/klm/util/stream/timer.hh @@ -0,0 +1,14 @@ +#ifndef UTIL_STREAM_TIMER__ +#define UTIL_STREAM_TIMER__ + +#include + +#if BOOST_VERSION >= 104800 +#include +#define UTIL_TIMER(str) boost::timer::auto_cpu_timer timer(std::cerr, 1, (str)) +#else +//#warning Using Boost older than 1.48. Timing information will not be available. +#define UTIL_TIMER(str) +#endif + +#endif // UTIL_STREAM_TIMER__ diff --git a/klm/util/thread_pool.hh b/klm/util/thread_pool.hh new file mode 100644 index 00000000..84e257ea --- /dev/null +++ b/klm/util/thread_pool.hh @@ -0,0 +1,95 @@ +#ifndef UTIL_THREAD_POOL__ +#define UTIL_THREAD_POOL__ + +#include "util/pcqueue.hh" + +#include +#include +#include + +#include + +#include + +namespace util { + +template class Worker : boost::noncopyable { + public: + typedef HandlerT Handler; + typedef typename Handler::Request Request; + + template Worker(PCQueue &in, Construct &construct, Request &poison) + : in_(in), handler_(construct), thread_(boost::ref(*this)), poison_(poison) {} + + // Only call from thread. + void operator()() { + Request request; + while (1) { + in_.Consume(request); + if (request == poison_) return; + try { + (*handler_)(request); + } + catch(std::exception &e) { + std::cerr << "Handler threw " << e.what() << std::endl; + abort(); + } + catch(...) { + std::cerr << "Handler threw an exception, dropping request" << std::endl; + abort(); + } + } + } + + void Join() { + thread_.join(); + } + + private: + PCQueue &in_; + + boost::optional handler_; + + boost::thread thread_; + + Request poison_; +}; + +template class ThreadPool : boost::noncopyable { + public: + typedef HandlerT Handler; + typedef typename Handler::Request Request; + + template ThreadPool(size_t queue_length, size_t workers, Construct handler_construct, Request poison) : in_(queue_length), poison_(poison) { + for (size_t i = 0; i < workers; ++i) { + workers_.push_back(new Worker(in_, handler_construct, poison)); + } + } + + ~ThreadPool() { + for (size_t i = 0; i < workers_.size(); ++i) { + Produce(poison_); + } + for (typename boost::ptr_vector >::iterator i = workers_.begin(); i != workers_.end(); ++i) { + i->Join(); + } + } + + void Produce(const Request &request) { + in_.Produce(request); + } + + // For adding to the queue. + PCQueue &In() { return in_; } + + private: + PCQueue in_; + + boost::ptr_vector > workers_; + + Request poison_; +}; + +} // namespace util + +#endif // UTIL_THREAD_POOL__ diff --git a/klm/util/usage.cc b/klm/util/usage.cc index e5cf76f0..16a004bb 100644 --- a/klm/util/usage.cc +++ b/klm/util/usage.cc @@ -1,13 +1,17 @@ #include "util/usage.hh" +#include "util/exception.hh" + #include #include +#include #include #include #if !defined(_WIN32) && !defined(_WIN64) #include #include +#include #endif namespace util { @@ -43,4 +47,60 @@ void PrintUsage(std::ostream &out) { #endif } +uint64_t GuessPhysicalMemory() { +#if defined(_WIN32) || defined(_WIN64) + return 0; +#elif defined(_SC_PHYS_PAGES) && defined(_SC_PAGESIZE) + long pages = sysconf(_SC_PHYS_PAGES); + if (pages == -1) return 0; + long page_size = sysconf(_SC_PAGESIZE); + if (page_size == -1) return 0; + return static_cast(pages) * static_cast(page_size); +#else + return 0; +#endif +} + +namespace { +class SizeParseError : public Exception { + public: + explicit SizeParseError(const std::string &str) throw() { + *this << "Failed to parse " << str << " into a memory size "; + } +}; + +template uint64_t ParseNum(const std::string &arg) { + std::stringstream stream(arg); + Num value; + stream >> value; + UTIL_THROW_IF_ARG(!stream, SizeParseError, (arg), "for the leading number."); + std::string after; + stream >> after; + UTIL_THROW_IF_ARG(after.size() > 1, SizeParseError, (arg), "because there are more than two characters after the number."); + std::string throwaway; + UTIL_THROW_IF_ARG(stream >> throwaway, SizeParseError, (arg), "because there was more cruft " << throwaway << " after the number."); + + // Silly sort, using kilobytes as your default unit. + if (after.empty()) after == "K"; + if (after == "%") { + uint64_t mem = GuessPhysicalMemory(); + UTIL_THROW_IF_ARG(!mem, SizeParseError, (arg), "because % was specified but the physical memory size could not be determined."); + return static_cast(value) * static_cast(mem) / 100.0; + } + + std::string units("bKMGTPEZY"); + std::string::size_type index = units.find(after[0]); + UTIL_THROW_IF_ARG(index == std::string::npos, SizeParseError, (arg), "the allowed suffixes are " << units << "%."); + for (std::string::size_type i = 0; i < index; ++i) { + value *= 1024; + } + return value; +} + +} // namespace + +uint64_t ParseSize(const std::string &arg) { + return arg.find('.') == std::string::npos ? ParseNum(arg) : ParseNum(arg); +} + } // namespace util diff --git a/klm/util/usage.hh b/klm/util/usage.hh index d331ff74..e19eda7b 100644 --- a/klm/util/usage.hh +++ b/klm/util/usage.hh @@ -1,8 +1,18 @@ #ifndef UTIL_USAGE__ #define UTIL_USAGE__ +#include #include +#include + +#include namespace util { void PrintUsage(std::ostream &to); + +// Determine how much physical memory there is. Return 0 on failure. +uint64_t GuessPhysicalMemory(); + +// Parse a size like unix sort. Sadly, this means the default multiplier is K. +uint64_t ParseSize(const std::string &arg); } // namespace util #endif // UTIL_USAGE__ diff --git a/training/crf/Makefile.am b/training/crf/Makefile.am index d37b224c..621ca803 100644 --- a/training/crf/Makefile.am +++ b/training/crf/Makefile.am @@ -7,21 +7,21 @@ bin_PROGRAMS = \ mpi_online_optimize mpi_online_optimize_SOURCES = mpi_online_optimize.cc -mpi_online_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_online_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz mpi_flex_optimize_SOURCES = mpi_flex_optimize.cc -mpi_flex_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_flex_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz mpi_extract_reachable_SOURCES = mpi_extract_reachable.cc -mpi_extract_reachable_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_extract_reachable_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz mpi_extract_features_SOURCES = mpi_extract_features.cc -mpi_extract_features_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_extract_features_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz mpi_batch_optimize_SOURCES = mpi_batch_optimize.cc cllh_observer.cc cllh_observer.h -mpi_batch_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_batch_optimize_LDADD = ../../training/utils/libtraining_utils.a ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz mpi_compute_cllh_SOURCES = mpi_compute_cllh.cc cllh_observer.cc cllh_observer.h -mpi_compute_cllh_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a -lz +mpi_compute_cllh_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lz AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -Wno-sign-compare -I$(top_srcdir)/training -I$(top_srcdir)/training/utils -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 8cf71078..844c790d 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h -dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a +dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/mira/Makefile.am b/training/mira/Makefile.am index 0084603d..fa4fb22d 100644 --- a/training/mira/Makefile.am +++ b/training/mira/Makefile.am @@ -1,6 +1,6 @@ bin_PROGRAMS = kbest_mira kbest_mira_SOURCES = kbest_mira.cc -kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a +kbest_mira_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval -- cgit v1.2.3 From 9e36263f64d6f5150f1b552dd77bde971d605376 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 19 Jan 2013 19:09:48 -0500 Subject: updated version of boost.m4 and automatically build kenneth's LM builder --- Makefile.am | 2 + configure.ac | 7 +- corpus/cut-corpus.pl | 2 +- klm/lm/builder/Makefile.am | 28 +++ klm/util/Makefile.am | 2 +- klm/util/double-conversion/Makefile.am | 2 +- klm/util/stream/Makefile.am | 20 ++ klm/util/stream/sort.hh | 3 +- m4/boost.m4 | 322 +++++++++++++++++++++++++-------- 9 files changed, 311 insertions(+), 77 deletions(-) create mode 100644 klm/lm/builder/Makefile.am create mode 100644 klm/util/stream/Makefile.am (limited to 'configure.ac') diff --git a/Makefile.am b/Makefile.am index c2444928..17190d27 100644 --- a/Makefile.am +++ b/Makefile.am @@ -5,8 +5,10 @@ SUBDIRS = \ utils \ mteval \ klm/util/double-conversion \ + klm/util/stream \ klm/util \ klm/lm \ + klm/lm/builder \ klm/search \ decoder \ training \ diff --git a/configure.ac b/configure.ac index d6030752..a1e5ad84 100644 --- a/configure.ac +++ b/configure.ac @@ -1,4 +1,4 @@ -AC_INIT([cdec],[2013-01-15]) +AC_INIT([cdec],[2013-01-19]) AC_CONFIG_SRCDIR([decoder/cdec.cc]) AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) @@ -15,7 +15,10 @@ BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS BOOST_SYSTEM BOOST_SERIALIZATION +BOOST_CHRONO +BOOST_TIMER BOOST_TEST +BOOST_THREADS AM_PATH_PYTHON AC_CHECK_HEADER(dlfcn.h,AC_DEFINE(HAVE_DLFCN_H)) AC_CHECK_LIB(dl, dlopen) @@ -111,8 +114,10 @@ AC_CONFIG_FILES([word-aligner/Makefile]) # KenLM stuff AC_CONFIG_FILES([klm/util/double-conversion/Makefile]) +AC_CONFIG_FILES([klm/util/stream/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) +AC_CONFIG_FILES([klm/lm/builder/Makefile]) AC_CONFIG_FILES([klm/search/Makefile]) # training stuff diff --git a/corpus/cut-corpus.pl b/corpus/cut-corpus.pl index 7daac0e2..0af3b23c 100755 --- a/corpus/cut-corpus.pl +++ b/corpus/cut-corpus.pl @@ -22,7 +22,7 @@ for my $ff (@ind) { while(<>) { chomp; - my @fields = split / \|\|\| /; + my @fields = split /\s*\|\|\|\s*/; my @sf; for my $i (@o) { my $y = $fields[$i]; diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am new file mode 100644 index 00000000..00444256 --- /dev/null +++ b/klm/lm/builder/Makefile.am @@ -0,0 +1,28 @@ +bin_PROGRAMS = builder + +builder_SOURCES = \ + main.cc \ + adjust_counts.cc \ + adjust_counts.hh \ + corpus_count.cc \ + corpus_count.hh \ + discount.hh \ + header_info.hh \ + initial_probabilities.cc \ + initial_probabilities.hh \ + interpolate.cc \ + interpolate.hh \ + joint_order.hh \ + multi_stream.hh \ + ngram.hh \ + ngram_stream.hh \ + pipeline.cc \ + pipeline.hh \ + print.cc \ + print.hh \ + sort.hh + +builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_TIMER_LIBS) $(BOOST_CHRONO_LIBS) $(BOOST_THREAD_LIBS) + +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm + diff --git a/klm/util/Makefile.am b/klm/util/Makefile.am index 294ebc0a..248cc844 100644 --- a/klm/util/Makefile.am +++ b/klm/util/Makefile.am @@ -54,4 +54,4 @@ libklm_util_a_SOURCES = \ string_piece.cc \ usage.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion diff --git a/klm/util/double-conversion/Makefile.am b/klm/util/double-conversion/Makefile.am index eb6616f7..dfcfb009 100644 --- a/klm/util/double-conversion/Makefile.am +++ b/klm/util/double-conversion/Makefile.am @@ -20,4 +20,4 @@ libklm_util_double_a_SOURCES = \ fixed-dtoa.cc \ strtod.cc -AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm -I$(top_srcdir)/klm/util/double-conversion diff --git a/klm/util/stream/Makefile.am b/klm/util/stream/Makefile.am new file mode 100644 index 00000000..f18cbedb --- /dev/null +++ b/klm/util/stream/Makefile.am @@ -0,0 +1,20 @@ +noinst_LIBRARIES = libklm_util_stream.a + +libklm_util_stream_a_SOURCES = \ + block.hh \ + chain.cc \ + chain.hh \ + config.hh \ + io.cc \ + io.hh \ + line_input.cc \ + line_input.hh \ + multi_progress.cc \ + multi_progress.hh \ + sort.hh \ + stream.hh \ + timer.hh + +AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm + +#-I$(top_srcdir)/klm/util/double-conversion diff --git a/klm/util/stream/sort.hh b/klm/util/stream/sort.hh index be6c11ea..df57fa41 100644 --- a/klm/util/stream/sort.hh +++ b/klm/util/stream/sort.hh @@ -259,7 +259,8 @@ template class MergingReader { while (in_offsets_->RemainingBlocks()) { // Use bigger buffers if there's less remaining. - uint64_t per_buffer = std::max(buffer_size_, total_memory_ / in_offsets_->RemainingBlocks()); + uint64_t per_buffer = std::max(static_cast(buffer_size_), + static_cast(total_memory_ / in_offsets_->RemainingBlocks())); per_buffer -= per_buffer % entry_size; assert(per_buffer); diff --git a/m4/boost.m4 b/m4/boost.m4 index 7e0ed075..027e039b 100644 --- a/m4/boost.m4 +++ b/m4/boost.m4 @@ -1,5 +1,5 @@ # boost.m4: Locate Boost headers and libraries for autoconf-based projects. -# Copyright (C) 2007, 2008, 2009 Benoit Sigoure +# Copyright (C) 2007, 2008, 2009, 2010, 2011 Benoit Sigoure # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by @@ -22,7 +22,7 @@ # along with this program. If not, see . m4_define([_BOOST_SERIAL], [m4_translit([ -# serial 12 +# serial 16 ], [# ], [])]) @@ -45,15 +45,19 @@ m4_define([_BOOST_SERIAL], [m4_translit([ # Note: THESE MACROS ASSUME THAT YOU USE LIBTOOL. If you don't, don't worry, # simply read the README, it will show you what to do step by step. -m4_pattern_forbid([^_?BOOST_]) +m4_pattern_forbid([^_?(BOOST|Boost)_]) # _BOOST_SED_CPP(SED-PROGRAM, PROGRAM, # [ACTION-IF-FOUND], [ACTION-IF-NOT-FOUND]) # -------------------------------------------------------- # Same as AC_EGREP_CPP, but leave the result in conftest.i. -# PATTERN is *not* overquoted, as in AC_EGREP_CPP. It could be useful -# to turn this into a macro which extracts the value of any macro. +# +# SED-PROGRAM is *not* overquoted, as in AC_EGREP_CPP. It is expanded +# in double-quotes, so escape your double quotes. +# +# It could be useful to turn this into a macro which extracts the +# value of any macro. m4_define([_BOOST_SED_CPP], [AC_LANG_PREPROC_REQUIRE()dnl AC_REQUIRE([AC_PROG_SED])dnl @@ -98,6 +102,7 @@ set x $boost_version_req 0 0 0 IFS=$boost_save_IFS shift boost_version_req=`expr "$[1]" '*' 100000 + "$[2]" '*' 100 + "$[3]"` +boost_version_req_string=$[1].$[2].$[3] AC_ARG_WITH([boost], [AS_HELP_STRING([--with-boost=DIR], [prefix of Boost $1 @<:@guess@:>@])])dnl @@ -113,9 +118,9 @@ if test x"$BOOST_ROOT" != x; then fi fi AC_SUBST([DISTCHECK_CONFIGURE_FLAGS], - ["$DISTCHECK_CONFIGURE_FLAGS '--with-boost=$with_boost'"]) + ["$DISTCHECK_CONFIGURE_FLAGS '--with-boost=$with_boost'"])dnl boost_save_CPPFLAGS=$CPPFLAGS - AC_CACHE_CHECK([for Boost headers version >= $boost_version_req], + AC_CACHE_CHECK([for Boost headers version >= $boost_version_req_string], [boost_cv_inc_path], [boost_cv_inc_path=no AC_LANG_PUSH([C++])dnl @@ -183,24 +188,25 @@ AC_LANG_POP([C++])dnl ]) case $boost_cv_inc_path in #( no) - boost_errmsg="cannot find Boost headers version >= $boost_version_req" + boost_errmsg="cannot find Boost headers version >= $boost_version_req_string" m4_if([$2], [], [AC_MSG_ERROR([$boost_errmsg])], [AC_MSG_NOTICE([$boost_errmsg])]) $2 ;;#( yes) BOOST_CPPFLAGS= - AC_DEFINE([HAVE_BOOST], [1], - [Defined if the requested minimum BOOST version is satisfied]) ;;#( *) - AC_SUBST([BOOST_CPPFLAGS], ["-I$boost_cv_inc_path"]) + AC_SUBST([BOOST_CPPFLAGS], ["-I$boost_cv_inc_path"])dnl ;; esac + if test x"$boost_cv_inc_path" != xno; then + AC_DEFINE([HAVE_BOOST], [1], + [Defined if the requested minimum BOOST version is satisfied]) AC_CACHE_CHECK([for Boost's header version], [boost_cv_lib_version], [m4_pattern_allow([^BOOST_LIB_VERSION$])dnl - _BOOST_SED_CPP([/^boost-lib-version = /{s///;s/\"//g;p;g;}], + _BOOST_SED_CPP([/^boost-lib-version = /{s///;s/\"//g;p;q;}], [#include boost-lib-version = BOOST_LIB_VERSION], [boost_cv_lib_version=`cat conftest.i`])]) @@ -211,6 +217,7 @@ boost-lib-version = BOOST_LIB_VERSION], AC_MSG_ERROR([invalid value: boost_major_version=$boost_major_version]) ;; esac +fi CPPFLAGS=$boost_save_CPPFLAGS ])# BOOST_REQUIRE @@ -220,7 +227,7 @@ CPPFLAGS=$boost_save_CPPFLAGS # on the command line, static versions of the libraries will be looked up. AC_DEFUN([BOOST_STATIC], [AC_ARG_ENABLE([static-boost], - [AC_HELP_STRING([--enable-static-boost], + [AS_HELP_STRING([--enable-static-boost], [Prefer the static boost libraries over the shared ones [no]])], [enable_static_boost=yes], [enable_static_boost=no])])# BOOST_STATIC @@ -290,6 +297,7 @@ dnl The else branch is huge and wasn't intended on purpose. AC_LANG_PUSH([C++])dnl AS_VAR_PUSHDEF([Boost_lib], [boost_cv_lib_$1])dnl AS_VAR_PUSHDEF([Boost_lib_LDFLAGS], [boost_cv_lib_$1_LDFLAGS])dnl +AS_VAR_PUSHDEF([Boost_lib_LDPATH], [boost_cv_lib_$1_LDPATH])dnl AS_VAR_PUSHDEF([Boost_lib_LIBS], [boost_cv_lib_$1_LIBS])dnl BOOST_FIND_HEADER([$3]) boost_save_CPPFLAGS=$CPPFLAGS @@ -371,8 +379,8 @@ for boost_rtopt_ in $boost_rtopt '' -d; do boost_tmp_lib=$with_boost test x"$with_boost" = x && boost_tmp_lib=${boost_cv_inc_path%/include} for boost_ldpath in "$boost_tmp_lib/lib" '' \ - /opt/local/lib /usr/local/lib /opt/lib /usr/lib \ - "$with_boost" C:/Boost/lib /lib /usr/lib64 /lib64 + /opt/local/lib* /usr/local/lib* /opt/lib* /usr/lib* \ + "$with_boost" C:/Boost/lib /lib* do test -e "$boost_ldpath" || continue boost_save_LDFLAGS=$LDFLAGS @@ -395,7 +403,16 @@ dnl generated only once above (before we start the for loops). LDFLAGS=$boost_save_LDFLAGS LIBS=$boost_save_LIBS if test x"$Boost_lib" = xyes; then - Boost_lib_LDFLAGS="-L$boost_ldpath -R$boost_ldpath" + # Because Boost is often installed in non-standard locations we want to + # hardcode the path to the library (with rpath). Here we assume that + # Libtool's macro was already invoked so we can steal its variable + # hardcode_libdir_flag_spec in order to get the right flags for ld. + boost_save_libdir=$libdir + libdir=$boost_ldpath + eval boost_rpath=\"$hardcode_libdir_flag_spec\" + libdir=$boost_save_libdir + Boost_lib_LDFLAGS="-L$boost_ldpath $boost_rpath" + Boost_lib_LDPATH="$boost_ldpath" break 6 else boost_failed_libs="$boost_failed_libs@$boost_lib@" @@ -410,14 +427,17 @@ rm -f conftest.$ac_objext ]) case $Boost_lib in #( no) _AC_MSG_LOG_CONFTEST - AC_MSG_ERROR([cannot not find the flags to link with Boost $1]) + AC_MSG_ERROR([cannot find the flags to link with Boost $1]) ;; esac -AC_SUBST(AS_TR_CPP([BOOST_$1_LDFLAGS]), [$Boost_lib_LDFLAGS]) -AC_SUBST(AS_TR_CPP([BOOST_$1_LIBS]), [$Boost_lib_LIBS]) +AC_SUBST(AS_TR_CPP([BOOST_$1_LDFLAGS]), [$Boost_lib_LDFLAGS])dnl +AC_SUBST(AS_TR_CPP([BOOST_$1_LDPATH]), [$Boost_lib_LDPATH])dnl +AC_SUBST([BOOST_LDPATH], [$Boost_lib_LDPATH])dnl +AC_SUBST(AS_TR_CPP([BOOST_$1_LIBS]), [$Boost_lib_LIBS])dnl CPPFLAGS=$boost_save_CPPFLAGS AS_VAR_POPDEF([Boost_lib])dnl AS_VAR_POPDEF([Boost_lib_LDFLAGS])dnl +AS_VAR_POPDEF([Boost_lib_LDPATH])dnl AS_VAR_POPDEF([Boost_lib_LIBS])dnl AC_LANG_POP([C++])dnl fi @@ -432,17 +452,31 @@ fi # The page http://beta.boost.org/doc/libs is useful: it gives the first release # version of each library (among other things). +# BOOST_DEFUN(LIBRARY, CODE) +# -------------------------- +# Define BOOST_ as a macro that runs CODE. +# +# Use indir to avoid the warning on underquoted macro name given to AC_DEFUN. +m4_define([BOOST_DEFUN], +[m4_indir([AC_DEFUN], + m4_toupper([BOOST_$1]), +[m4_pushdef([BOOST_Library], [$1])dnl +$2 +m4_popdef([BOOST_Library])dnl +]) +]) + # BOOST_ARRAY() # ------------- # Look for Boost.Array -AC_DEFUN([BOOST_ARRAY], +BOOST_DEFUN([Array], [BOOST_FIND_HEADER([boost/array.hpp])]) # BOOST_ASIO() # ------------ # Look for Boost.Asio (new in Boost 1.35). -AC_DEFUN([BOOST_ASIO], +BOOST_DEFUN([Asio], [AC_REQUIRE([BOOST_SYSTEM])dnl BOOST_FIND_HEADER([boost/asio.hpp])]) @@ -450,14 +484,41 @@ BOOST_FIND_HEADER([boost/asio.hpp])]) # BOOST_BIND() # ------------ # Look for Boost.Bind -AC_DEFUN([BOOST_BIND], +BOOST_DEFUN([Bind], [BOOST_FIND_HEADER([boost/bind.hpp])]) +# BOOST_CHRONO() +# ------------------ +# Look for Boost.Chrono +BOOST_DEFUN([Chrono], +[# Do we have to check for Boost.System? This link-time dependency was +# added as of 1.35.0. If we have a version <1.35, we must not attempt to +# find Boost.System as it didn't exist by then. +if test $boost_major_version -ge 135; then + BOOST_SYSTEM([$1]) +fi # end of the Boost.System check. +boost_system_save_LIBS=$LIBS +boost_system_save_LDFLAGS=$LDFLAGS +m4_pattern_allow([^BOOST_SYSTEM_(LIBS|LDFLAGS)$])dnl +LIBS="$LIBS $BOOST_SYSTEM_LIBS" +LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS" +BOOST_FIND_LIB([chrono], [$1], + [boost/chrono.hpp], + [boost::chrono::system_clock::time_point d = boost::chrono::system_clock::now();]) +if test $enable_static_boost = yes && test $boost_major_version -ge 135; then + AC_SUBST([BOOST_SYSTEM_LIBS], ["$BOOST_SYSTEM_LIBS $BOOST_SYSTEM_LIBS"]) +fi +LIBS=$boost_system_save_LIBS +LDFLAGS=$boost_system_save_LDFLAGS + +])# BOOST_CHRONO + + # BOOST_CONVERSION() # ------------------ # Look for Boost.Conversion (cast / lexical_cast) -AC_DEFUN([BOOST_CONVERSION], +BOOST_DEFUN([Conversion], [BOOST_FIND_HEADER([boost/cast.hpp]) BOOST_FIND_HEADER([boost/lexical_cast.hpp]) ])# BOOST_CONVERSION @@ -467,12 +528,31 @@ BOOST_FIND_HEADER([boost/lexical_cast.hpp]) # ----------------------------------- # Look for Boost.Date_Time. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_DATE_TIME], +BOOST_DEFUN([Date_Time], [BOOST_FIND_LIB([date_time], [$1], [boost/date_time/posix_time/posix_time.hpp], [boost::posix_time::ptime t;]) ])# BOOST_DATE_TIME +# BOOST_TIMER([PREFERRED-RT-OPT]) +# ----------------------------------- +# Look for Boost.Timer. For the documentation of PREFERRED-RT-OPT, see the +# documentation of BOOST_FIND_LIB above. +BOOST_DEFUN([Timer], +[#check for Boost.System +BOOST_SYSTEM([$1]) +boost_system_save_LIBS=$LIBS +boost_system_save_LDFLAGS=$LDFLAGS +m4_pattern_allow([^BOOST_SYSTEM_(LIBS|LDFLAGS)$])dnl +LIBS="$LIBS $BOOST_SYSTEM_LIBS" +LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS" +BOOST_FIND_LIB([timer], [$1], + [boost/timer/timer.hpp], + [boost::timer::auto_cpu_timer t;]) +AC_SUBST([BOOST_SYSTEM_LIBS], ["$BOOST_SYSTEM_LIBS $BOOST_SYSTEM_LIBS"]) +LIBS=$boost_system_save_LIBS +LDFLAGS=$boost_system_save_LDFLAGS +])# BOOST_TIMER # BOOST_FILESYSTEM([PREFERRED-RT-OPT]) # ------------------------------------ @@ -480,7 +560,7 @@ AC_DEFUN([BOOST_DATE_TIME], # the documentation of BOOST_FIND_LIB above. # Do not check for boost/filesystem.hpp because this file was introduced in # 1.34. -AC_DEFUN([BOOST_FILESYSTEM], +BOOST_DEFUN([Filesystem], [# Do we have to check for Boost.System? This link-time dependency was # added as of 1.35.0. If we have a version <1.35, we must not attempt to # find Boost.System as it didn't exist by then. @@ -494,6 +574,9 @@ LIBS="$LIBS $BOOST_SYSTEM_LIBS" LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS" BOOST_FIND_LIB([filesystem], [$1], [boost/filesystem/path.hpp], [boost::filesystem::path p;]) +if test $enable_static_boost = yes && test $boost_major_version -ge 135; then + AC_SUBST([BOOST_FILESYSTEM_LIBS], ["$BOOST_FILESYSTEM_LIBS $BOOST_SYSTEM_LIBS"]) +fi LIBS=$boost_filesystem_save_LIBS LDFLAGS=$boost_filesystem_save_LDFLAGS ])# BOOST_FILESYSTEM @@ -502,7 +585,7 @@ LDFLAGS=$boost_filesystem_save_LDFLAGS # BOOST_FOREACH() # --------------- # Look for Boost.Foreach -AC_DEFUN([BOOST_FOREACH], +BOOST_DEFUN([Foreach], [BOOST_FIND_HEADER([boost/foreach.hpp])]) @@ -513,14 +596,14 @@ AC_DEFUN([BOOST_FOREACH], # standalone. It can't be compiled because it triggers the following error: # boost/format/detail/config_macros.hpp:88: error: 'locale' in namespace 'std' # does not name a type -AC_DEFUN([BOOST_FORMAT], +BOOST_DEFUN([Format], [BOOST_FIND_HEADER([boost/format.hpp])]) # BOOST_FUNCTION() # ---------------- # Look for Boost.Function -AC_DEFUN([BOOST_FUNCTION], +BOOST_DEFUN([Function], [BOOST_FIND_HEADER([boost/function.hpp])]) @@ -528,37 +611,60 @@ AC_DEFUN([BOOST_FUNCTION], # ------------------------------- # Look for Boost.Graphs. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_GRAPH], +BOOST_DEFUN([Graph], [BOOST_FIND_LIB([graph], [$1], [boost/graph/adjacency_list.hpp], [boost::adjacency_list<> g;]) ])# BOOST_GRAPH # BOOST_IOSTREAMS([PREFERRED-RT-OPT]) -# ------------------------------- +# ----------------------------------- # Look for Boost.IOStreams. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_IOSTREAMS], +BOOST_DEFUN([IOStreams], [BOOST_FIND_LIB([iostreams], [$1], [boost/iostreams/device/file_descriptor.hpp], - [boost::iostreams::file_descriptor fd(0); fd.close();]) + [boost::iostreams::file_descriptor fd; fd.close();]) ])# BOOST_IOSTREAMS # BOOST_HASH() # ------------ # Look for Boost.Functional/Hash -AC_DEFUN([BOOST_HASH], +BOOST_DEFUN([Hash], [BOOST_FIND_HEADER([boost/functional/hash.hpp])]) # BOOST_LAMBDA() # -------------- # Look for Boost.Lambda -AC_DEFUN([BOOST_LAMBDA], +BOOST_DEFUN([Lambda], [BOOST_FIND_HEADER([boost/lambda/lambda.hpp])]) +# BOOST_LOG([PREFERRED-RT-OPT]) +# ----------------------------- +# Look for Boost.Log For the documentation of PREFERRED-RT-OPT, see the +# documentation of BOOST_FIND_LIB above. +BOOST_DEFUN([Log], +[BOOST_FIND_LIB([log], [$1], + [boost/log/core/core.hpp], + [boost::log::attribute a; a.get_value();]) +])# BOOST_LOG + + +# BOOST_LOG_SETUP([PREFERRED-RT-OPT]) +# ----------------------------------- +# Look for Boost.Log For the documentation of PREFERRED-RT-OPT, see the +# documentation of BOOST_FIND_LIB above. +BOOST_DEFUN([Log_Setup], +[AC_REQUIRE([BOOST_LOG])dnl +BOOST_FIND_LIB([log_setup], [$1], + [boost/log/utility/init/from_settings.hpp], + [boost::log::basic_settings bs; bs.empty();]) +])# BOOST_LOG_SETUP + + # BOOST_MATH() # ------------ # Look for Boost.Math @@ -567,21 +673,21 @@ AC_DEFUN([BOOST_LAMBDA], # libboost_math_c99f, libboost_math_c99l, libboost_math_tr1, # libboost_math_tr1f, libboost_math_tr1l). This macro must be fixed to do the # right thing anyway. -AC_DEFUN([BOOST_MATH], +BOOST_DEFUN([Math], [BOOST_FIND_HEADER([boost/math/special_functions.hpp])]) # BOOST_MULTIARRAY() # ------------------ # Look for Boost.MultiArray -AC_DEFUN([BOOST_MULTIARRAY], +BOOST_DEFUN([MultiArray], [BOOST_FIND_HEADER([boost/multi_array.hpp])]) # BOOST_NUMERIC_CONVERSION() # -------------------------- # Look for Boost.NumericConversion (policy-based numeric conversion) -AC_DEFUN([BOOST_NUMERIC_CONVERSION], +BOOST_DEFUN([Numeric_Conversion], [BOOST_FIND_HEADER([boost/numeric/conversion/converter.hpp]) ])# BOOST_NUMERIC_CONVERSION @@ -589,32 +695,76 @@ AC_DEFUN([BOOST_NUMERIC_CONVERSION], # BOOST_OPTIONAL() # ---------------- # Look for Boost.Optional -AC_DEFUN([BOOST_OPTIONAL], +BOOST_DEFUN([Optional], [BOOST_FIND_HEADER([boost/optional.hpp])]) # BOOST_PREPROCESSOR() # -------------------- # Look for Boost.Preprocessor -AC_DEFUN([BOOST_PREPROCESSOR], +BOOST_DEFUN([Preprocessor], [BOOST_FIND_HEADER([boost/preprocessor/repeat.hpp])]) +# BOOST_UNORDERED() +# ----------------- +# Look for Boost.Unordered +BOOST_DEFUN([Unordered], +[BOOST_FIND_HEADER([boost/unordered_map.hpp])]) + + +# BOOST_UUID() +# ------------ +# Look for Boost.Uuid +BOOST_DEFUN([Uuid], +[BOOST_FIND_HEADER([boost/uuid/uuid.hpp])]) + + # BOOST_PROGRAM_OPTIONS([PREFERRED-RT-OPT]) # ----------------------------------------- -# Look for Boost.Program_options. For the documentation of PREFERRED-RT-OPT, see -# the documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_PROGRAM_OPTIONS], +# Look for Boost.Program_options. For the documentation of PREFERRED-RT-OPT, +# see the documentation of BOOST_FIND_LIB above. +BOOST_DEFUN([Program_Options], [BOOST_FIND_LIB([program_options], [$1], [boost/program_options.hpp], [boost::program_options::options_description d("test");]) ])# BOOST_PROGRAM_OPTIONS + +# _BOOST_PYTHON_CONFIG(VARIABLE, FLAG) +# ------------------------------------ +# Save VARIABLE, and define it via `python-config --FLAG`. +# Substitute BOOST_PYTHON_VARIABLE. +m4_define([_BOOST_PYTHON_CONFIG], +[AC_SUBST([BOOST_PYTHON_$1], + [`python-config --$2 2>/dev/null`])dnl +boost_python_save_$1=$$1 +$1="$$1 $BOOST_PYTHON_$1"]) + + +# BOOST_PYTHON([PREFERRED-RT-OPT]) +# -------------------------------- +# Look for Boost.Python. For the documentation of PREFERRED-RT-OPT, +# see the documentation of BOOST_FIND_LIB above. +BOOST_DEFUN([Python], +[_BOOST_PYTHON_CONFIG([CPPFLAGS], [includes]) +_BOOST_PYTHON_CONFIG([LDFLAGS], [ldflags]) +_BOOST_PYTHON_CONFIG([LIBS], [libs]) +m4_pattern_allow([^BOOST_PYTHON_MODULE$])dnl +BOOST_FIND_LIB([python], [$1], + [boost/python.hpp], + [], [BOOST_PYTHON_MODULE(empty) {}]) +CPPFLAGS=$boost_python_save_CPPFLAGS +LDFLAGS=$boost_python_save_LDFLAGS +LIBS=$boost_python_save_LIBS +])# BOOST_PYTHON + + # BOOST_REF() # ----------- # Look for Boost.Ref -AC_DEFUN([BOOST_REF], +BOOST_DEFUN([Ref], [BOOST_FIND_HEADER([boost/ref.hpp])]) @@ -622,7 +772,7 @@ AC_DEFUN([BOOST_REF], # ------------------------------- # Look for Boost.Regex. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_REGEX], +BOOST_DEFUN([Regex], [BOOST_FIND_LIB([regex], [$1], [boost/regex.hpp], [boost::regex exp("*"); boost::regex_match("foo", exp);]) @@ -633,19 +783,19 @@ AC_DEFUN([BOOST_REGEX], # --------------------------------------- # Look for Boost.Serialization. For the documentation of PREFERRED-RT-OPT, see # the documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_SERIALIZATION], +BOOST_DEFUN([Serialization], [BOOST_FIND_LIB([serialization], [$1], [boost/archive/text_oarchive.hpp], [std::ostream* o = 0; // Cheap way to get an ostream... boost::archive::text_oarchive t(*o);]) -])# BOOST_SIGNALS +])# BOOST_SERIALIZATION # BOOST_SIGNALS([PREFERRED-RT-OPT]) # --------------------------------- # Look for Boost.Signals. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_SIGNALS], +BOOST_DEFUN([Signals], [BOOST_FIND_LIB([signals], [$1], [boost/signal.hpp], [boost::signal s;]) @@ -655,7 +805,7 @@ AC_DEFUN([BOOST_SIGNALS], # BOOST_SMART_PTR() # ----------------- # Look for Boost.SmartPtr -AC_DEFUN([BOOST_SMART_PTR], +BOOST_DEFUN([Smart_Ptr], [BOOST_FIND_HEADER([boost/scoped_ptr.hpp]) BOOST_FIND_HEADER([boost/shared_ptr.hpp]) ]) @@ -664,14 +814,14 @@ BOOST_FIND_HEADER([boost/shared_ptr.hpp]) # BOOST_STATICASSERT() # -------------------- # Look for Boost.StaticAssert -AC_DEFUN([BOOST_STATICASSERT], +BOOST_DEFUN([StaticAssert], [BOOST_FIND_HEADER([boost/static_assert.hpp])]) # BOOST_STRING_ALGO() # ------------------- # Look for Boost.StringAlgo -AC_DEFUN([BOOST_STRING_ALGO], +BOOST_DEFUN([String_Algo], [BOOST_FIND_HEADER([boost/algorithm/string.hpp]) ]) @@ -681,7 +831,7 @@ AC_DEFUN([BOOST_STRING_ALGO], # Look for Boost.System. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. This library was introduced in Boost # 1.35.0. -AC_DEFUN([BOOST_SYSTEM], +BOOST_DEFUN([System], [BOOST_FIND_LIB([system], [$1], [boost/system/error_code.hpp], [boost::system::error_code e; e.clear();]) @@ -692,7 +842,7 @@ AC_DEFUN([BOOST_SYSTEM], # ------------------------------ # Look for Boost.Test. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_TEST], +BOOST_DEFUN([Test], [m4_pattern_allow([^BOOST_CHECK$])dnl BOOST_FIND_LIB([unit_test_framework], [$1], [boost/test/unit_test.hpp], [BOOST_CHECK(2 == 2);], @@ -707,25 +857,49 @@ BOOST_FIND_LIB([unit_test_framework], [$1], # Look for Boost.Thread. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. # FIXME: Provide an alias "BOOST_THREAD". -AC_DEFUN([BOOST_THREADS], +BOOST_DEFUN([Threads], [dnl Having the pthread flag is required at least on GCC3 where dnl boost/thread.hpp would complain if we try to compile without dnl -pthread on GNU/Linux. AC_REQUIRE([_BOOST_PTHREAD_FLAG])dnl boost_threads_save_LIBS=$LIBS +boost_threads_save_LDFLAGS=$LDFLAGS boost_threads_save_CPPFLAGS=$CPPFLAGS -LIBS="$LIBS $boost_cv_pthread_flag" +# Link-time dependency from thread to system was added as of 1.49.0. +if test $boost_major_version -ge 149; then +BOOST_SYSTEM([$1]) +fi # end of the Boost.System check. +m4_pattern_allow([^BOOST_SYSTEM_(LIBS|LDFLAGS)$])dnl +LIBS="$LIBS $BOOST_SYSTEM_LIBS $boost_cv_pthread_flag" +LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS" # Yes, we *need* to put the -pthread thing in CPPFLAGS because with GCC3, # boost/thread.hpp will trigger a #error if -pthread isn't used: # boost/config/requires_threads.hpp:47:5: #error "Compiler threading support # is not turned on. Please set the correct command line options for # threading: -pthread (Linux), -pthreads (Solaris) or -mthreads (Mingw32)" CPPFLAGS="$CPPFLAGS $boost_cv_pthread_flag" -BOOST_FIND_LIB([thread], [$1], - [boost/thread.hpp], [boost::thread t; boost::mutex m;]) -BOOST_THREAD_LIBS="$BOOST_THREAD_LIBS $boost_cv_pthread_flag" + +# When compiling for the Windows platform, the threads library is named +# differently. +case $host_os in + (*mingw*) + BOOST_FIND_LIB([thread_win32], [$1], + [boost/thread.hpp], [boost::thread t; boost::mutex m;]) + BOOST_THREAD_LDFLAGS=$BOOST_THREAD_WIN32_LDFLAGS + BOOST_THREAD_LDPATH=$BOOST_THREAD_WIN32_LDPATH + BOOST_THREAD_LIBS=$BOOST_THREAD_WIN32_LIBS + ;; + (*) + BOOST_FIND_LIB([thread], [$1], + [boost/thread.hpp], [boost::thread t; boost::mutex m;]) + ;; +esac + +BOOST_THREAD_LIBS="$BOOST_THREAD_LIBS $BOOST_SYSTEM_LIBS $boost_cv_pthread_flag" +BOOST_THREAD_LDFLAGS="$BOOST_SYSTEM_LDFLAGS" BOOST_CPPFLAGS="$BOOST_CPPFLAGS $boost_cv_pthread_flag" LIBS=$boost_threads_save_LIBS +LDFLAGS=$boost_threads_save_LDFLAGS CPPFLAGS=$boost_threads_save_CPPFLAGS ])# BOOST_THREADS @@ -733,14 +907,14 @@ CPPFLAGS=$boost_threads_save_CPPFLAGS # BOOST_TOKENIZER() # ----------------- # Look for Boost.Tokenizer -AC_DEFUN([BOOST_TOKENIZER], +BOOST_DEFUN([Tokenizer], [BOOST_FIND_HEADER([boost/tokenizer.hpp])]) # BOOST_TRIBOOL() # --------------- # Look for Boost.Tribool -AC_DEFUN([BOOST_TRIBOOL], +BOOST_DEFUN([Tribool], [BOOST_FIND_HEADER([boost/logic/tribool_fwd.hpp]) BOOST_FIND_HEADER([boost/logic/tribool.hpp]) ]) @@ -749,14 +923,14 @@ BOOST_FIND_HEADER([boost/logic/tribool.hpp]) # BOOST_TUPLE() # ------------- # Look for Boost.Tuple -AC_DEFUN([BOOST_TUPLE], +BOOST_DEFUN([Tuple], [BOOST_FIND_HEADER([boost/tuple/tuple.hpp])]) # BOOST_TYPETRAITS() # -------------------- # Look for Boost.TypeTraits -AC_DEFUN([BOOST_TYPETRAITS], +BOOST_DEFUN([TypeTraits], [BOOST_FIND_HEADER([boost/type_traits.hpp])]) @@ -764,14 +938,14 @@ AC_DEFUN([BOOST_TYPETRAITS], # --------------- # Look for Boost.Utility (noncopyable, result_of, base-from-member idiom, # etc.) -AC_DEFUN([BOOST_UTILITY], +BOOST_DEFUN([Utility], [BOOST_FIND_HEADER([boost/utility.hpp])]) # BOOST_VARIANT() # --------------- # Look for Boost.Variant. -AC_DEFUN([BOOST_VARIANT], +BOOST_DEFUN([Variant], [BOOST_FIND_HEADER([boost/variant/variant_fwd.hpp]) BOOST_FIND_HEADER([boost/variant.hpp])]) @@ -782,15 +956,15 @@ BOOST_FIND_HEADER([boost/variant.hpp])]) # call BOOST_THREADS first. # Look for Boost.Wave. For the documentation of PREFERRED-RT-OPT, see the # documentation of BOOST_FIND_LIB above. -AC_DEFUN([BOOST_WAVE], +BOOST_DEFUN([Wave], [AC_REQUIRE([BOOST_FILESYSTEM])dnl AC_REQUIRE([BOOST_DATE_TIME])dnl boost_wave_save_LIBS=$LIBS boost_wave_save_LDFLAGS=$LDFLAGS m4_pattern_allow([^BOOST_((FILE)?SYSTEM|DATE_TIME|THREAD)_(LIBS|LDFLAGS)$])dnl -LIBS="$LIBS $BOOST_SYSTEM_LIBS $BOOST_FILESYSTEM_LIBS $BOOST_DATE_TIME_LIBS\ +LIBS="$LIBS $BOOST_SYSTEM_LIBS $BOOST_FILESYSTEM_LIBS $BOOST_DATE_TIME_LIBS \ $BOOST_THREAD_LIBS" -LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS $BOOST_FILESYSTEM_LDFLAGS\ +LDFLAGS="$LDFLAGS $BOOST_SYSTEM_LDFLAGS $BOOST_FILESYSTEM_LDFLAGS \ $BOOST_DATE_TIME_LDFLAGS $BOOST_THREAD_LDFLAGS" BOOST_FIND_LIB([wave], [$1], [boost/wave.hpp], @@ -803,7 +977,7 @@ LDFLAGS=$boost_wave_save_LDFLAGS # BOOST_XPRESSIVE() # ----------------- # Look for Boost.Xpressive (new since 1.36.0). -AC_DEFUN([BOOST_XPRESSIVE], +BOOST_DEFUN([Xpressive], [BOOST_FIND_HEADER([boost/xpressive/xpressive.hpp])]) @@ -893,8 +1067,9 @@ AC_DEFUN([_BOOST_FIND_COMPILER_TAG], [AC_REQUIRE([AC_PROG_CXX])dnl AC_REQUIRE([AC_CANONICAL_HOST])dnl AC_CACHE_CHECK([for the toolset name used by Boost for $CXX], [boost_cv_lib_tag], -[AC_LANG_PUSH([C++])dnl - boost_cv_lib_tag=unknown +[boost_cv_lib_tag=unknown +if test x$boost_cv_inc_path != xno; then + AC_LANG_PUSH([C++])dnl # The following tests are mostly inspired by boost/config/auto_link.hpp # The list is sorted to most recent/common to oldest compiler (in order # to increase the likelihood of finding the right compiler with the @@ -908,8 +1083,12 @@ AC_CACHE_CHECK([for the toolset name used by Boost for $CXX], [boost_cv_lib_tag] # como, edg, kcc, bck, mp, sw, tru, xlc # I'm not sure about my test for `il' (be careful: Intel's ICC pre-defines # the same defines as GCC's). - # TODO: Move the test on GCC 4.4 up once it's released. for i in \ + _BOOST_gcc_test(4, 8) \ + _BOOST_gcc_test(4, 7) \ + _BOOST_gcc_test(4, 6) \ + _BOOST_gcc_test(4, 5) \ + _BOOST_gcc_test(4, 4) \ _BOOST_gcc_test(4, 3) \ _BOOST_gcc_test(4, 2) \ _BOOST_gcc_test(4, 1) \ @@ -929,7 +1108,6 @@ AC_CACHE_CHECK([for the toolset name used by Boost for $CXX], [boost_cv_lib_tag] "defined __ICC && (defined __unix || defined __unix__) @ il" \ "defined __ICL @ iw" \ "defined _MSC_VER && _MSC_VER == 1300 @ vc7" \ - _BOOST_gcc_test(4, 4) \ _BOOST_gcc_test(2, 95) \ "defined __MWERKS__ && __MWERKS__ <= 0x32FF @ cw9" \ "defined _MSC_VER && _MSC_VER < 1300 && !defined UNDER_CE @ vc6" \ @@ -969,7 +1147,7 @@ AC_LANG_POP([C++])dnl boost_cv_lib_tag= ;; esac -])dnl end of AC_CACHE_CHECK +fi])dnl end of AC_CACHE_CHECK ])# _BOOST_FIND_COMPILER_TAG -- cgit v1.2.3 From 333afa9a19a812ad0945fa3edaeb3f4314e57d42 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 19 Jan 2013 19:46:46 -0500 Subject: deal with chrono/timer --- .travis.yml | 1 + configure.ac | 8 +++++--- python/setup.py.in | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) (limited to 'configure.ac') diff --git a/.travis.yml b/.travis.yml index 14b625ab..38dbd2fa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,7 @@ python: - "2.7" before_script: - sudo apt-get install libboost-program-options-dev + - sudo apt-cache search boost - sudo apt-get install libboost-serialization-dev - sudo apt-get install libboost-regex-dev - sudo apt-get install libboost-test-dev diff --git a/configure.ac b/configure.ac index a1e5ad84..c3e3251c 100644 --- a/configure.ac +++ b/configure.ac @@ -15,8 +15,10 @@ BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS BOOST_SYSTEM BOOST_SERIALIZATION -BOOST_CHRONO -BOOST_TIMER +if test $boost_major_version -ge 148; then + BOOST_CHRONO + BOOST_TIMER +fi BOOST_TEST BOOST_THREADS AM_PATH_PYTHON @@ -117,8 +119,8 @@ AC_CONFIG_FILES([klm/util/double-conversion/Makefile]) AC_CONFIG_FILES([klm/util/stream/Makefile]) AC_CONFIG_FILES([klm/util/Makefile]) AC_CONFIG_FILES([klm/lm/Makefile]) -AC_CONFIG_FILES([klm/lm/builder/Makefile]) AC_CONFIG_FILES([klm/search/Makefile]) +AC_CONFIG_FILES([klm/lm/builder/Makefile]) # training stuff AC_CONFIG_FILES([training/Makefile]) diff --git a/python/setup.py.in b/python/setup.py.in index fa8a9f5e..925cb196 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -3,7 +3,7 @@ from distutils.extension import Extension import re INC = ['..', 'src/', '../decoder', '../utils', '../mteval'] -LIB = ['../decoder', '../utils', '../mteval', '../training/utils', '../klm/lm', '../klm/util', '../klm/search'] +LIB = ['../decoder', '../utils', '../mteval', '../training/utils', '../klm/lm', '../klm/util', '../klm/util/double-conversion', '../klm/search'] # Set automatically by configure LIBS = re.findall('-l([^\s]+)', '@LIBS@') @@ -17,7 +17,7 @@ ext_modules = [ sources=['src/_cdec.cpp'], include_dirs=INC, library_dirs=LIB, - libraries=['cdec', 'utils', 'mteval', 'training_utils', 'klm', 'klm_util', 'ksearch'] + LIBS, + libraries=['cdec', 'utils', 'mteval', 'training_utils', 'klm', 'klm_util', 'klm_util_double', 'ksearch'] + LIBS, extra_compile_args=CPPFLAGS, extra_link_args=LDFLAGS), Extension(name='cdec.sa._sa', -- cgit v1.2.3 From ccf63227ee70fbdad365790dc763860463d2c9f3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 20 Jan 2013 09:44:18 -0500 Subject: clean up a bit --- .gitignore | 2 ++ configure.ac | 2 +- word-aligner/Makefile.am | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) (limited to 'configure.ac') diff --git a/.gitignore b/.gitignore index 56372ad4..bde0f6a5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,9 @@ example_extff/ff_example.lo example_extff/libff_example.la mteval/meteor_jar.cc +training/utils/grammar_convert *.a +*.trs *.aux *.bbl *.blg diff --git a/configure.ac b/configure.ac index c3e3251c..c474c050 100644 --- a/configure.ac +++ b/configure.ac @@ -1,4 +1,4 @@ -AC_INIT([cdec],[2013-01-19]) +AC_INIT([cdec],[2013-01-20]) AC_CONFIG_SRCDIR([decoder/cdec.cc]) AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) diff --git a/word-aligner/Makefile.am b/word-aligner/Makefile.am index a195cc5a..1f7f78ae 100644 --- a/word-aligner/Makefile.am +++ b/word-aligner/Makefile.am @@ -3,4 +3,6 @@ bin_PROGRAMS = fast_align fast_align_SOURCES = fast_align.cc ttables.cc da.h ttables.h fast_align_LDADD = ../utils/libutils.a +EXTRA_DIST = aligner.pl ortho-norm support makefiles stemmers + AM_CPPFLAGS = -W -Wall -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/training -- cgit v1.2.3 From 608886384da40aedfabd629c882b8ea9b3f6348e Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 20 Jan 2013 10:08:57 -0500 Subject: remove dependency on timer/chrono --- .travis.yml | 2 -- configure.ac | 4 ---- klm/lm/builder/Makefile.am | 2 +- 3 files changed, 1 insertion(+), 7 deletions(-) (limited to 'configure.ac') diff --git a/.travis.yml b/.travis.yml index c67c5b43..d2d25903 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,8 +7,6 @@ before_script: - sudo apt-get install libboost-regex1.48-dev - sudo apt-get install libboost-test1.48-dev - sudo apt-get install libboost-system1.48-dev - - sudo apt-get install libboost-timer1.48-dev - - sudo apt-get install libboost-chrono1.48-dev - sudo apt-get install libboost-thread1.48-dev - sudo apt-get install flex - autoreconf -ifv diff --git a/configure.ac b/configure.ac index c474c050..402ddd0a 100644 --- a/configure.ac +++ b/configure.ac @@ -15,10 +15,6 @@ BOOST_REQUIRE([1.44]) BOOST_PROGRAM_OPTIONS BOOST_SYSTEM BOOST_SERIALIZATION -if test $boost_major_version -ge 148; then - BOOST_CHRONO - BOOST_TIMER -fi BOOST_TEST BOOST_THREADS AM_PATH_PYTHON diff --git a/klm/lm/builder/Makefile.am b/klm/lm/builder/Makefile.am index 00444256..b5c147fd 100644 --- a/klm/lm/builder/Makefile.am +++ b/klm/lm/builder/Makefile.am @@ -22,7 +22,7 @@ builder_SOURCES = \ print.hh \ sort.hh -builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_TIMER_LIBS) $(BOOST_CHRONO_LIBS) $(BOOST_THREAD_LIBS) +builder_LDADD = ../libklm.a ../../util/double-conversion/libklm_util_double.a ../../util/stream/libklm_util_stream.a ../../util/libklm_util.a $(BOOST_THREAD_LIBS) AM_CPPFLAGS = -W -Wall -I$(top_srcdir)/klm -- cgit v1.2.3 From bd65d6a4492e172a7840c010c5414ceb6f6acd56 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sat, 9 Mar 2013 00:05:40 -0500 Subject: bump release --- configure.ac | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'configure.ac') diff --git a/configure.ac b/configure.ac index 402ddd0a..98deac86 100644 --- a/configure.ac +++ b/configure.ac @@ -1,4 +1,4 @@ -AC_INIT([cdec],[2013-01-20]) +AC_INIT([cdec],[2013-03-08]) AC_CONFIG_SRCDIR([decoder/cdec.cc]) AM_INIT_AUTOMAKE AC_CONFIG_HEADERS(config.h) -- cgit v1.2.3