From 81166355dfa0ccbae5413fd1ee896c43ce2f5d96 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 9 Oct 2014 13:59:21 -0400 Subject: patch for using ax_pthread --- training/utils/Makefile.am | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/training/utils/Makefile.am b/training/utils/Makefile.am index 27c6e344..edaaf3d4 100644 --- a/training/utils/Makefile.am +++ b/training/utils/Makefile.am @@ -12,10 +12,12 @@ noinst_PROGRAMS = \ EXTRA_DIST = decode-and-evaluate.pl libcall.pl parallelize.pl sentserver_SOURCES = sentserver.cc -sentserver_LDFLAGS = -pthread +sentserver_LDFLAGS = $(PTHREAD_LIBS) +sentserver_CXXFLAGS = $(PTHREAD_CFLAGS) sentclient_SOURCES = sentclient.cc -sentclient_LDFLAGS = -pthread +sentclient_LDFLAGS = $(PTHREAD_LIBS) +sentclient_CXXFLAGS = $(PTHREAD_CFLAGS) TESTS = lbfgs_test optimize_test -- cgit v1.2.3 From 011a87cfe6d9cc702cb4a8a6d9a765556e460af9 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 19 Oct 2014 05:24:21 -0400 Subject: stop switch to boost serialization for hypergraph IO --- decoder/decoder.cc | 4 +- decoder/forest_writer.cc | 4 +- decoder/hg.h | 52 +++++++++++ decoder/hg_io.cc | 101 +++------------------ decoder/hg_io.h | 5 +- decoder/hg_test.cc | 39 +++++--- decoder/rule_lexer.ll | 1 + python/cdec/hypergraph.pxd | 3 +- training/dpmert/mr_dpmert_generate_mapper_input.cc | 2 +- training/dpmert/mr_dpmert_map.cc | 4 +- training/minrisk/minrisk_optimize.cc | 2 +- training/pro/mr_pro_map.cc | 2 +- training/rampion/rampion_cccp.cc | 2 +- training/utils/grammar_convert.cc | 5 +- utils/small_vector.h | 16 ++++ utils/small_vector_test.cc | 30 ++++++ 16 files changed, 156 insertions(+), 116 deletions(-) (limited to 'training') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 3cc77d27..f8214f5f 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -930,7 +930,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); @@ -1023,7 +1023,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { Hypergraph new_hg; { ReadFile rf(writer.fname_); - bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); if (!succeeded) abort(); } HG::Union(forest, &new_hg); diff --git a/decoder/forest_writer.cc b/decoder/forest_writer.cc index 6e4cccb3..c072e599 100644 --- a/decoder/forest_writer.cc +++ b/decoder/forest_writer.cc @@ -11,13 +11,13 @@ using namespace std; ForestWriter::ForestWriter(const std::string& path, int num) : - fname_(path + '/' + boost::lexical_cast(num) + ".json.gz"), used_(false) {} + fname_(path + '/' + boost::lexical_cast(num) + ".bin.gz"), used_(false) {} bool ForestWriter::Write(const Hypergraph& forest, bool minimal_rules) { assert(!used_); used_ = true; cerr << " Writing forest to " << fname_ << endl; WriteFile wf(fname_); - return HypergraphIO::WriteToJSON(forest, minimal_rules, wf.stream()); + return HypergraphIO::WriteToBinary(forest, wf.stream()); } diff --git a/decoder/hg.h b/decoder/hg.h index 256f650f..124eab86 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "feature_vector.h" #include "small_vector.h" @@ -69,6 +70,18 @@ namespace HG { short int j_; short int prev_i_; short int prev_j_; + template + void serialize(Archive & ar, const unsigned int version) { + ar & head_node_; + ar & tail_nodes_; + ar & rule_; + ar & feature_values_; + ar & i_; + ar & j_; + ar & prev_i_; + ar & prev_j_; + ar & id_; + } void show(std::ostream &o,unsigned mask=SPAN|RULE) const { o<<'{'; if (mask&CATEGORY) @@ -149,6 +162,24 @@ namespace HG { WordID NT() const { return -cat_; } EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ + template + void save(Archive & ar, const unsigned int version) const { + ar & node_hash; + ar & id_; + ar & TD::Convert(-cat_); + ar & in_edges_; + ar & out_edges_; + } + template + void load(Archive & ar, const unsigned int version) { + ar & node_hash; + ar & id_; + std::string cat; ar & cat; + cat_ = -TD::Convert(cat); + ar & in_edges_; + ar & out_edges_; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() void copy_fixed(Node const& o) { // nonstructural fields only - structural ones are managed by sorting/pruning/subsetting node_hash = o.node_hash; cat_=o.cat_; @@ -492,6 +523,27 @@ public: void set_ids(); // resync edge,node .id_ void check_ids() const; // assert that .id_ have been kept in sync + template + void save(Archive & ar, const unsigned int version) const { + unsigned ns = nodes_.size(); ar & ns; + unsigned es = edges_.size(); ar & es; + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + x = edges_topo_; ar & x; + x = is_linear_chain_; ar & x; + } + template + void load(Archive & ar, const unsigned int version) { + unsigned ns; ar & ns; nodes_.resize(ns); + unsigned es; ar & es; edges_.resize(es); + for (auto& n : nodes_) ar & n; + for (auto& e : edges_) ar & e; + int x; + ar & x; edges_topo_ = x; + ar & x; is_linear_chain_ = x; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() private: Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {} }; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index eb0be3d4..67760fb1 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -6,6 +6,10 @@ #include #include +#include +#include +#include + #include "fast_lexical_cast.hpp" #include "tdict.h" @@ -271,97 +275,16 @@ bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { return reader.Parse(in); } -static void WriteRule(const TRule& r, ostream* out) { - if (!r.lhs_) { (*out) << "[X] ||| "; } - JSONParser::WriteEscapedString(r.AsString(), out); +bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) { + boost::archive::binary_iarchive oa(*in); + hg->clear(); + oa >> *hg; + return true; } -bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) { - if (hg.empty()) { *out << "{}\n"; return true; } - map rid; - ostream& o = *out; - rid[NULL] = 0; - o << '{'; - if (!remove_rules) { - o << "\"rules\":["; - for (int i = 0; i < hg.edges_.size(); ++i) { - const TRule* r = hg.edges_[i].rule_.get(); - int &id = rid[r]; - if (!id) { - id=rid.size() - 1; - if (id > 1) o << ','; - o << id << ','; - WriteRule(*r, &o); - }; - } - o << "],"; - } - const bool use_fdict = FD::NumFeats() < 1000; - if (use_fdict) { - o << "\"features\":["; - for (int i = 1; i < FD::NumFeats(); ++i) { - o << (i==1 ? "":","); - JSONParser::WriteEscapedString(FD::Convert(i), &o); - } - o << "],"; - } - vector edgemap(hg.edges_.size(), -1); // edges may be in non-topo order - int edge_count = 0; - for (int i = 0; i < hg.nodes_.size(); ++i) { - const Hypergraph::Node& node = hg.nodes_[i]; - if (i > 0) { o << ","; } - o << "\"edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - const Hypergraph::Edge& edge = hg.edges_[node.in_edges_[j]]; - edgemap[edge.id_] = edge_count; - ++edge_count; - o << (j == 0 ? "" : ",") << "{"; - - o << "\"tail\":["; - for (int k = 0; k < edge.tail_nodes_.size(); ++k) { - o << (k > 0 ? "," : "") << edge.tail_nodes_[k]; - } - o << "],"; - - o << "\"spans\":[" << edge.i_ << "," << edge.j_ << "," << edge.prev_i_ << "," << edge.prev_j_ << "],"; - - o << "\"feats\":["; - bool first = true; - for (SparseVector::const_iterator it = edge.feature_values_.begin(); it != edge.feature_values_.end(); ++it) { - if (!it->second) continue; // don't write features that have a zero value - if (!it->first) continue; // if the feature set was frozen this might happen - if (!first) o << ','; - if (use_fdict) - o << (it->first - 1); - else { - JSONParser::WriteEscapedString(FD::Convert(it->first), &o); - } - o << ',' << it->second; - first = false; - } - o << "]"; - if (!remove_rules) { o << ",\"rule\":" << rid[edge.rule_.get()]; } - o << "}"; - } - o << "],"; - - o << "\"node\":{\"in_edges\":["; - for (int j = 0; j < node.in_edges_.size(); ++j) { - int mapped_edge = edgemap[node.in_edges_[j]]; - assert(mapped_edge >= 0); - o << (j == 0 ? "" : ",") << mapped_edge; - } - o << "]"; - if (node.cat_ < 0) { - o << ",\"cat\":"; - JSONParser::WriteEscapedString(TD::Convert(node.cat_ * -1), &o); - } - char buf[48]; - sprintf(buf, "%016lX", node.node_hash); - o << ",\"node_hash\":\"" << buf << "\""; - o << "}"; - } - o << "}\n"; +bool HypergraphIO::WriteToBinary(const Hypergraph& hg, ostream* out) { + boost::archive::binary_oarchive oa(*out); + oa << hg; return true; } diff --git a/decoder/hg_io.h b/decoder/hg_io.h index 5a2bd808..5ba86f69 100644 --- a/decoder/hg_io.h +++ b/decoder/hg_io.h @@ -18,10 +18,11 @@ struct HypergraphIO { // see test_data/small.json.gz for an email encoding static bool ReadFromJSON(std::istream* in, Hypergraph* out); + static bool ReadFromBinary(std::istream* in, Hypergraph* out); + static bool WriteToBinary(const Hypergraph& hg, std::ostream* out); + // if remove_rules is used, the hypergraph is serialized without rule information // (so it only contains structure and feature information) - static bool WriteToJSON(const Hypergraph& hg, bool remove_rules, std::ostream* out); - static void WriteAsCFG(const Hypergraph& hg); // Write only the target size information in bottom-up order. diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 5cb8626a..25eddcec 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -1,6 +1,11 @@ #define BOOST_TEST_MODULE hg_test #include #include +#include +#include +#include +#include +#include #include #include "tdict.h" @@ -427,19 +432,29 @@ BOOST_AUTO_TEST_CASE(TestGenericKBest) { } } -BOOST_AUTO_TEST_CASE(TestReadWriteHG) { +BOOST_AUTO_TEST_CASE(TestReadWriteHG_Boost) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - Hypergraph hg,hg2; - CreateHG(path, &hg); - hg.edges_.front().j_ = 23; - hg.edges_.back().prev_i_ = 99; - ostringstream os; - HypergraphIO::WriteToJSON(hg, false, &os); - istringstream is(os.str()); - HypergraphIO::ReadFromJSON(&is, &hg2); - BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); - BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); - BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + Hypergraph hg; + Hypergraph hg2; + std::string out; + { + CreateHG(path, &hg); + hg.edges_.front().j_ = 23; + hg.edges_.back().prev_i_ = 99; + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << hg; + out = os.str(); + } + { + cerr << out << endl; + istringstream is(out); + boost::archive::text_iarchive ia(is); + ia >> hg2; + BOOST_CHECK_EQUAL(hg2.NumberOfPaths(), hg.NumberOfPaths()); + BOOST_CHECK_EQUAL(hg2.edges_.front().j_, 23); + BOOST_CHECK_EQUAL(hg2.edges_.back().prev_i_, 99); + } } BOOST_AUTO_TEST_SUITE_END() diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index d4a8d86b..8b48ab7b 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -356,6 +356,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const void RuleLexer::ReadRule(const std::string& srule, RuleCallback func, bool mono, void* extra) { init_default_feature_names(); + scfglex_fname = srule; lex_mono_rules = mono; lex_line = 1; rule_callback_extra = extra; diff --git a/python/cdec/hypergraph.pxd b/python/cdec/hypergraph.pxd index 1e150bbc..9780cf8b 100644 --- a/python/cdec/hypergraph.pxd +++ b/python/cdec/hypergraph.pxd @@ -63,7 +63,8 @@ cdef extern from "decoder/viterbi.h": cdef extern from "decoder/hg_io.h" namespace "HypergraphIO": # Hypergraph JSON I/O bint ReadFromJSON(istream* inp, Hypergraph* out) - bint WriteToJSON(Hypergraph& hg, bint remove_rules, ostream* out) + bint ReadFromBinary(istream* inp, Hypergraph* out) + bint WriteToBinary(Hypergraph& hg, ostream* out) # Hypergraph PLF I/O void ReadFromPLF(string& inp, Hypergraph* out) string AsPLF(Hypergraph& hg, bint include_global_parentheses) diff --git a/training/dpmert/mr_dpmert_generate_mapper_input.cc b/training/dpmert/mr_dpmert_generate_mapper_input.cc index 199cd23a..3fa2f476 100644 --- a/training/dpmert/mr_dpmert_generate_mapper_input.cc +++ b/training/dpmert/mr_dpmert_generate_mapper_input.cc @@ -70,7 +70,7 @@ int main(int argc, char** argv) { unsigned dev_set_size = conf["dev_set_size"].as(); for (unsigned i = 0; i < dev_set_size; ++i) { for (unsigned j = 0; j < directions.size(); ++j) { - cout << forest_repository << '/' << i << ".json.gz " << i << ' '; + cout << forest_repository << '/' << i << ".bin.gz " << i << ' '; print(cout, origin, "=", ";"); cout << ' '; print(cout, directions[j], "=", ";"); diff --git a/training/dpmert/mr_dpmert_map.cc b/training/dpmert/mr_dpmert_map.cc index d1efcf96..2bf3f8fc 100644 --- a/training/dpmert/mr_dpmert_map.cc +++ b/training/dpmert/mr_dpmert_map.cc @@ -83,7 +83,7 @@ int main(int argc, char** argv) { istringstream is(line); int sent_id; string file, s_origin, s_direction; - // path-to-file (JSON) sent_ed starting-point search-direction + // path-to-file sent_ed starting-point search-direction is >> file >> sent_id >> s_origin >> s_direction; SparseVector origin; ReadSparseVectorString(s_origin, &origin); @@ -93,7 +93,7 @@ int main(int argc, char** argv) { if (last_file != file) { last_file = file; ReadFile rf(file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); } const ConvexHullWeightFunction wf(origin, direction); const ConvexHull hull = Inside(hg, NULL, wf); diff --git a/training/minrisk/minrisk_optimize.cc b/training/minrisk/minrisk_optimize.cc index da8b5260..a2938fb0 100644 --- a/training/minrisk/minrisk_optimize.cc +++ b/training/minrisk/minrisk_optimize.cc @@ -178,7 +178,7 @@ int main(int argc, char** argv) { ReadFile rf(file); if (kis.size() % 5 == 0) { cerr << '.'; } if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); if (kbest_file.size()) diff --git a/training/pro/mr_pro_map.cc b/training/pro/mr_pro_map.cc index da58cd24..b142fd05 100644 --- a/training/pro/mr_pro_map.cc +++ b/training/pro/mr_pro_map.cc @@ -203,7 +203,7 @@ int main(int argc, char** argv) { const string kbest_file = os.str(); if (FileExists(kbest_file)) J_i.ReadFromFile(kbest_file); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); J_i.AddKBestCandidates(hg, kbest_size, ds[sent_id]); J_i.WriteToFile(kbest_file); diff --git a/training/rampion/rampion_cccp.cc b/training/rampion/rampion_cccp.cc index 1e36dc51..1c45bac5 100644 --- a/training/rampion/rampion_cccp.cc +++ b/training/rampion/rampion_cccp.cc @@ -136,7 +136,7 @@ int main(int argc, char** argv) { ReadFile rf(file); if (kis.size() % 5 == 0) { cerr << '.'; } if (kis.size() % 200 == 0) { cerr << " [" << kis.size() << "]\n"; } - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(weights); curkbest.AddKBestCandidates(hg, kbest_size, ds[sent_id]); if (kbest_file.size()) diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 5c1b4d4a..000f2a26 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -43,7 +43,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::notify(*conf); if (conf->count("help") || conf->count("input") == 0) { - cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into JSON hypergraph.\n"; + cerr << "\nUsage: grammar_convert [-options]\n\nConverts a grammar file (in Hiero format) into serialized hypergraph.\n"; cerr << dcmdline_options << endl; exit(1); } @@ -254,7 +254,8 @@ void ProcessHypergraph(const vector& w, const po::variables_map& conf, c if (w.size() > 0) { hg->Reweight(w); } if (conf.count("collapse_weights")) CollapseWeights(hg); if (conf["output"].as() == "json") { - HypergraphIO::WriteToJSON(*hg, false, &cout); + cerr << "NOT IMPLEMENTED ... talk to cdyer if you need this functionality\n"; + // HypergraphIO::WriteToBinary(*hg, &cout); if (!ref.empty()) { cerr << "REF: " << ref << endl; } } else { vector onebest; diff --git a/utils/small_vector.h b/utils/small_vector.h index c8cbcb2c..f16bc898 100644 --- a/utils/small_vector.h +++ b/utils/small_vector.h @@ -15,6 +15,7 @@ #include #include #include +#include //sizeof(T)/sizeof(T*)>1?sizeof(T)/sizeof(T*):1 @@ -297,6 +298,21 @@ public: return hash_range(data_.ptr,data_.ptr+size_); } + template + void save(Archive & ar, const unsigned int) const { + ar & size_; + for (unsigned i = 0; i < size_; ++i) + ar & (*this)[i]; + } + template + void load(Archive & ar, const unsigned int) { + uint16_t s; + ar & s; + this->resize(s); + for (unsigned i = 0; i < size_; ++i) + ar & (*this)[i]; + } + BOOST_SERIALIZATION_SPLIT_MEMBER() private: union StorageType { T vals[SV_MAX]; diff --git a/utils/small_vector_test.cc b/utils/small_vector_test.cc index a4eb89ae..9e1a148d 100644 --- a/utils/small_vector_test.cc +++ b/utils/small_vector_test.cc @@ -3,6 +3,10 @@ #define BOOST_TEST_MODULE svTest #include #include +#include +#include +#include +#include #include #include @@ -128,3 +132,29 @@ BOOST_AUTO_TEST_CASE(Small) { cerr << sizeof(SmallVectorInt) << endl; cerr << sizeof(vector) << endl; } + +BOOST_AUTO_TEST_CASE(Serialize) { + std::string in; + { + SmallVectorInt v; + v.push_back(0); + v.push_back(1); + v.push_back(-2); + ostringstream os; + boost::archive::text_oarchive oa(os); + oa << v; + in = os.str(); + cerr << in; + } + { + istringstream is(in); + boost::archive::text_iarchive ia(is); + SmallVectorInt v; + ia >> v; + BOOST_CHECK_EQUAL(v.size(), 3); + BOOST_CHECK_EQUAL(v[0], 0); + BOOST_CHECK_EQUAL(v[1], 1); + BOOST_CHECK_EQUAL(v[2], -2); + } +} + -- cgit v1.2.3 From b2d5b3da636cfbef53830245f1f5281add2a4b62 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 19 Oct 2014 15:23:31 -0400 Subject: remove json hypergraph format --- decoder/JSON_parser.c | 1012 --------------------- decoder/JSON_parser.h | 152 ---- decoder/Makefile.am | 6 +- decoder/aligner.h | 2 +- decoder/fst_translator.cc | 10 +- decoder/hg.h | 10 +- decoder/hg_io.cc | 258 ------ decoder/hg_io.h | 9 - decoder/hg_test.cc | 11 - decoder/hg_test.h | 32 +- decoder/json_parse.cc | 50 - decoder/json_parse.h | 58 -- decoder/rescore_translator.cc | 23 +- decoder/test_data/perro.json.gz | Bin 608 -> 0 bytes decoder/test_data/small.json.gz | Bin 1733 -> 0 bytes decoder/test_data/urdu.json.gz | Bin 253497 -> 0 bytes decoder/trule.h | 4 +- tests/system_tests/cfg_rescore/input.txt | 2 +- tests/system_tests/ftrans/input.txt | 2 +- tests/system_tests/ftrans/input0.hg.bin.gz | Bin 0 -> 225 bytes training/dpmert/lo_test.cc | 22 +- training/dpmert/test_data/0.bin.gz | Bin 0 -> 24904 bytes training/dpmert/test_data/0.json.gz | Bin 13709 -> 0 bytes training/dpmert/test_data/1.bin.gz | Bin 0 -> 339220 bytes training/dpmert/test_data/1.json.gz | Bin 204803 -> 0 bytes training/dpmert/test_data/test-ch-inside.bin.gz | Bin 0 -> 340 bytes training/dpmert/test_data/test-zero-origin.bin.gz | Bin 0 -> 923 bytes 27 files changed, 59 insertions(+), 1604 deletions(-) delete mode 100644 decoder/JSON_parser.c delete mode 100644 decoder/JSON_parser.h delete mode 100644 decoder/json_parse.cc delete mode 100644 decoder/json_parse.h delete mode 100644 decoder/test_data/perro.json.gz delete mode 100644 decoder/test_data/small.json.gz delete mode 100644 decoder/test_data/urdu.json.gz create mode 100644 tests/system_tests/ftrans/input0.hg.bin.gz create mode 100644 training/dpmert/test_data/0.bin.gz delete mode 100644 training/dpmert/test_data/0.json.gz create mode 100644 training/dpmert/test_data/1.bin.gz delete mode 100644 training/dpmert/test_data/1.json.gz create mode 100644 training/dpmert/test_data/test-ch-inside.bin.gz create mode 100644 training/dpmert/test_data/test-zero-origin.bin.gz (limited to 'training') diff --git a/decoder/JSON_parser.c b/decoder/JSON_parser.c deleted file mode 100644 index 5e392bc6..00000000 --- a/decoder/JSON_parser.c +++ /dev/null @@ -1,1012 +0,0 @@ -/* JSON_parser.c */ - -/* 2007-08-24 */ - -/* -Copyright (c) 2005 JSON.org - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -The Software shall be used for Good, not Evil. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -/* - Callbacks, comments, Unicode handling by Jean Gressmann (jean@0x42.de), 2007-2009. - - For the added features the license above applies also. - - Changelog: - 2009-05-17 - Incorporated benrudiak@googlemail.com fix for UTF16 decoding. - - 2009-05-14 - Fixed float parsing bug related to a locale being set that didn't - use '.' as decimal point character (charles@transmissionbt.com). - - 2008-10-14 - Renamed states.IN to states.IT to avoid name clash which IN macro - defined in windef.h (alexey.pelykh@gmail.com) - - 2008-07-19 - Removed some duplicate code & debugging variable (charles@transmissionbt.com) - - 2008-05-28 - Made JSON_value structure ansi C compliant. This bug was report by - trisk@acm.jhu.edu - - 2008-05-20 - Fixed bug reported by charles@transmissionbt.com where the switching - from static to dynamic parse buffer did not copy the static parse - buffer's content. -*/ - - - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "JSON_parser.h" - -#ifdef _MSC_VER -# if _MSC_VER >= 1400 /* Visual Studio 2005 and up */ -# pragma warning(disable:4996) // unsecure sscanf -# endif -#endif - - -#define true 1 -#define false 0 -#define __ -1 /* the universal error code */ - -/* values chosen so that the object size is approx equal to one page (4K) */ -#ifndef JSON_PARSER_STACK_SIZE -# define JSON_PARSER_STACK_SIZE 128 -#endif - -#ifndef JSON_PARSER_PARSE_BUFFER_SIZE -# define JSON_PARSER_PARSE_BUFFER_SIZE 3500 -#endif - -typedef unsigned short UTF16; - -struct JSON_parser_struct { - JSON_parser_callback callback; - void* ctx; - signed char state, before_comment_state, type, escaped, comment, allow_comments, handle_floats_manually; - UTF16 utf16_high_surrogate; - long depth; - long top; - signed char* stack; - long stack_capacity; - char decimal_point; - char* parse_buffer; - size_t parse_buffer_capacity; - size_t parse_buffer_count; - size_t comment_begin_offset; - signed char static_stack[JSON_PARSER_STACK_SIZE]; - char static_parse_buffer[JSON_PARSER_PARSE_BUFFER_SIZE]; -}; - -#define COUNTOF(x) (sizeof(x)/sizeof(x[0])) - -/* - Characters are mapped into these character classes. This allows for - a significant reduction in the size of the state transition table. -*/ - - - -enum classes { - C_SPACE, /* space */ - C_WHITE, /* other whitespace */ - C_LCURB, /* { */ - C_RCURB, /* } */ - C_LSQRB, /* [ */ - C_RSQRB, /* ] */ - C_COLON, /* : */ - C_COMMA, /* , */ - C_QUOTE, /* " */ - C_BACKS, /* \ */ - C_SLASH, /* / */ - C_PLUS, /* + */ - C_MINUS, /* - */ - C_POINT, /* . */ - C_ZERO , /* 0 */ - C_DIGIT, /* 123456789 */ - C_LOW_A, /* a */ - C_LOW_B, /* b */ - C_LOW_C, /* c */ - C_LOW_D, /* d */ - C_LOW_E, /* e */ - C_LOW_F, /* f */ - C_LOW_L, /* l */ - C_LOW_N, /* n */ - C_LOW_R, /* r */ - C_LOW_S, /* s */ - C_LOW_T, /* t */ - C_LOW_U, /* u */ - C_ABCDF, /* ABCDF */ - C_E, /* E */ - C_ETC, /* everything else */ - C_STAR, /* * */ - NR_CLASSES -}; - -static int ascii_class[128] = { -/* - This array maps the 128 ASCII characters into character classes. - The remaining Unicode characters should be mapped to C_ETC. - Non-whitespace control characters are errors. -*/ - __, __, __, __, __, __, __, __, - __, C_WHITE, C_WHITE, __, __, C_WHITE, __, __, - __, __, __, __, __, __, __, __, - __, __, __, __, __, __, __, __, - - C_SPACE, C_ETC, C_QUOTE, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_STAR, C_PLUS, C_COMMA, C_MINUS, C_POINT, C_SLASH, - C_ZERO, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, C_DIGIT, - C_DIGIT, C_DIGIT, C_COLON, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - - C_ETC, C_ABCDF, C_ABCDF, C_ABCDF, C_ABCDF, C_E, C_ABCDF, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LSQRB, C_BACKS, C_RSQRB, C_ETC, C_ETC, - - C_ETC, C_LOW_A, C_LOW_B, C_LOW_C, C_LOW_D, C_LOW_E, C_LOW_F, C_ETC, - C_ETC, C_ETC, C_ETC, C_ETC, C_LOW_L, C_ETC, C_LOW_N, C_ETC, - C_ETC, C_ETC, C_LOW_R, C_LOW_S, C_LOW_T, C_LOW_U, C_ETC, C_ETC, - C_ETC, C_ETC, C_ETC, C_LCURB, C_ETC, C_RCURB, C_ETC, C_ETC -}; - - -/* - The state codes. -*/ -enum states { - GO, /* start */ - OK, /* ok */ - OB, /* object */ - KE, /* key */ - CO, /* colon */ - VA, /* value */ - AR, /* array */ - ST, /* string */ - ES, /* escape */ - U1, /* u1 */ - U2, /* u2 */ - U3, /* u3 */ - U4, /* u4 */ - MI, /* minus */ - ZE, /* zero */ - IT, /* integer */ - FR, /* fraction */ - E1, /* e */ - E2, /* ex */ - E3, /* exp */ - T1, /* tr */ - T2, /* tru */ - T3, /* true */ - F1, /* fa */ - F2, /* fal */ - F3, /* fals */ - F4, /* false */ - N1, /* nu */ - N2, /* nul */ - N3, /* null */ - C1, /* / */ - C2, /* / * */ - C3, /* * */ - FX, /* *.* *eE* */ - D1, /* second UTF-16 character decoding started by \ */ - D2, /* second UTF-16 character proceeded by u */ - NR_STATES -}; - -enum actions -{ - CB = -10, /* comment begin */ - CE = -11, /* comment end */ - FA = -12, /* false */ - TR = -13, /* false */ - NU = -14, /* null */ - DE = -15, /* double detected by exponent e E */ - DF = -16, /* double detected by fraction . */ - SB = -17, /* string begin */ - MX = -18, /* integer detected by minus */ - ZX = -19, /* integer detected by zero */ - IX = -20, /* integer detected by 1-9 */ - EX = -21, /* next char is escaped */ - UC = -22 /* Unicode character read */ -}; - - -static int state_transition_table[NR_STATES][NR_CLASSES] = { -/* - The state transition table takes the current state and the current symbol, - and returns either a new state or an action. An action is represented as a - negative number. A JSON text is accepted if at the end of the text the - state is OK and if the mode is MODE_DONE. - - white 1-9 ABCDF etc - space | { } [ ] : , " \ / + - . 0 | a b c d e f l n r s t u | E | * */ -/*start GO*/ {GO,GO,-6,__,-5,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ok OK*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*object OB*/ {OB,OB,__,-9,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*key KE*/ {KE,KE,__,__,__,__,__,__,SB,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*colon CO*/ {CO,CO,__,__,__,__,-2,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*value VA*/ {VA,VA,-6,__,-5,__,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*array AR*/ {AR,AR,-6,__,-5,-7,__,__,SB,__,CB,__,MX,__,ZX,IX,__,__,__,__,__,FA,__,NU,__,__,TR,__,__,__,__,__}, -/*string ST*/ {ST,__,ST,ST,ST,ST,ST,ST,-4,EX,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST,ST}, -/*escape ES*/ {__,__,__,__,__,__,__,__,ST,ST,ST,__,__,__,__,__,__,ST,__,__,__,ST,__,ST,ST,__,ST,U1,__,__,__,__}, -/*u1 U1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U2,U2,U2,U2,U2,U2,U2,U2,__,__,__,__,__,__,U2,U2,__,__}, -/*u2 U2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U3,U3,U3,U3,U3,U3,U3,U3,__,__,__,__,__,__,U3,U3,__,__}, -/*u3 U3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,U4,U4,U4,U4,U4,U4,U4,U4,__,__,__,__,__,__,U4,U4,__,__}, -/*u4 U4*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,UC,UC,UC,UC,UC,UC,UC,UC,__,__,__,__,__,__,UC,UC,__,__}, -/*minus MI*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,ZE,IT,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*zero ZE*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*int IT*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,DF,IT,IT,__,__,__,__,DE,__,__,__,__,__,__,__,__,DE,__,__}, -/*frac FR*/ {OK,OK,__,-8,__,-7,__,-3,__,__,CB,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*e E1*/ {__,__,__,__,__,__,__,__,__,__,__,E2,E2,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*ex E2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*exp E3*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,E3,E3,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*tr T1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T2,__,__,__,__,__,__,__}, -/*tru T2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,T3,__,__,__,__}, -/*true T3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*fa F1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*fal F2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F3,__,__,__,__,__,__,__,__,__}, -/*fals F3*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,F4,__,__,__,__,__,__}, -/*false F4*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__,__,__}, -/*nu N1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N2,__,__,__,__}, -/*nul N2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,N3,__,__,__,__,__,__,__,__,__}, -/*null N3*/ {__,__,__,__,__,__,__,__,__,__,CB,__,__,__,__,__,__,__,__,__,__,__,OK,__,__,__,__,__,__,__,__,__}, -/*/ C1*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,C2}, -/*/* C2*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/** C3*/ {C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,CE,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C2,C3}, -/*_. FX*/ {OK,OK,__,-8,__,-7,__,-3,__,__,__,__,__,__,FR,FR,__,__,__,__,E1,__,__,__,__,__,__,__,__,E1,__,__}, -/*\ D1*/ {__,__,__,__,__,__,__,__,__,D2,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__}, -/*\ D2*/ {__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,__,U1,__,__,__,__}, -}; - - -/* - These modes can be pushed on the stack. -*/ -enum modes { - MODE_ARRAY = 1, - MODE_DONE = 2, - MODE_KEY = 3, - MODE_OBJECT = 4 -}; - -static int -push(JSON_parser jc, int mode) -{ -/* - Push a mode onto the stack. Return false if there is overflow. -*/ - jc->top += 1; - if (jc->depth < 0) { - if (jc->top >= jc->stack_capacity) { - size_t bytes_to_allocate; - jc->stack_capacity *= 2; - bytes_to_allocate = jc->stack_capacity * sizeof(jc->static_stack[0]); - if (jc->stack == &jc->static_stack[0]) { - jc->stack = (signed char*)malloc(bytes_to_allocate); - memcpy(jc->stack, jc->static_stack, sizeof(jc->static_stack)); - } else { - jc->stack = (signed char*)realloc(jc->stack, bytes_to_allocate); - } - } - } else { - if (jc->top >= jc->depth) { - return false; - } - } - - jc->stack[jc->top] = mode; - return true; -} - - -static int -pop(JSON_parser jc, int mode) -{ -/* - Pop the stack, assuring that the current mode matches the expectation. - Return false if there is underflow or if the modes mismatch. -*/ - if (jc->top < 0 || jc->stack[jc->top] != mode) { - return false; - } - jc->top -= 1; - return true; -} - - -#define parse_buffer_clear(jc) \ - do {\ - jc->parse_buffer_count = 0;\ - jc->parse_buffer[0] = 0;\ - } while (0) - -#define parse_buffer_pop_back_char(jc)\ - do {\ - assert(jc->parse_buffer_count >= 1);\ - --jc->parse_buffer_count;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -void delete_JSON_parser(JSON_parser jc) -{ - if (jc) { - if (jc->stack != &jc->static_stack[0]) { - free((void*)jc->stack); - } - if (jc->parse_buffer != &jc->static_parse_buffer[0]) { - free((void*)jc->parse_buffer); - } - free((void*)jc); - } -} - - -JSON_parser -new_JSON_parser(JSON_config* config) -{ -/* - new_JSON_parser starts the checking process by constructing a JSON_parser - object. It takes a depth parameter that restricts the level of maximum - nesting. - - To continue the process, call JSON_parser_char for each character in the - JSON text, and then call JSON_parser_done to obtain the final result. - These functions are fully reentrant. -*/ - - int depth = 0; - JSON_config default_config; - - JSON_parser jc = (JSON_parser)malloc(sizeof(struct JSON_parser_struct)); - - memset(jc, 0, sizeof(*jc)); - - - /* initialize configuration */ - init_JSON_config(&default_config); - - /* set to default configuration if none was provided */ - if (config == NULL) { - config = &default_config; - } - - depth = config->depth; - - /* We need to be able to push at least one object */ - if (depth == 0) { - depth = 1; - } - - jc->state = GO; - jc->top = -1; - - /* Do we want non-bound stack? */ - if (depth > 0) { - jc->stack_capacity = depth; - jc->depth = depth; - if (depth <= (int)COUNTOF(jc->static_stack)) { - jc->stack = &jc->static_stack[0]; - } else { - jc->stack = (signed char*)malloc(jc->stack_capacity * sizeof(jc->static_stack[0])); - } - } else { - jc->stack_capacity = COUNTOF(jc->static_stack); - jc->depth = -1; - jc->stack = &jc->static_stack[0]; - } - - /* set parser to start */ - push(jc, MODE_DONE); - - /* set up the parse buffer */ - jc->parse_buffer = &jc->static_parse_buffer[0]; - jc->parse_buffer_capacity = COUNTOF(jc->static_parse_buffer); - parse_buffer_clear(jc); - - /* set up callback, comment & float handling */ - jc->callback = config->callback; - jc->ctx = config->callback_ctx; - jc->allow_comments = config->allow_comments != 0; - jc->handle_floats_manually = config->handle_floats_manually != 0; - - /* set up decimal point */ - jc->decimal_point = *localeconv()->decimal_point; - - return jc; -} - -static void grow_parse_buffer(JSON_parser jc) -{ - size_t bytes_to_allocate; - jc->parse_buffer_capacity *= 2; - bytes_to_allocate = jc->parse_buffer_capacity * sizeof(jc->parse_buffer[0]); - if (jc->parse_buffer == &jc->static_parse_buffer[0]) { - jc->parse_buffer = (char*)malloc(bytes_to_allocate); - memcpy(jc->parse_buffer, jc->static_parse_buffer, jc->parse_buffer_count); - } else { - jc->parse_buffer = (char*)realloc(jc->parse_buffer, bytes_to_allocate); - } -} - -#define parse_buffer_push_back_char(jc, c)\ - do {\ - if (jc->parse_buffer_count + 1 >= jc->parse_buffer_capacity) grow_parse_buffer(jc);\ - jc->parse_buffer[jc->parse_buffer_count++] = c;\ - jc->parse_buffer[jc->parse_buffer_count] = 0;\ - } while (0) - -#define assert_is_non_container_type(jc) \ - assert( \ - jc->type == JSON_T_NULL || \ - jc->type == JSON_T_FALSE || \ - jc->type == JSON_T_TRUE || \ - jc->type == JSON_T_FLOAT || \ - jc->type == JSON_T_INTEGER || \ - jc->type == JSON_T_STRING) - - -static int parse_parse_buffer(JSON_parser jc) -{ - if (jc->callback) { - JSON_value value, *arg = NULL; - - if (jc->type != JSON_T_NONE) { - assert_is_non_container_type(jc); - - switch(jc->type) { - case JSON_T_FLOAT: - arg = &value; - if (jc->handle_floats_manually) { - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - } else { - /*sscanf(jc->parse_buffer, "%Lf", &value.vu.float_value);*/ - - /* not checking with end pointer b/c there may be trailing ws */ - value.vu.float_value = strtod(jc->parse_buffer, NULL); - } - break; - case JSON_T_INTEGER: - arg = &value; - sscanf(jc->parse_buffer, JSON_PARSER_INTEGER_SSCANF_TOKEN, &value.vu.integer_value); - break; - case JSON_T_STRING: - arg = &value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - break; - } - - if (!(*jc->callback)(jc->ctx, jc->type, arg)) { - return false; - } - } - } - - parse_buffer_clear(jc); - - return true; -} - -#define IS_HIGH_SURROGATE(uc) (((uc) & 0xFC00) == 0xD800) -#define IS_LOW_SURROGATE(uc) (((uc) & 0xFC00) == 0xDC00) -#define DECODE_SURROGATE_PAIR(hi,lo) ((((hi) & 0x3FF) << 10) + ((lo) & 0x3FF) + 0x10000) -static unsigned char utf8_lead_bits[4] = { 0x00, 0xC0, 0xE0, 0xF0 }; - -static int decode_unicode_char(JSON_parser jc) -{ - int i; - unsigned uc = 0; - char* p; - int trail_bytes; - - assert(jc->parse_buffer_count >= 6); - - p = &jc->parse_buffer[jc->parse_buffer_count - 4]; - - for (i = 12; i >= 0; i -= 4, ++p) { - unsigned x = *p; - - if (x >= 'a') { - x -= ('a' - 10); - } else if (x >= 'A') { - x -= ('A' - 10); - } else { - x &= ~0x30u; - } - - assert(x < 16); - - uc |= x << i; - } - - /* clear UTF-16 char from buffer */ - jc->parse_buffer_count -= 6; - jc->parse_buffer[jc->parse_buffer_count] = 0; - - /* attempt decoding ... */ - if (jc->utf16_high_surrogate) { - if (IS_LOW_SURROGATE(uc)) { - uc = DECODE_SURROGATE_PAIR(jc->utf16_high_surrogate, uc); - trail_bytes = 3; - jc->utf16_high_surrogate = 0; - } else { - /* high surrogate without a following low surrogate */ - return false; - } - } else { - if (uc < 0x80) { - trail_bytes = 0; - } else if (uc < 0x800) { - trail_bytes = 1; - } else if (IS_HIGH_SURROGATE(uc)) { - /* save the high surrogate and wait for the low surrogate */ - jc->utf16_high_surrogate = uc; - return true; - } else if (IS_LOW_SURROGATE(uc)) { - /* low surrogate without a preceding high surrogate */ - return false; - } else { - trail_bytes = 2; - } - } - - jc->parse_buffer[jc->parse_buffer_count++] = (char) ((uc >> (trail_bytes * 6)) | utf8_lead_bits[trail_bytes]); - - for (i = trail_bytes * 6 - 6; i >= 0; i -= 6) { - jc->parse_buffer[jc->parse_buffer_count++] = (char) (((uc >> i) & 0x3F) | 0x80); - } - - jc->parse_buffer[jc->parse_buffer_count] = 0; - - return true; -} - -static int add_escaped_char_to_parse_buffer(JSON_parser jc, int next_char) -{ - jc->escaped = 0; - /* remove the backslash */ - parse_buffer_pop_back_char(jc); - switch(next_char) { - case 'b': - parse_buffer_push_back_char(jc, '\b'); - break; - case 'f': - parse_buffer_push_back_char(jc, '\f'); - break; - case 'n': - parse_buffer_push_back_char(jc, '\n'); - break; - case 'r': - parse_buffer_push_back_char(jc, '\r'); - break; - case 't': - parse_buffer_push_back_char(jc, '\t'); - break; - case '"': - parse_buffer_push_back_char(jc, '"'); - break; - case '\\': - parse_buffer_push_back_char(jc, '\\'); - break; - case '/': - parse_buffer_push_back_char(jc, '/'); - break; - case 'u': - parse_buffer_push_back_char(jc, '\\'); - parse_buffer_push_back_char(jc, 'u'); - break; - default: - return false; - } - - return true; -} - -#define add_char_to_parse_buffer(jc, next_char, next_class) \ - do { \ - if (jc->escaped) { \ - if (!add_escaped_char_to_parse_buffer(jc, next_char)) \ - return false; \ - } else if (!jc->comment) { \ - if ((jc->type != JSON_T_NONE) | !((next_class == C_SPACE) | (next_class == C_WHITE)) /* non-white-space */) { \ - parse_buffer_push_back_char(jc, (char)next_char); \ - } \ - } \ - } while (0) - - -#define assert_type_isnt_string_null_or_bool(jc) \ - assert(jc->type != JSON_T_FALSE); \ - assert(jc->type != JSON_T_TRUE); \ - assert(jc->type != JSON_T_NULL); \ - assert(jc->type != JSON_T_STRING) - - -int -JSON_parser_char(JSON_parser jc, int next_char) -{ -/* - After calling new_JSON_parser, call this function for each character (or - partial character) in your JSON text. It can accept UTF-8, UTF-16, or - UTF-32. It returns true if things are looking ok so far. If it rejects the - text, it returns false. -*/ - int next_class, next_state; - -/* - Determine the character's class. -*/ - if (next_char < 0) { - return false; - } - if (next_char >= 128) { - next_class = C_ETC; - } else { - next_class = ascii_class[next_char]; - if (next_class <= __) { - return false; - } - } - - add_char_to_parse_buffer(jc, next_char, next_class); - -/* - Get the next state from the state transition table. -*/ - next_state = state_transition_table[jc->state][next_class]; - if (next_state >= 0) { -/* - Change the state. -*/ - jc->state = next_state; - } else { -/* - Or perform one of the actions. -*/ - switch (next_state) { -/* Unicode character */ - case UC: - if(!decode_unicode_char(jc)) { - return false; - } - /* check if we need to read a second UTF-16 char */ - if (jc->utf16_high_surrogate) { - jc->state = D1; - } else { - jc->state = ST; - } - break; -/* escaped char */ - case EX: - jc->escaped = 1; - jc->state = ES; - break; -/* integer detected by minus */ - case MX: - jc->type = JSON_T_INTEGER; - jc->state = MI; - break; -/* integer detected by zero */ - case ZX: - jc->type = JSON_T_INTEGER; - jc->state = ZE; - break; -/* integer detected by 1-9 */ - case IX: - jc->type = JSON_T_INTEGER; - jc->state = IT; - break; - -/* floating point number detected by exponent*/ - case DE: - assert_type_isnt_string_null_or_bool(jc); - jc->type = JSON_T_FLOAT; - jc->state = E1; - break; - -/* floating point number detected by fraction */ - case DF: - assert_type_isnt_string_null_or_bool(jc); - if (!jc->handle_floats_manually) { -/* - Some versions of strtod (which underlies sscanf) don't support converting - C-locale formated floating point values. -*/ - assert(jc->parse_buffer[jc->parse_buffer_count-1] == '.'); - jc->parse_buffer[jc->parse_buffer_count-1] = jc->decimal_point; - } - jc->type = JSON_T_FLOAT; - jc->state = FX; - break; -/* string begin " */ - case SB: - parse_buffer_clear(jc); - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_STRING; - jc->state = ST; - break; - -/* n */ - case NU: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_NULL; - jc->state = N1; - break; -/* f */ - case FA: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_FALSE; - jc->state = F1; - break; -/* t */ - case TR: - assert(jc->type == JSON_T_NONE); - jc->type = JSON_T_TRUE; - jc->state = T1; - break; - -/* closing comment */ - case CE: - jc->comment = 0; - assert(jc->parse_buffer_count == 0); - assert(jc->type == JSON_T_NONE); - jc->state = jc->before_comment_state; - break; - -/* opening comment */ - case CB: - if (!jc->allow_comments) { - return false; - } - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - assert(jc->parse_buffer_count == 0); - assert(jc->type != JSON_T_STRING); - switch (jc->stack[jc->top]) { - case MODE_ARRAY: - case MODE_OBJECT: - switch(jc->state) { - case VA: - case AR: - jc->before_comment_state = jc->state; - break; - default: - jc->before_comment_state = OK; - break; - } - break; - default: - jc->before_comment_state = jc->state; - break; - } - jc->type = JSON_T_NONE; - jc->state = C1; - jc->comment = 1; - break; -/* empty } */ - case -9: - parse_buffer_clear(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_KEY)) { - return false; - } - jc->state = OK; - break; - -/* } */ case -8: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_END, NULL)) { - return false; - } - if (!pop(jc, MODE_OBJECT)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* ] */ case -7: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_END, NULL)) { - return false; - } - if (!pop(jc, MODE_ARRAY)) { - return false; - } - - jc->type = JSON_T_NONE; - jc->state = OK; - break; - -/* { */ case -6: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_OBJECT_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_KEY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = OB; - break; - -/* [ */ case -5: - parse_buffer_pop_back_char(jc); - if (jc->callback && !(*jc->callback)(jc->ctx, JSON_T_ARRAY_BEGIN, NULL)) { - return false; - } - if (!push(jc, MODE_ARRAY)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = AR; - break; - -/* string end " */ case -4: - parse_buffer_pop_back_char(jc); - switch (jc->stack[jc->top]) { - case MODE_KEY: - assert(jc->type == JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = CO; - - if (jc->callback) { - JSON_value value; - value.vu.str.value = jc->parse_buffer; - value.vu.str.length = jc->parse_buffer_count; - if (!(*jc->callback)(jc->ctx, JSON_T_KEY, &value)) { - return false; - } - } - parse_buffer_clear(jc); - break; - case MODE_ARRAY: - case MODE_OBJECT: - assert(jc->type == JSON_T_STRING); - if (!parse_parse_buffer(jc)) { - return false; - } - jc->type = JSON_T_NONE; - jc->state = OK; - break; - default: - return false; - } - break; - -/* , */ case -3: - parse_buffer_pop_back_char(jc); - if (!parse_parse_buffer(jc)) { - return false; - } - switch (jc->stack[jc->top]) { - case MODE_OBJECT: -/* - A comma causes a flip from object mode to key mode. -*/ - if (!pop(jc, MODE_OBJECT) || !push(jc, MODE_KEY)) { - return false; - } - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = KE; - break; - case MODE_ARRAY: - assert(jc->type != JSON_T_STRING); - jc->type = JSON_T_NONE; - jc->state = VA; - break; - default: - return false; - } - break; - -/* : */ case -2: -/* - A colon causes a flip from key mode to object mode. -*/ - parse_buffer_pop_back_char(jc); - if (!pop(jc, MODE_KEY) || !push(jc, MODE_OBJECT)) { - return false; - } - assert(jc->type == JSON_T_NONE); - jc->state = VA; - break; -/* - Bad action. -*/ - default: - return false; - } - } - return true; -} - - -int -JSON_parser_done(JSON_parser jc) -{ - const int result = jc->state == OK && pop(jc, MODE_DONE); - - return result; -} - - -int JSON_parser_is_legal_white_space_string(const char* s) -{ - int c, char_class; - - if (s == NULL) { - return false; - } - - for (; *s; ++s) { - c = *s; - - if (c < 0 || c >= 128) { - return false; - } - - char_class = ascii_class[c]; - - if (char_class != C_SPACE && char_class != C_WHITE) { - return false; - } - } - - return true; -} - - - -void init_JSON_config(JSON_config* config) -{ - if (config) { - memset(config, 0, sizeof(*config)); - - config->depth = JSON_PARSER_STACK_SIZE - 1; - } -} diff --git a/decoder/JSON_parser.h b/decoder/JSON_parser.h deleted file mode 100644 index de980072..00000000 --- a/decoder/JSON_parser.h +++ /dev/null @@ -1,152 +0,0 @@ -#ifndef JSON_PARSER_H -#define JSON_PARSER_H - -/* JSON_parser.h */ - - -#include - -/* Windows DLL stuff */ -#ifdef _WIN32 -# ifdef JSON_PARSER_DLL_EXPORTS -# define JSON_PARSER_DLL_API __declspec(dllexport) -# else -# define JSON_PARSER_DLL_API __declspec(dllimport) -# endif -#else -# define JSON_PARSER_DLL_API -#endif - -/* Determine the integer type use to parse non-floating point numbers */ -#if __STDC_VERSION__ >= 199901L || HAVE_LONG_LONG == 1 -typedef long long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%lld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%lld" -#else -typedef long JSON_int_t; -#define JSON_PARSER_INTEGER_SSCANF_TOKEN "%ld" -#define JSON_PARSER_INTEGER_SPRINTF_TOKEN "%ld" -#endif - - -#ifdef __cplusplus -extern "C" { -#endif - -typedef enum -{ - JSON_T_NONE = 0, - JSON_T_ARRAY_BEGIN, // 1 - JSON_T_ARRAY_END, // 2 - JSON_T_OBJECT_BEGIN, // 3 - JSON_T_OBJECT_END, // 4 - JSON_T_INTEGER, // 5 - JSON_T_FLOAT, // 6 - JSON_T_NULL, // 7 - JSON_T_TRUE, // 8 - JSON_T_FALSE, // 9 - JSON_T_STRING, // 10 - JSON_T_KEY, // 11 - JSON_T_MAX // 12 -} JSON_type; - -typedef struct JSON_value_struct { - union { - JSON_int_t integer_value; - - double float_value; - - struct { - const char* value; - size_t length; - } str; - } vu; -} JSON_value; - -typedef struct JSON_parser_struct* JSON_parser; - -/*! \brief JSON parser callback - - \param ctx The pointer passed to new_JSON_parser. - \param type An element of JSON_type but not JSON_T_NONE. - \param value A representation of the parsed value. This parameter is NULL for - JSON_T_ARRAY_BEGIN, JSON_T_ARRAY_END, JSON_T_OBJECT_BEGIN, JSON_T_OBJECT_END, - JSON_T_NULL, JSON_T_TRUE, and SON_T_FALSE. String values are always returned - as zero-terminated C strings. - - \return Non-zero if parsing should continue, else zero. -*/ -typedef int (*JSON_parser_callback)(void* ctx, int type, const struct JSON_value_struct* value); - - -/*! \brief The structure used to configure a JSON parser object - - \param depth If negative, the parser can parse arbitrary levels of JSON, otherwise - the depth is the limit - \param Pointer to a callback. This parameter may be NULL. In this case the input is merely checked for validity. - \param Callback context. This parameter may be NULL. - \param depth. Specifies the levels of nested JSON to allow. Negative numbers yield unlimited nesting. - \param allowComments. To allow C style comments in JSON, set to non-zero. - \param handleFloatsManually. To decode floating point numbers manually set this parameter to non-zero. - - \return The parser object. -*/ -typedef struct { - JSON_parser_callback callback; - void* callback_ctx; - int depth; - int allow_comments; - int handle_floats_manually; -} JSON_config; - - -/*! \brief Initializes the JSON parser configuration structure to default values. - - The default configuration is - - 127 levels of nested JSON (depends on JSON_PARSER_STACK_SIZE, see json_parser.c) - - no parsing, just checking for JSON syntax - - no comments - - \param config. Used to configure the parser. -*/ -JSON_PARSER_DLL_API void init_JSON_config(JSON_config* config); - -/*! \brief Create a JSON parser object - - \param config. Used to configure the parser. Set to NULL to use the default configuration. - See init_JSON_config - - \return The parser object. -*/ -JSON_PARSER_DLL_API extern JSON_parser new_JSON_parser(JSON_config* config); - -/*! \brief Destroy a previously created JSON parser object. */ -JSON_PARSER_DLL_API extern void delete_JSON_parser(JSON_parser jc); - -/*! \brief Parse a character. - - \return Non-zero, if all characters passed to this function are part of are valid JSON. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_char(JSON_parser jc, int next_char); - -/*! \brief Finalize parsing. - - Call this method once after all input characters have been consumed. - - \return Non-zero, if all parsed characters are valid JSON, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_done(JSON_parser jc); - -/*! \brief Determine if a given string is valid JSON white space - - \return Non-zero if the string is valid, zero otherwise. -*/ -JSON_PARSER_DLL_API extern int JSON_parser_is_legal_white_space_string(const char* s); - - -#ifdef __cplusplus -} -#endif - - -#endif /* JSON_PARSER_H */ diff --git a/decoder/Makefile.am b/decoder/Makefile.am index e46a7120..b56e4c72 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -33,7 +33,6 @@ noinst_LIBRARIES = libcdec.a EXTRA_DIST = test_data rule_lexer.ll libcdec_a_SOURCES = \ - JSON_parser.h \ aligner.h \ apply_models.h \ bottom_up_parser.h \ @@ -80,7 +79,6 @@ libcdec_a_SOURCES = \ hg_union.h \ incremental.h \ inside_outside.h \ - json_parse.h \ kbest.h \ lattice.h \ lexalign.h \ @@ -141,7 +139,6 @@ libcdec_a_SOURCES = \ hg_sampler.cc \ hg_union.cc \ incremental.cc \ - json_parse.cc \ lattice.cc \ lexalign.cc \ lextrans.cc \ @@ -157,5 +154,4 @@ libcdec_a_SOURCES = \ tagger.cc \ translator.cc \ trule.cc \ - viterbi.cc \ - JSON_parser.c + viterbi.cc diff --git a/decoder/aligner.h b/decoder/aligner.h index a34795c9..d68ceefc 100644 --- a/decoder/aligner.h +++ b/decoder/aligner.h @@ -1,4 +1,4 @@ -#ifndef _ALIGNER_H_ +#ifndef ALIGNER_H #include #include diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 50e6adcc..fe28f4c6 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -27,11 +27,15 @@ struct FSTTranslatorImpl { const vector& weights, Hypergraph* forest) { bool composed = false; - if (input.find("{\"rules\"") == 0) { + if (input.find("::forest::") == 0) { istringstream is(input); + string header, fname; + is >> header >> fname; + ReadFile rf(fname); + if (!rf) { cerr << "Failed to open " << fname << endl; abort(); } Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, &src_cfg_hg)) { - cerr << "Failed to read HG from JSON.\n"; + if (!HypergraphIO::ReadFromBinary(rf.stream(), &src_cfg_hg)) { + cerr << "Failed to read HG.\n"; abort(); } if (add_pass_through_rules) { diff --git a/decoder/hg.h b/decoder/hg.h index 124eab86..c756012e 100644 --- a/decoder/hg.h +++ b/decoder/hg.h @@ -71,7 +71,7 @@ namespace HG { short int prev_i_; short int prev_j_; template - void serialize(Archive & ar, const unsigned int version) { + void serialize(Archive & ar, const unsigned int /*version*/) { ar & head_node_; ar & tail_nodes_; ar & rule_; @@ -163,7 +163,7 @@ namespace HG { EdgesVector in_edges_; // an in edge is an edge with this node as its head. (in edges come from the bottom up to us) indices in edges_ EdgesVector out_edges_; // an out edge is an edge with this node as its tail. (out edges leave us up toward the top/goal). indices in edges_ template - void save(Archive & ar, const unsigned int version) const { + void save(Archive & ar, const unsigned int /*version*/) const { ar & node_hash; ar & id_; ar & TD::Convert(-cat_); @@ -171,7 +171,7 @@ namespace HG { ar & out_edges_; } template - void load(Archive & ar, const unsigned int version) { + void load(Archive & ar, const unsigned int /*version*/) { ar & node_hash; ar & id_; std::string cat; ar & cat; @@ -524,7 +524,7 @@ public: void check_ids() const; // assert that .id_ have been kept in sync template - void save(Archive & ar, const unsigned int version) const { + void save(Archive & ar, const unsigned int /*version*/) const { unsigned ns = nodes_.size(); ar & ns; unsigned es = edges_.size(); ar & es; for (auto& n : nodes_) ar & n; @@ -534,7 +534,7 @@ public: x = is_linear_chain_; ar & x; } template - void load(Archive & ar, const unsigned int version) { + void load(Archive & ar, const unsigned int /*version*/) { unsigned ns; ar & ns; nodes_.resize(ns); unsigned es; ar & es; edges_.resize(es); for (auto& n : nodes_) ar & n; diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 67760fb1..626b2954 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -13,268 +13,10 @@ #include "fast_lexical_cast.hpp" #include "tdict.h" -#include "json_parse.h" #include "hg.h" using namespace std; -struct HGReader : public JSONParser { - HGReader(Hypergraph* g) : rp("[X] ||| "), state(-1), hg(*g), nodes_needed(true), edges_needed(true) { nodes = 0; edges = 0; } - - void CreateNode(const string& cat, const string& shash, const vector& in_edges) { - WordID c = TD::Convert("X") * -1; - if (!cat.empty()) c = TD::Convert(cat) * -1; - Hypergraph::Node* node = hg.AddNode(c); - char* dend; - if (shash.size()) - node->node_hash = strtoull(shash.c_str(), &dend, 16); - else - node->node_hash = 0; - for (int i = 0; i < in_edges.size(); ++i) { - if (in_edges[i] >= hg.edges_.size()) { - cerr << "JSONParser: in_edges[" << i << "]=" << in_edges[i] - << ", but hg only has " << hg.edges_.size() << " edges!\n"; - abort(); - } - hg.ConnectEdgeToHeadNode(&hg.edges_[in_edges[i]], node); - } - } - void CreateEdge(const TRulePtr& rule, SparseVector* feats, const SmallVectorUnsigned& tail) { - Hypergraph::Edge* edge = hg.AddEdge(rule, tail); - feats->swap(edge->feature_values_); - edge->i_ = spans[0]; - edge->j_ = spans[1]; - edge->prev_i_ = spans[2]; - edge->prev_j_ = spans[3]; - } - - bool HandleJSONEvent(int type, const JSON_value* value) { - switch(state) { - case -1: - assert(type == JSON_T_OBJECT_BEGIN); - state = 0; - break; - case 0: - if (type == JSON_T_OBJECT_END) { - //cerr << "HG created\n"; // TODO, signal some kind of callback - } else if (type == JSON_T_KEY) { - string val = value->vu.str.value; - if (val == "features") { assert(fdict.empty()); state = 1; } - else if (val == "is_sorted") { state = 3; } - else if (val == "rules") { assert(rules.empty()); state = 4; } - else if (val == "node") { state = 8; } - else if (val == "edges") { state = 13; } - else { cerr << "Unexpected key: " << val << endl; return false; } - } - break; - - // features - case 1: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 2; - break; - case 2: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_STRING); - fdict.push_back(FD::Convert(value->vu.str.value)); - assert(fdict.back() > 0); - break; - - // is_sorted - case 3: - assert(type == JSON_T_TRUE || type == JSON_T_FALSE); - is_sorted = (type == JSON_T_TRUE); - if (!is_sorted) { cerr << "[WARNING] is_sorted flag is ignored\n"; } - state = 0; - break; - - // rules - case 4: - if(type == JSON_T_NULL) { state = 0; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 5; - break; - case 5: - if(type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_INTEGER); - state = 6; - rule_id = value->vu.integer_value; - break; - case 6: - assert(type == JSON_T_STRING); - rules[rule_id] = TRulePtr(new TRule(value->vu.str.value)); - state = 5; - break; - - // Nodes - case 8: - assert(type == JSON_T_OBJECT_BEGIN); - ++nodes; - in_edges.clear(); - cat.clear(); - shash.clear(); - state = 9; break; - case 9: - if (type == JSON_T_OBJECT_END) { - //cerr << "Creating NODE\n"; - CreateNode(cat, shash, in_edges); - state = 0; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - if (cur_key == "cat") { assert(cat.empty()); state = 10; break; } - if (cur_key == "in_edges") { assert(in_edges.empty()); state = 11; break; } - if (cur_key == "node_hash") { assert(shash.empty()); state = 24; break; } - cerr << "Syntax error: unexpected key " << cur_key << " in node specification.\n"; - return false; - case 10: - assert(type == JSON_T_STRING || type == JSON_T_NULL); - cat = value->vu.str.value; - state = 9; break; - case 11: - if (type == JSON_T_NULL) { state = 9; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 12; break; - case 12: - if (type == JSON_T_ARRAY_END) { state = 9; break; } - assert(type == JSON_T_INTEGER); - //cerr << "in_edges: " << value->vu.integer_value << endl; - in_edges.push_back(value->vu.integer_value); - break; - - // "edges": [ { "tail": null, "feats" : [0,1.63,1,-0.54], "rule": 12}, - // { "tail": null, "feats" : [0,0.87,1,0.02], "spans":[1,2,3,4], "rule": 17}, - // { "tail": [0], "feats" : [1,2.3,2,15.3,"ExtraFeature",1.2], "rule": 13}] - case 13: - assert(type == JSON_T_ARRAY_BEGIN); - state = 14; - break; - case 14: - if (type == JSON_T_ARRAY_END) { state = 0; break; } - assert(type == JSON_T_OBJECT_BEGIN); - //cerr << "New edge\n"; - ++edges; - cur_rule.reset(); feats.clear(); tail.clear(); - state = 15; break; - case 15: - if (type == JSON_T_OBJECT_END) { - CreateEdge(cur_rule, &feats, tail); - state = 14; break; - } - assert(type == JSON_T_KEY); - cur_key = value->vu.str.value; - //cerr << "edge key " << cur_key << endl; - if (cur_key == "rule") { assert(!cur_rule); state = 16; break; } - if (cur_key == "spans") { assert(!cur_rule); state = 22; break; } - if (cur_key == "feats") { assert(feats.empty()); state = 17; break; } - if (cur_key == "tail") { assert(tail.empty()); state = 20; break; } - cerr << "Unexpected key " << cur_key << " in edge specification\n"; - return false; - case 16: // edge.rule - if (type == JSON_T_INTEGER) { - int rule_id = value->vu.integer_value; - if (rules.find(rule_id) == rules.end()) { - // rules list must come before the edge definitions! - cerr << "Rule_id " << rule_id << " given but only loaded " << rules.size() << " rules\n"; - return false; - } - cur_rule = rules[rule_id]; - } else if (type == JSON_T_STRING) { - cur_rule.reset(new TRule(value->vu.str.value)); - } else { - cerr << "Rule must be either a rule id or a rule string" << endl; - return false; - } - // cerr << "Edge: rule=" << cur_rule->AsString() << endl; - state = 15; - break; - case 17: // edge.feats - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 18; break; - case 18: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - if (type != JSON_T_INTEGER && type != JSON_T_STRING) { - cerr << "Unexpected feature id type\n"; return false; - } - if (type == JSON_T_INTEGER) { - fid = value->vu.integer_value; - assert(fid < fdict.size()); - fid = fdict[fid]; - } else if (JSON_T_STRING) { - fid = FD::Convert(value->vu.str.value); - } else { abort(); } - state = 19; - break; - case 19: - { - assert(type == JSON_T_INTEGER || type == JSON_T_FLOAT); - double val = (type == JSON_T_INTEGER ? static_cast(value->vu.integer_value) : - strtod(value->vu.str.value, NULL)); - feats.set_value(fid, val); - state = 18; - break; - } - case 20: // edge.tail - if (type == JSON_T_NULL) { state = 15; break; } - assert(type == JSON_T_ARRAY_BEGIN); - state = 21; break; - case 21: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - tail.push_back(value->vu.integer_value); - break; - case 22: // edge.spans - assert(type == JSON_T_ARRAY_BEGIN); - state = 23; - spans[0] = spans[1] = spans[2] = spans[3] = -1; - spanc = 0; - break; - case 23: - if (type == JSON_T_ARRAY_END) { state = 15; break; } - assert(type == JSON_T_INTEGER); - assert(spanc < 4); - spans[spanc] = value->vu.integer_value; - ++spanc; - break; - case 24: // read node hash - assert(type == JSON_T_STRING); - shash = value->vu.str.value; - state = 9; - break; - } - return true; - } - string rp; - string cat; - SmallVectorUnsigned tail; - vector in_edges; - string shash; - TRulePtr cur_rule; - map rules; - vector fdict; - SparseVector feats; - int state; - int fid; - int nodes; - int edges; - int spans[4]; - int spanc; - string cur_key; - Hypergraph& hg; - int rule_id; - bool nodes_needed; - bool edges_needed; - bool is_sorted; -}; - -bool HypergraphIO::ReadFromJSON(istream* in, Hypergraph* hg) { - hg->clear(); - HGReader reader(hg); - return reader.Parse(in); -} - bool HypergraphIO::ReadFromBinary(istream* in, Hypergraph* hg) { boost::archive::binary_iarchive oa(*in); hg->clear(); diff --git a/decoder/hg_io.h b/decoder/hg_io.h index 5ba86f69..93a9e280 100644 --- a/decoder/hg_io.h +++ b/decoder/hg_io.h @@ -9,15 +9,6 @@ class Hypergraph; struct HypergraphIO { - // the format is basically a list of nodes and edges in topological order - // any edge you read, you must have already read its tail nodes - // any node you read, you must have already read its incoming edges - // this may make writing a bit more challenging if your forest is not - // topologically sorted (but that probably doesn't happen very often), - // but it makes reading much more memory efficient. - // see test_data/small.json.gz for an email encoding - static bool ReadFromJSON(std::istream* in, Hypergraph* out); - static bool ReadFromBinary(std::istream* in, Hypergraph* out); static bool WriteToBinary(const Hypergraph& hg, std::ostream* out); diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 25eddcec..366b269d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -9,7 +9,6 @@ #include #include "tdict.h" -#include "json_parse.h" #include "hg_intersect.h" #include "hg_union.h" #include "viterbi.h" @@ -399,16 +398,6 @@ BOOST_AUTO_TEST_CASE(Small) { BOOST_CHECK_CLOSE(2.1431036, log(c2), 1e-4); } -BOOST_AUTO_TEST_CASE(JSONTest) { - std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); - ostringstream os; - JSONParser::WriteEscapedString("\"I don't know\", she said.", &os); - BOOST_CHECK_EQUAL("\"\\\"I don't know\\\", she said.\"", os.str()); - ostringstream os2; - JSONParser::WriteEscapedString("yes", &os2); - BOOST_CHECK_EQUAL("\"yes\"", os2.str()); -} - BOOST_AUTO_TEST_CASE(TestGenericKBest) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; diff --git a/decoder/hg_test.h b/decoder/hg_test.h index b7bab3c2..70c2c97d 100644 --- a/decoder/hg_test.h +++ b/decoder/hg_test.h @@ -12,11 +12,11 @@ namespace { typedef char const* Name; -Name urdu_json="urdu.json.gz"; +Name urdu_json="urdu.bin.gz"; Name urdu_wts="Arity_0 1.70741473606976 Arity_1 1.12426238048012 Arity_2 1.14986187839554 Glue -0.04589037041388 LanguageModel 1.09051 PassThrough -3.66226367902928 PhraseModel_0 -1.94633451863252 PhraseModel_1 -0.1475347695476 PhraseModel_2 -1.614818994946 WordPenalty -3.0 WordPenaltyFsa -0.56028442964748 ShorterThanPrev -10 LongerThanPrev -10"; -Name small_json="small.json.gz"; +Name small_json="small.bin.gz"; Name small_wts="Model_0 -2 Model_1 -.5 Model_2 -1.1 Model_3 -1 Model_4 -1 Model_5 .5 Model_6 .2 Model_7 -.3"; -Name perro_json="perro.json.gz"; +Name perro_json="perro.bin.gz"; Name perro_wts="SameFirstLetter 1 LongerThanPrev 1 ShorterThanPrev 1 GlueTop 0.0 Glue -1.0 EgivenF -0.5 FgivenE -0.5 LexEgivenF -0.5 LexFgivenE -0.5 LM 1"; } @@ -32,7 +32,7 @@ struct HGSetup { static void JsonFile(Hypergraph *hg,std::string f) { ReadFile rf(f); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } static void JsonTestFile(Hypergraph *hg,std::string path,std::string n) { JsonFile(hg,path + "/"+n); @@ -48,35 +48,35 @@ void AddNullEdge(Hypergraph* hg) { } void HGSetup::CreateTinyLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny_lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny_lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateLatticeHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.lattice"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.lattice.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); AddNullEdge(hg); } void HGSetup::CreateHG_tiny(const std::string& path, Hypergraph* hg) { - ReadFile rf(path + "/hg_test.tiny"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.tiny.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG_int(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_int"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_int.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHG(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } void HGSetup::CreateHGBalanced(const std::string& path,Hypergraph* hg) { - ReadFile rf(path + "/hg_test.hg_balanced"); - HypergraphIO::ReadFromJSON(rf.stream(), hg); + ReadFile rf(path + "/hg_test.hg_balanced.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), hg); } #endif diff --git a/decoder/json_parse.cc b/decoder/json_parse.cc deleted file mode 100644 index f6fdfea8..00000000 --- a/decoder/json_parse.cc +++ /dev/null @@ -1,50 +0,0 @@ -#include "json_parse.h" - -#include -#include - -using namespace std; - -static const char *json_hex_chars = "0123456789abcdef"; - -void JSONParser::WriteEscapedString(const string& in, ostream* out) { - int pos = 0; - int start_offset = 0; - unsigned char c = 0; - (*out) << '"'; - while(pos < in.size()) { - c = in[pos]; - switch(c) { - case '\b': - case '\n': - case '\r': - case '\t': - case '"': - case '\\': - case '/': - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - if(c == '\b') (*out) << "\\b"; - else if(c == '\n') (*out) << "\\n"; - else if(c == '\r') (*out) << "\\r"; - else if(c == '\t') (*out) << "\\t"; - else if(c == '"') (*out) << "\\\""; - else if(c == '\\') (*out) << "\\\\"; - else if(c == '/') (*out) << "\\/"; - start_offset = ++pos; - break; - default: - if(c < ' ') { - cerr << "Warning, bad character (" << static_cast(c) << ") in string\n"; - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << "\\u00" << json_hex_chars[c >> 4] << json_hex_chars[c & 0xf]; - start_offset = ++pos; - } else pos++; - } - } - if(pos - start_offset > 0) - (*out) << in.substr(start_offset, pos - start_offset); - (*out) << '"'; -} - diff --git a/decoder/json_parse.h b/decoder/json_parse.h deleted file mode 100644 index 85e2eff1..00000000 --- a/decoder/json_parse.h +++ /dev/null @@ -1,58 +0,0 @@ -#ifndef JSON_WRAPPER_H_ -#define JSON_WRAPPER_H_ - -#include -#include -#include "JSON_parser.h" - -class JSONParser { - public: - JSONParser() { - init_JSON_config(&config); - hack.mf = &JSONParser::Callback; - config.depth = 10; - config.callback_ctx = reinterpret_cast(this); - config.callback = hack.cb; - config.allow_comments = 1; - config.handle_floats_manually = 1; - jc = new_JSON_parser(&config); - } - virtual ~JSONParser() { - delete_JSON_parser(jc); - } - bool Parse(std::istream* in) { - int count = 0; - int lc = 1; - for (; in ; ++count) { - int next_char = in->get(); - if (!in->good()) break; - if (lc == '\n') { ++lc; } - if (!JSON_parser_char(jc, next_char)) { - std::cerr << "JSON_parser_char: syntax error, line " << lc << " (byte " << count << ")" << std::endl; - return false; - } - } - if (!JSON_parser_done(jc)) { - std::cerr << "JSON_parser_done: syntax error\n"; - return false; - } - return true; - } - static void WriteEscapedString(const std::string& in, std::ostream* out); - protected: - virtual bool HandleJSONEvent(int type, const JSON_value* value) = 0; - private: - int Callback(int type, const JSON_value* value) { - if (HandleJSONEvent(type, value)) return 1; - return 0; - } - JSON_parser_struct* jc; - JSON_config config; - typedef int (JSONParser::* MF)(int type, const struct JSON_value_struct* value); - union CBHack { - JSON_parser_callback cb; - MF mf; - } hack; -}; - -#endif diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 18c83c56..2c5fa9c4 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -3,6 +3,7 @@ #include #include +#include "filelib.h" #include "sentence_metadata.h" #include "hg.h" #include "hg_io.h" @@ -20,16 +21,18 @@ struct RescoreTranslatorImpl { bool Translate(const string& input, const vector& weights, Hypergraph* forest) { - if (input == "{}") return false; - if (input.find("{\"rules\"") == 0) { - istringstream is(input); - Hypergraph src_cfg_hg; - if (!HypergraphIO::ReadFromJSON(&is, forest)) { - cerr << "Parse error while reading HG from JSON.\n"; - abort(); - } - } else { - cerr << "Can only read HG input from JSON: use training/grammar_convert\n"; + istringstream is(input); + string header, fname; + is >> header >> fname; + if (header != "::forest::") { + cerr << "RescoreTranslator: expected input lines of form ::forest:: filename.gz\n"; + abort(); + } + ReadFile rf(fname); + if (!rf) { cerr << "Can't read " << fname << endl; abort(); } + Hypergraph src_cfg_hg; + if (!HypergraphIO::ReadFromBinary(rf.stream(), forest)) { + cerr << "Parse error while reading HG.\n"; abort(); } Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); diff --git a/decoder/test_data/perro.json.gz b/decoder/test_data/perro.json.gz deleted file mode 100644 index 41de5758..00000000 Binary files a/decoder/test_data/perro.json.gz and /dev/null differ diff --git a/decoder/test_data/small.json.gz b/decoder/test_data/small.json.gz deleted file mode 100644 index f6f37293..00000000 Binary files a/decoder/test_data/small.json.gz and /dev/null differ diff --git a/decoder/test_data/urdu.json.gz b/decoder/test_data/urdu.json.gz deleted file mode 100644 index 84535402..00000000 Binary files a/decoder/test_data/urdu.json.gz and /dev/null differ diff --git a/decoder/trule.h b/decoder/trule.h index 85842bb5..7af46747 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -167,7 +167,7 @@ class TRule { friend class boost::serialization::access; template - void save(Archive & ar, const unsigned int version) const { + void save(Archive & ar, const unsigned int /*version*/) const { ar & TD::Convert(-lhs_); unsigned f_size = f_.size(); ar & f_size; @@ -195,7 +195,7 @@ class TRule { ar & scores_; } template - void load(Archive & ar, const unsigned int version) { + void load(Archive & ar, const unsigned int /*version*/) { std::string lhs; ar & lhs; lhs_ = -TD::Convert(lhs); unsigned f_size; ar & f_size; f_.resize(f_size); diff --git a/tests/system_tests/cfg_rescore/input.txt b/tests/system_tests/cfg_rescore/input.txt index 2999a5fb..99624d85 100644 --- a/tests/system_tests/cfg_rescore/input.txt +++ b/tests/system_tests/cfg_rescore/input.txt @@ -1 +1 @@ -{"rules":[1,"[S] ||| [NP1] [VP] ||| [1] [2] ||| Active=1",2,"[S] ||| [NP2] [VPSV] by [NP1] ||| [1] [2] by [3] ||| Passive=1",3,"[VP] ||| [V] [NP2] ||| [1] [2]",4,"[V] ||| ate ||| ate",5,"[VPSV] ||| was eaten ||| was eaten",6,"[NP1] ||| John ||| John",7,"[NP2] ||| broccoli ||| broccoli",8,"[NP2] ||| the broccoli ||| the broccoli ||| Definite=1",9,"[Goal] ||| [X] ||| [1]"],"features":["PhraseModel_0","PhraseModel_1","PhraseModel_2","PhraseModel_3","PhraseModel_4","PhraseModel_5","PhraseModel_6","PhraseModel_7","PhraseModel_8","PhraseModel_9","PhraseModel_10","PhraseModel_11","PhraseModel_12","PhraseModel_13","PhraseModel_14","PhraseModel_15","PhraseModel_16","PhraseModel_17","PhraseModel_18","PhraseModel_19","PhraseModel_20","PhraseModel_21","PhraseModel_22","PhraseModel_23","PhraseModel_24","PhraseModel_25","PhraseModel_26","PhraseModel_27","PhraseModel_28","PhraseModel_29","PhraseModel_30","PhraseModel_31","PhraseModel_32","PhraseModel_33","PhraseModel_34","PhraseModel_35","PhraseModel_36","PhraseModel_37","PhraseModel_38","PhraseModel_39","PhraseModel_40","PhraseModel_41","PhraseModel_42","PhraseModel_43","PhraseModel_44","PhraseModel_45","PhraseModel_46","PhraseModel_47","PhraseModel_48","PhraseModel_49","PhraseModel_50","PhraseModel_51","PhraseModel_52","PhraseModel_53","PhraseModel_54","PhraseModel_55","PhraseModel_56","PhraseModel_57","PhraseModel_58","PhraseModel_59","PhraseModel_60","PhraseModel_61","PhraseModel_62","PhraseModel_63","PhraseModel_64","PhraseModel_65","PhraseModel_66","PhraseModel_67","PhraseModel_68","PhraseModel_69","PhraseModel_70","PhraseModel_71","PhraseModel_72","PhraseModel_73","PhraseModel_74","PhraseModel_75","PhraseModel_76","PhraseModel_77","PhraseModel_78","PhraseModel_79","PhraseModel_80","PhraseModel_81","PhraseModel_82","PhraseModel_83","PhraseModel_84","PhraseModel_85","PhraseModel_86","PhraseModel_87","PhraseModel_88","PhraseModel_89","PhraseModel_90","PhraseModel_91","PhraseModel_92","PhraseModel_93","PhraseModel_94","PhraseModel_95","PhraseModel_96","PhraseModel_97","PhraseModel_98","PhraseModel_99","Active","Passive","Definite"],"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":6}],"node":{"in_edges":[0],"cat":"NP1","node_hash":"0000000000000006"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":4}],"node":{"in_edges":[1],"cat":"V","node_hash":"0000000000000004"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":7},{"tail":[],"spans":[-1,-1,-1,-1],"feats":[102,1],"rule":8}],"node":{"in_edges":[2,3],"cat":"NP2","node_hash":"0000000000000008"},"edges":[{"tail":[1,2],"spans":[-1,-1,-1,-1],"feats":[],"rule":3}],"node":{"in_edges":[4],"cat":"VP","node_hash":"0000000000000003"},"edges":[{"tail":[],"spans":[-1,-1,-1,-1],"feats":[],"rule":5}],"node":{"in_edges":[5],"cat":"VPSV","node_hash":"0000000000000005"},"edges":[{"tail":[0,3],"spans":[-1,-1,-1,-1],"feats":[100,1],"rule":1},{"tail":[2,4,0],"spans":[-1,-1,-1,-1],"feats":[101,1],"rule":2}],"node":{"in_edges":[6,7],"cat":"S","node_hash":"0000000000000002"},"edges":[{"tail":[5],"spans":[-1,-1,-1,-1],"feats":[],"rule":9}],"node":{"in_edges":[8],"cat":"Goal","node_hash":"000000000000003D"}} +::forest:: input0.hg.bin.gz diff --git a/tests/system_tests/ftrans/input.txt b/tests/system_tests/ftrans/input.txt index aa37b2e7..99624d85 100644 --- a/tests/system_tests/ftrans/input.txt +++ b/tests/system_tests/ftrans/input.txt @@ -1 +1 @@ -{"rules":[1,"[B] ||| b ||| b",2,"[C] ||| c ||| c",3,"[A] ||| [B,1] [C,2] ||| [1] [2] ||| Mono=1",4,"[A] ||| [C,1] [B,2] ||| [1] [2] ||| Inv=1",5,"[S] ||| [A,1] ||| [1]"],"features":["Mono","Inv"],"edges":[{"tail":[],"feats":[],"rule":1}],"node":{"in_edges":[0],"cat":"B"},"edges":[{"tail":[],"feats":[],"rule":2}],"node":{"in_edges":[1],"cat":"C"},"edges":[{"tail":[0,1],"feats":[0,1],"rule":3},{"tail":[1,0],"feats":[1,1],"rule":4}],"node":{"in_edges":[2,3],"cat":"A"},"edges":[{"tail":[2],"feats":[],"rule":5}],"node":{"in_edges":[4],"cat":"S"}} +::forest:: input0.hg.bin.gz diff --git a/tests/system_tests/ftrans/input0.hg.bin.gz b/tests/system_tests/ftrans/input0.hg.bin.gz new file mode 100644 index 00000000..210f4a44 Binary files /dev/null and b/tests/system_tests/ftrans/input0.hg.bin.gz differ diff --git a/training/dpmert/lo_test.cc b/training/dpmert/lo_test.cc index b8776169..69e5aa3f 100644 --- a/training/dpmert/lo_test.cc +++ b/training/dpmert/lo_test.cc @@ -56,10 +56,11 @@ BOOST_AUTO_TEST_CASE(TestConvexHull) { } BOOST_AUTO_TEST_CASE(TestConvexHullInside) { - const string json = "{\"rules\":[1,\"[X] ||| a ||| a\",2,\"[X] ||| A [X] ||| A [1]\",3,\"[X] ||| c ||| c\",4,\"[X] ||| C [X] ||| C [1]\",5,\"[X] ||| [X] B [X] ||| [1] B [2]\",6,\"[X] ||| [X] b [X] ||| [1] b [2]\",7,\"[X] ||| X [X] ||| X [1]\",8,\"[X] ||| Z [X] ||| Z [1]\"],\"features\":[\"f1\",\"f2\",\"Feature_1\",\"Feature_0\",\"Model_0\",\"Model_1\",\"Model_2\",\"Model_3\",\"Model_4\",\"Model_5\",\"Model_6\",\"Model_7\"],\"edges\":[{\"tail\":[],\"feats\":[],\"rule\":1}],\"node\":{\"in_edges\":[0]},\"edges\":[{\"tail\":[0],\"feats\":[0,-0.8,1,-0.1],\"rule\":2}],\"node\":{\"in_edges\":[1]},\"edges\":[{\"tail\":[],\"feats\":[1,-1],\"rule\":3}],\"node\":{\"in_edges\":[2]},\"edges\":[{\"tail\":[2],\"feats\":[0,-0.2,1,-0.1],\"rule\":4}],\"node\":{\"in_edges\":[3]},\"edges\":[{\"tail\":[1,3],\"feats\":[0,-1.2,1,-0.2],\"rule\":5},{\"tail\":[1,3],\"feats\":[0,-0.5,1,-1.3],\"rule\":6}],\"node\":{\"in_edges\":[4,5]},\"edges\":[{\"tail\":[4],\"feats\":[0,-0.5,1,-0.8],\"rule\":7},{\"tail\":[4],\"feats\":[0,-0.7,1,-0.9],\"rule\":8}],\"node\":{\"in_edges\":[6,7]}}"; + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); + ReadFile rf(path + "/test-ch-inside.bin.gz"); + assert(rf); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); SparseVector wts; wts.set_value(FD::Convert("f1"), 0.4); wts.set_value(FD::Convert("f2"), 1.0); @@ -121,13 +122,13 @@ BOOST_AUTO_TEST_CASE( TestS1) { std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); Hypergraph hg; - ReadFile rf(path + "/0.json.gz"); - HypergraphIO::ReadFromJSON(rf.stream(), &hg); + ReadFile rf(path + "/0.bin.gz"); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); hg.Reweight(wts); Hypergraph hg2; - ReadFile rf2(path + "/1.json.gz"); - HypergraphIO::ReadFromJSON(rf2.stream(), &hg2); + ReadFile rf2(path + "/1.bin.gz"); + HypergraphIO::ReadFromBinary(rf2.stream(), &hg2); hg2.Reweight(wts); vector > refs1(4); @@ -193,10 +194,11 @@ BOOST_AUTO_TEST_CASE( TestS1) { } BOOST_AUTO_TEST_CASE(TestZeroOrigin) { - const string json = "{\"rules\":[1,\"[X7] ||| blA ||| without ||| LHSProb=3.92173 LexE2F=2.90799 LexF2E=1.85003 GenerativeProb=10.5381 RulePenalty=1 XFE=2.77259 XEF=0.441833 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=0.693147\",2,\"[X7] ||| blA ||| except ||| LHSProb=4.92173 LexE2F=3.90799 LexF2E=1.85003 GenerativeProb=11.5381 RulePenalty=1 XFE=2.77259 XEF=1.44183 LabelledEF=2.63906 LabelledFE=4.96981 LogRuleCount=1.69315\",3,\"[S] ||| [X7,1] ||| [1] ||| GlueTop=1\",4,\"[X28] ||| EnwAn ||| title ||| LHSProb=3.96802 LexE2F=2.22462 LexF2E=1.83258 GenerativeProb=10.0863 RulePenalty=1 XFE=0 XEF=1.20397 LabelledEF=1.20397 LabelledFE=-1.98341e-08 LogRuleCount=1.09861\",5,\"[X0] ||| EnwAn ||| funny ||| LHSProb=3.98479 LexE2F=1.79176 LexF2E=3.21888 GenerativeProb=11.1681 RulePenalty=1 XFE=0 XEF=2.30259 LabelledEF=2.30259 LabelledFE=0 LogRuleCount=0 SingletonRule=1\",6,\"[X8] ||| [X7,1] EnwAn ||| entitled [1] ||| LHSProb=3.82533 LexE2F=3.21888 LexF2E=2.52573 GenerativeProb=11.3276 RulePenalty=1 XFE=1.20397 XEF=1.20397 LabelledEF=2.30259 LabelledFE=2.30259 LogRuleCount=0 SingletonRule=1\",7,\"[S] ||| [S,1] [X28,2] ||| [1] [2] ||| Glue=1\",8,\"[S] ||| [S,1] [X0,2] ||| [1] [2] ||| Glue=1\",9,\"[S] ||| [X8,1] ||| [1] ||| GlueTop=1\",10,\"[Goal] ||| [S,1] ||| [1]\"],\"features\":[\"PassThrough\",\"Glue\",\"GlueTop\",\"LanguageModel\",\"WordPenalty\",\"LHSProb\",\"LexE2F\",\"LexF2E\",\"GenerativeProb\",\"RulePenalty\",\"XFE\",\"XEF\",\"LabelledEF\",\"LabelledFE\",\"LogRuleCount\",\"SingletonRule\"],\"edges\":[{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,3.92173,6,2.90799,7,1.85003,8,10.5381,9,1,10,2.77259,11,0.441833,12,2.63906,13,4.96981,14,0.693147],\"rule\":1},{\"tail\":[],\"spans\":[0,1,-1,-1],\"feats\":[5,4.92173,6,3.90799,7,1.85003,8,11.5381,9,1,10,2.77259,11,1.44183,12,2.63906,13,4.96981,14,1.69315],\"rule\":2}],\"node\":{\"in_edges\":[0,1],\"cat\":\"X7\"},\"edges\":[{\"tail\":[0],\"spans\":[0,1,-1,-1],\"feats\":[2,1],\"rule\":3}],\"node\":{\"in_edges\":[2],\"cat\":\"S\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.96802,6,2.22462,7,1.83258,8,10.0863,9,1,11,1.20397,12,1.20397,13,-1.98341e-08,14,1.09861],\"rule\":4}],\"node\":{\"in_edges\":[3],\"cat\":\"X28\"},\"edges\":[{\"tail\":[],\"spans\":[1,2,-1,-1],\"feats\":[5,3.98479,6,1.79176,7,3.21888,8,11.1681,9,1,11,2.30259,12,2.30259,15,1],\"rule\":5}],\"node\":{\"in_edges\":[4],\"cat\":\"X0\"},\"edges\":[{\"tail\":[0],\"spans\":[0,2,-1,-1],\"feats\":[5,3.82533,6,3.21888,7,2.52573,8,11.3276,9,1,10,1.20397,11,1.20397,12,2.30259,13,2.30259,15,1],\"rule\":6}],\"node\":{\"in_edges\":[5],\"cat\":\"X8\"},\"edges\":[{\"tail\":[1,2],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":7},{\"tail\":[1,3],\"spans\":[0,2,-1,-1],\"feats\":[1,1],\"rule\":8},{\"tail\":[4],\"spans\":[0,2,-1,-1],\"feats\":[2,1],\"rule\":9}],\"node\":{\"in_edges\":[6,7,8],\"cat\":\"S\"},\"edges\":[{\"tail\":[5],\"spans\":[0,2,-1,-1],\"feats\":[],\"rule\":10}],\"node\":{\"in_edges\":[9],\"cat\":\"Goal\"}}"; + std::string path(boost::unit_test::framework::master_test_suite().argc == 2 ? boost::unit_test::framework::master_test_suite().argv[1] : TEST_DATA); + ReadFile rf(path + "/test-zero-origin.bin.gz"); + assert(rf); Hypergraph hg; - istringstream instr(json); - HypergraphIO::ReadFromJSON(&instr, &hg); + HypergraphIO::ReadFromBinary(rf.stream(), &hg); SparseVector wts; wts.set_value(FD::Convert("PassThrough"), -0.929201533002898); hg.Reweight(wts); diff --git a/training/dpmert/test_data/0.bin.gz b/training/dpmert/test_data/0.bin.gz new file mode 100644 index 00000000..388298e9 Binary files /dev/null and b/training/dpmert/test_data/0.bin.gz differ diff --git a/training/dpmert/test_data/0.json.gz b/training/dpmert/test_data/0.json.gz deleted file mode 100644 index 30f8dd77..00000000 Binary files a/training/dpmert/test_data/0.json.gz and /dev/null differ diff --git a/training/dpmert/test_data/1.bin.gz b/training/dpmert/test_data/1.bin.gz new file mode 100644 index 00000000..44f9e0ff Binary files /dev/null and b/training/dpmert/test_data/1.bin.gz differ diff --git a/training/dpmert/test_data/1.json.gz b/training/dpmert/test_data/1.json.gz deleted file mode 100644 index c82cc179..00000000 Binary files a/training/dpmert/test_data/1.json.gz and /dev/null differ diff --git a/training/dpmert/test_data/test-ch-inside.bin.gz b/training/dpmert/test_data/test-ch-inside.bin.gz new file mode 100644 index 00000000..392f08c6 Binary files /dev/null and b/training/dpmert/test_data/test-ch-inside.bin.gz differ diff --git a/training/dpmert/test_data/test-zero-origin.bin.gz b/training/dpmert/test_data/test-zero-origin.bin.gz new file mode 100644 index 00000000..c641faaf Binary files /dev/null and b/training/dpmert/test_data/test-zero-origin.bin.gz differ -- cgit v1.2.3 From 800c6df6e7aac9032e1068fb6dc985344865854a Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 19 Oct 2014 15:31:11 -0400 Subject: remove json from grammar_convert --- training/utils/grammar_convert.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'training') diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc index 000f2a26..04f1eb77 100644 --- a/training/utils/grammar_convert.cc +++ b/training/utils/grammar_convert.cc @@ -316,11 +316,11 @@ int main(int argc, char **argv) { line = line.substr(0, pos + 2); } istringstream is(line); - if (HypergraphIO::ReadFromJSON(&is, &hg)) { + if (HypergraphIO::ReadFromBinary(&is, &hg)) { ProcessHypergraph(w, conf, ref, &hg); hg.clear(); } else { - cerr << "Error reading grammar from JSON: line " << lc << endl; + cerr << "Error reading grammar line " << lc << endl; exit(1); } } else { -- cgit v1.2.3 From f2d50c333d0dde8a5ef211bc31b4978a3d8911cf Mon Sep 17 00:00:00 2001 From: "Wu, Ke" Date: Wed, 17 Dec 2014 15:41:32 -0500 Subject: Move training routine out of ff_const_reorder_common.h --- decoder/ff_const_reorder_common.h | 93 ---------------------- training/const_reorder/Makefile.am | 8 +- training/const_reorder/argument_reorder_model.cc | 6 +- .../const_reorder/constituent_reorder_model.cc | 6 +- training/const_reorder/trainer.cc | 67 ++++++++++++++++ training/const_reorder/trainer.h | 12 +++ 6 files changed, 91 insertions(+), 101 deletions(-) create mode 100644 training/const_reorder/trainer.cc create mode 100644 training/const_reorder/trainer.h (limited to 'training') diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h index 7c111de3..b124ce47 100644 --- a/decoder/ff_const_reorder_common.h +++ b/decoder/ff_const_reorder_common.h @@ -1091,99 +1091,6 @@ struct Tsuruoka_Maxent { if (m_pModel != NULL) delete m_pModel; } - void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm, - const char* pszModelFName, int /*iNumIteration*/) { - assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || - strcmp(pszAlgorithm, "sgd") == 0 || - strcmp(pszAlgorithm, "SGD") == 0); - FILE* fpIn = fopen(pszInstanceFName, "r"); - - ME_Model* pModel = new ME_Model(); - - char* pszLine = new char[100001]; - int iNumInstances = 0; - int iLen; - while (!feof(fpIn)) { - pszLine[0] = '\0'; - fgets(pszLine, 20000, fpIn); - if (strlen(pszLine) == 0) { - continue; - } - - iLen = strlen(pszLine); - while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { - pszLine[iLen - 1] = '\0'; - iLen--; - } - - iNumInstances++; - - ME_Sample* pmes = new ME_Sample(); - - char* p = strrchr(pszLine, ' '); - assert(p != NULL); - p[0] = '\0'; - p++; - std::vector vecContext; - SplitOnWhitespace(std::string(pszLine), &vecContext); - - pmes->label = std::string(p); - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - pModel->add_training_sample((*pmes)); - if (iNumInstances % 100000 == 0) - fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); - delete pmes; - } - fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); - fclose(fpIn); - - if (strcmp(pszAlgorithm, "l1") == 0) - pModel->use_l1_regularizer(1.0); - else if (strcmp(pszAlgorithm, "l2") == 0) - pModel->use_l2_regularizer(1.0); - else - pModel->use_SGD(); - - pModel->train(); - pModel->save_to_file(pszModelFName); - - delete pModel; - fprintf(stdout, "......Finished Training\n"); - fprintf(stdout, "......Model saved as %s\n", pszModelFName); - delete[] pszLine; - } - - double fnEval(const char* pszContext, const char* pszOutcome) const { - std::vector vecContext; - ME_Sample* pmes = new ME_Sample(); - SplitOnWhitespace(std::string(pszContext), &vecContext); - - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - std::vector vecProb = m_pModel->classify(*pmes); - delete pmes; - int iLableID = m_pModel->get_class_id(pszOutcome); - return vecProb[iLableID]; - } - void fnEval(const char* pszContext, - std::vector >& vecOutput) const { - std::vector vecContext; - ME_Sample* pmes = new ME_Sample(); - SplitOnWhitespace(std::string(pszContext), &vecContext); - - vecOutput.clear(); - - for (size_t i = 0; i < vecContext.size(); i++) - pmes->add_feature(vecContext[i]); - std::vector vecProb = m_pModel->classify(*pmes); - - for (size_t i = 0; i < vecProb.size(); i++) { - std::string label = m_pModel->get_class_label(i); - vecOutput.push_back(make_pair(label, vecProb[i])); - } - delete pmes; - } void fnEval(const char* pszContext, std::vector& vecOutput) const { std::vector vecContext; ME_Sample* pmes = new ME_Sample(); diff --git a/training/const_reorder/Makefile.am b/training/const_reorder/Makefile.am index 2e81e588..367ac904 100644 --- a/training/const_reorder/Makefile.am +++ b/training/const_reorder/Makefile.am @@ -1,8 +1,12 @@ +noinst_LIBRARIES = libtrainer.a + +libtrainer_a_SOURCES = trainer.h trainer.cc + bin_PROGRAMS = const_reorder_model_trainer argument_reorder_model_trainer AM_CPPFLAGS = -I$(top_srcdir) -I$(top_srcdir)/utils -I$(top_srcdir)/decoder const_reorder_model_trainer_SOURCES = constituent_reorder_model.cc -const_reorder_model_trainer_LDADD = ../../utils/libutils.a +const_reorder_model_trainer_LDADD = ../../utils/libutils.a libtrainer.a argument_reorder_model_trainer_SOURCES = argument_reorder_model.cc -argument_reorder_model_trainer_LDADD = ../../utils/libutils.a +argument_reorder_model_trainer_LDADD = ../../utils/libutils.a libtrainer.a diff --git a/training/const_reorder/argument_reorder_model.cc b/training/const_reorder/argument_reorder_model.cc index 54402436..87f2ce2f 100644 --- a/training/const_reorder/argument_reorder_model.cc +++ b/training/const_reorder/argument_reorder_model.cc @@ -14,7 +14,7 @@ #include "utils/filelib.h" -#include "decoder/ff_const_reorder_common.h" +#include "trainer.h" using namespace std; using namespace const_reorder; @@ -93,8 +93,8 @@ struct SArgumentReorderTrainer { strcpy(pszNewInstanceFName, pszInstanceFname); } - Tsuruoka_Maxent* pMaxent = new Tsuruoka_Maxent(NULL); - pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer; + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname); delete pMaxent; if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { diff --git a/training/const_reorder/constituent_reorder_model.cc b/training/const_reorder/constituent_reorder_model.cc index 6bec3f0b..d3ad0f2b 100644 --- a/training/const_reorder/constituent_reorder_model.cc +++ b/training/const_reorder/constituent_reorder_model.cc @@ -12,7 +12,7 @@ #include "utils/filelib.h" -#include "decoder/ff_const_reorder_common.h" +#include "trainer.h" using namespace std; using namespace const_reorder; @@ -104,8 +104,8 @@ struct SConstReorderTrainer { pZhangleMaxent->fnTrain(pszInstanceFname, "lbfgs", pszModelFname, 100, 2.0); delete pZhangleMaxent;*/ - Tsuruoka_Maxent* pMaxent = new Tsuruoka_Maxent(NULL); - pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname, 300); + Tsuruoka_Maxent_Trainer* pMaxent = new Tsuruoka_Maxent_Trainer; + pMaxent->fnTrain(pszNewInstanceFName, "l1", pszModelFname); delete pMaxent; if (strcmp(pszNewInstanceFName, pszInstanceFname) != 0) { diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc new file mode 100644 index 00000000..e22a8a66 --- /dev/null +++ b/training/const_reorder/trainer.cc @@ -0,0 +1,67 @@ +#include "trainer.h" + +Tsuruoka_Maxent_Trainer::Tsuruoka_Maxent_Trainer() + : const_reorder::Tsuruoka_Maxent(NULL) {} + +void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, + const char* pszAlgorithm, + const char* pszModelFName) { + assert(strcmp(pszAlgorithm, "l1") == 0 || strcmp(pszAlgorithm, "l2") == 0 || + strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0); + FILE* fpIn = fopen(pszInstanceFName, "r"); + + ME_Model* pModel = new ME_Model(); + + char* pszLine = new char[100001]; + int iNumInstances = 0; + int iLen; + while (!feof(fpIn)) { + pszLine[0] = '\0'; + fgets(pszLine, 20000, fpIn); + if (strlen(pszLine) == 0) { + continue; + } + + iLen = strlen(pszLine); + while (iLen > 0 && pszLine[iLen - 1] > 0 && pszLine[iLen - 1] < 33) { + pszLine[iLen - 1] = '\0'; + iLen--; + } + + iNumInstances++; + + ME_Sample* pmes = new ME_Sample(); + + char* p = strrchr(pszLine, ' '); + assert(p != NULL); + p[0] = '\0'; + p++; + std::vector vecContext; + SplitOnWhitespace(std::string(pszLine), &vecContext); + + pmes->label = std::string(p); + for (size_t i = 0; i < vecContext.size(); i++) + pmes->add_feature(vecContext[i]); + pModel->add_training_sample((*pmes)); + if (iNumInstances % 100000 == 0) + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + delete pmes; + } + fprintf(stdout, "......Reading #Instances: %1d\n", iNumInstances); + fclose(fpIn); + + if (strcmp(pszAlgorithm, "l1") == 0) + pModel->use_l1_regularizer(1.0); + else if (strcmp(pszAlgorithm, "l2") == 0) + pModel->use_l2_regularizer(1.0); + else + pModel->use_SGD(); + + pModel->train(); + pModel->save_to_file(pszModelFName); + + delete pModel; + fprintf(stdout, "......Finished Training\n"); + fprintf(stdout, "......Model saved as %s\n", pszModelFName); + delete[] pszLine; +} diff --git a/training/const_reorder/trainer.h b/training/const_reorder/trainer.h new file mode 100644 index 00000000..e574a536 --- /dev/null +++ b/training/const_reorder/trainer.h @@ -0,0 +1,12 @@ +#ifndef TRAINING_CONST_REORDER_TRAINER_H_ +#define TRAINING_CONST_REORDER_TRAINER_H_ + +#include "decoder/ff_const_reorder_common.h" + +struct Tsuruoka_Maxent_Trainer : const_reorder::Tsuruoka_Maxent { + Tsuruoka_Maxent_Trainer(); + void fnTrain(const char* pszInstanceFName, const char* pszAlgorithm, + const char* pszModelFName); +}; + +#endif // TRAINING_CONST_REORDER_TRAINER_H_ -- cgit v1.2.3 From bd9308e22b5434aa220cc57d82ee867464a011f1 Mon Sep 17 00:00:00 2001 From: "Wu, Ke" Date: Wed, 17 Dec 2014 16:00:04 -0500 Subject: Combine everything related to maxent to a single file --- decoder/ff_const_reorder_common.h | 6 +- training/const_reorder/trainer.cc | 4 +- utils/Makefile.am | 5 - utils/lbfgs.cpp | 108 ---------- utils/lbfgs.h | 20 -- utils/mathvec.h | 87 -------- utils/maxent.cpp | 427 +++++++++++++++++++++++++++++++++++++- utils/maxent.h | 95 ++++++++- utils/owlqn.cpp | 127 ------------ utils/sgd.cpp | 193 ----------------- 10 files changed, 516 insertions(+), 556 deletions(-) delete mode 100644 utils/lbfgs.cpp delete mode 100644 utils/lbfgs.h delete mode 100644 utils/mathvec.h delete mode 100644 utils/owlqn.cpp delete mode 100644 utils/sgd.cpp (limited to 'training') diff --git a/decoder/ff_const_reorder_common.h b/decoder/ff_const_reorder_common.h index b124ce47..755fd948 100644 --- a/decoder/ff_const_reorder_common.h +++ b/decoder/ff_const_reorder_common.h @@ -1081,7 +1081,7 @@ typedef std::unordered_map::iterator Iterator; struct Tsuruoka_Maxent { Tsuruoka_Maxent(const char* pszModelFName) { if (pszModelFName != NULL) { - m_pModel = new ME_Model(); + m_pModel = new maxent::ME_Model(); m_pModel->load_from_file(pszModelFName); } else m_pModel = NULL; @@ -1093,7 +1093,7 @@ struct Tsuruoka_Maxent { void fnEval(const char* pszContext, std::vector& vecOutput) const { std::vector vecContext; - ME_Sample* pmes = new ME_Sample(); + maxent::ME_Sample* pmes = new maxent::ME_Sample(); SplitOnWhitespace(std::string(pszContext), &vecContext); vecOutput.clear(); @@ -1113,7 +1113,7 @@ struct Tsuruoka_Maxent { } private: - ME_Model* m_pModel; + maxent::ME_Model* m_pModel; }; // an argument item or a predicate item (the verb itself) diff --git a/training/const_reorder/trainer.cc b/training/const_reorder/trainer.cc index e22a8a66..89bd7479 100644 --- a/training/const_reorder/trainer.cc +++ b/training/const_reorder/trainer.cc @@ -10,7 +10,7 @@ void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, strcmp(pszAlgorithm, "sgd") == 0 || strcmp(pszAlgorithm, "SGD") == 0); FILE* fpIn = fopen(pszInstanceFName, "r"); - ME_Model* pModel = new ME_Model(); + maxent::ME_Model* pModel = new maxent::ME_Model(); char* pszLine = new char[100001]; int iNumInstances = 0; @@ -30,7 +30,7 @@ void Tsuruoka_Maxent_Trainer::fnTrain(const char* pszInstanceFName, iNumInstances++; - ME_Sample* pmes = new ME_Sample(); + maxent::ME_Sample* pmes = new maxent::ME_Sample(); char* p = strrchr(pszLine, ' '); assert(p != NULL); diff --git a/utils/Makefile.am b/utils/Makefile.am index fabb4454..e0221e64 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -38,11 +38,8 @@ libutils_a_SOURCES = \ have_64_bits.h \ indices_after.h \ kernel_string_subseq.h \ - lbfgs.h \ - lbfgs.cpp \ logval.h \ m.h \ - mathvec.h \ maxent.h \ maxent.cpp \ murmur_hash3.h \ @@ -50,8 +47,6 @@ libutils_a_SOURCES = \ named_enum.h \ null_deleter.h \ null_traits.h \ - owlqn.cpp \ - sgd.cpp \ perfect_hash.h \ prob.h \ sampler.h \ diff --git a/utils/lbfgs.cpp b/utils/lbfgs.cpp deleted file mode 100644 index bd26f048..00000000 --- a/utils/lbfgs.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include -#include -#include "mathvec.h" -#include "lbfgs.h" -#include "maxent.h" - -using namespace std; - -const static int M = LBFGS_M; -const static double LINE_SEARCH_ALPHA = 0.1; -const static double LINE_SEARCH_BETA = 0.5; - -// stopping criteria -int LBFGS_MAX_ITER = 300; -const static double MIN_GRAD_NORM = 0.0001; - -double ME_Model::backtracking_line_search(const Vec& x0, const Vec& grad0, - const double f0, const Vec& dx, - Vec& x, Vec& grad1) { - double t = 1.0 / LINE_SEARCH_BETA; - - double f; - do { - t *= LINE_SEARCH_BETA; - x = x0 + t * dx; - f = FunctionGradient(x.STLVec(), grad1.STLVec()); - // cout << "*"; - } while (f > f0 + LINE_SEARCH_ALPHA * t * dot_product(dx, grad0)); - - return f; -} - -// -// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage", -// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980. -// -Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], - const Vec y[], const double z[]) { - int offset, bound; - if (iter <= M) { - offset = 0; - bound = iter; - } else { - offset = iter - M; - bound = M; - } - - Vec q = grad; - double alpha[M], beta[M]; - for (int i = bound - 1; i >= 0; i--) { - const int j = (i + offset) % M; - alpha[i] = z[j] * dot_product(s[j], q); - q += -alpha[i] * y[j]; - } - if (iter > 0) { - const int j = (iter - 1) % M; - const double gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); - // static double gamma; - // if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); - q *= gamma; - } - for (int i = 0; i <= bound - 1; i++) { - const int j = (i + offset) % M; - beta[i] = z[j] * dot_product(y[j], q); - q += s[j] * (alpha[i] - beta[i]); - } - - return q; -} - -vector ME_Model::perform_LBFGS(const vector& x0) { - const size_t dim = x0.size(); - Vec x = x0; - - Vec grad(dim), dx(dim); - double f = FunctionGradient(x.STLVec(), grad.STLVec()); - - Vec s[M], y[M]; - double z[M]; // rho - - for (int iter = 0; iter < LBFGS_MAX_ITER; iter++) { - - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); - if (_nheldout > 0) { - const double heldout_logl = heldout_likelihood(); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - - if (sqrt(dot_product(grad, grad)) < MIN_GRAD_NORM) break; - - dx = -1 * approximate_Hg(iter, grad, s, y, z); - - Vec x1(dim), grad1(dim); - f = backtracking_line_search(x, grad, f, dx, x1, grad1); - - s[iter % M] = x1 - x; - y[iter % M] = grad1 - grad; - z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); - x = x1; - grad = grad1; - } - - return x.STLVec(); -} diff --git a/utils/lbfgs.h b/utils/lbfgs.h deleted file mode 100644 index 4d706f7a..00000000 --- a/utils/lbfgs.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _LBFGS_H_ -#define _LBFGS_H_ - -#include - -// template -// std::vector -// perform_LBFGS(FuncGrad func_grad, const std::vector & x0); - -std::vector perform_LBFGS( - double (*func_grad)(const std::vector &, std::vector &), - const std::vector &x0); - -std::vector perform_OWLQN( - double (*func_grad)(const std::vector &, std::vector &), - const std::vector &x0, const double C); - -const int LBFGS_M = 10; - -#endif diff --git a/utils/mathvec.h b/utils/mathvec.h deleted file mode 100644 index f8c60e5d..00000000 --- a/utils/mathvec.h +++ /dev/null @@ -1,87 +0,0 @@ -#ifndef _MATH_VECTOR_H_ -#define _MATH_VECTOR_H_ - -#include -#include -#include - -class Vec { - private: - std::vector _v; - - public: - Vec(const size_t n = 0, const double val = 0) { _v.resize(n, val); } - Vec(const std::vector& v) : _v(v) {} - const std::vector& STLVec() const { return _v; } - std::vector& STLVec() { return _v; } - size_t Size() const { return _v.size(); } - double& operator[](int i) { return _v[i]; } - const double& operator[](int i) const { return _v[i]; } - Vec& operator+=(const Vec& b) { - assert(b.Size() == _v.size()); - for (size_t i = 0; i < _v.size(); i++) { - _v[i] += b[i]; - } - return *this; - } - Vec& operator*=(const double c) { - for (size_t i = 0; i < _v.size(); i++) { - _v[i] *= c; - } - return *this; - } - void Project(const Vec& y) { - for (size_t i = 0; i < _v.size(); i++) { - // if (sign(_v[i]) != sign(y[i])) _v[i] = 0; - if (_v[i] * y[i] <= 0) _v[i] = 0; - } - } -}; - -inline double dot_product(const Vec& a, const Vec& b) { - double sum = 0; - for (size_t i = 0; i < a.Size(); i++) { - sum += a[i] * b[i]; - } - return sum; -} - -inline std::ostream& operator<<(std::ostream& s, const Vec& a) { - s << "("; - for (size_t i = 0; i < a.Size(); i++) { - if (i != 0) s << ", "; - s << a[i]; - } - s << ")"; - return s; -} - -inline const Vec operator+(const Vec& a, const Vec& b) { - Vec v(a.Size()); - assert(a.Size() == b.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] + b[i]; - } - return v; -} - -inline const Vec operator-(const Vec& a, const Vec& b) { - Vec v(a.Size()); - assert(a.Size() == b.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] - b[i]; - } - return v; -} - -inline const Vec operator*(const Vec& a, const double c) { - Vec v(a.Size()); - for (size_t i = 0; i < a.Size(); i++) { - v[i] = a[i] * c; - } - return v; -} - -inline const Vec operator*(const double c, const Vec& a) { return a * c; } - -#endif diff --git a/utils/maxent.cpp b/utils/maxent.cpp index 0f49ee9d..fd772e08 100644 --- a/utils/maxent.cpp +++ b/utils/maxent.cpp @@ -3,12 +3,15 @@ */ #include "maxent.h" + +#include +#include #include #include -#include "lbfgs.h" using namespace std; +namespace maxent { double ME_Model::FunctionGradient(const vector& x, vector& grad) { assert((int)_fb.Size() == x.size()); @@ -601,6 +604,428 @@ vector ME_Model::classify(ME_Sample& mes) const { return vp; } +// template +// std::vector +// perform_LBFGS(FuncGrad func_grad, const std::vector & x0); + +std::vector perform_LBFGS( + double (*func_grad)(const std::vector &, std::vector &), + const std::vector &x0); + +std::vector perform_OWLQN( + double (*func_grad)(const std::vector &, std::vector &), + const std::vector &x0, const double C); + +const int LBFGS_M = 10; + +const static int M = LBFGS_M; +const static double LINE_SEARCH_ALPHA = 0.1; +const static double LINE_SEARCH_BETA = 0.5; + +// stopping criteria +int LBFGS_MAX_ITER = 300; +const static double MIN_GRAD_NORM = 0.0001; + +// LBFGS + +double ME_Model::backtracking_line_search(const Vec& x0, const Vec& grad0, + const double f0, const Vec& dx, + Vec& x, Vec& grad1) { + double t = 1.0 / LINE_SEARCH_BETA; + + double f; + do { + t *= LINE_SEARCH_BETA; + x = x0 + t * dx; + f = FunctionGradient(x.STLVec(), grad1.STLVec()); + // cout << "*"; + } while (f > f0 + LINE_SEARCH_ALPHA * t * dot_product(dx, grad0)); + + return f; +} + +// +// Jorge Nocedal, "Updating Quasi-Newton Matrices With Limited Storage", +// Mathematics of Computation, Vol. 35, No. 151, pp. 773-782, 1980. +// +Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], + const Vec y[], const double z[]) { + int offset, bound; + if (iter <= M) { + offset = 0; + bound = iter; + } else { + offset = iter - M; + bound = M; + } + + Vec q = grad; + double alpha[M], beta[M]; + for (int i = bound - 1; i >= 0; i--) { + const int j = (i + offset) % M; + alpha[i] = z[j] * dot_product(s[j], q); + q += -alpha[i] * y[j]; + } + if (iter > 0) { + const int j = (iter - 1) % M; + const double gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); + // static double gamma; + // if (gamma == 0) gamma = ((1.0 / z[j]) / dot_product(y[j], y[j])); + q *= gamma; + } + for (int i = 0; i <= bound - 1; i++) { + const int j = (i + offset) % M; + beta[i] = z[j] * dot_product(y[j], q); + q += s[j] * (alpha[i] - beta[i]); + } + + return q; +} + +vector ME_Model::perform_LBFGS(const vector& x0) { + const size_t dim = x0.size(); + Vec x = x0; + + Vec grad(dim), dx(dim); + double f = FunctionGradient(x.STLVec(), grad.STLVec()); + + Vec s[M], y[M]; + double z[M]; // rho + + for (int iter = 0; iter < LBFGS_MAX_ITER; iter++) { + + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); + if (_nheldout > 0) { + const double heldout_logl = heldout_likelihood(); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + + if (sqrt(dot_product(grad, grad)) < MIN_GRAD_NORM) break; + + dx = -1 * approximate_Hg(iter, grad, s, y, z); + + Vec x1(dim), grad1(dim); + f = backtracking_line_search(x, grad, f, dx, x1, grad1); + + s[iter % M] = x1 - x; + y[iter % M] = grad1 - grad; + z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); + x = x1; + grad = grad1; + } + + return x.STLVec(); +} + +// OWLQN + +// stopping criteria +int OWLQN_MAX_ITER = 300; + +Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], + const Vec y[], const double z[]); + +inline int sign(double x) { + if (x > 0) return 1; + if (x < 0) return -1; + return 0; +}; + +static Vec pseudo_gradient(const Vec& x, const Vec& grad0, const double C) { + Vec grad = grad0; + for (size_t i = 0; i < x.Size(); i++) { + if (x[i] != 0) { + grad[i] += C * sign(x[i]); + continue; + } + const double gm = grad0[i] - C; + if (gm > 0) { + grad[i] = gm; + continue; + } + const double gp = grad0[i] + C; + if (gp < 0) { + grad[i] = gp; + continue; + } + grad[i] = 0; + } + + return grad; +} + +double ME_Model::regularized_func_grad(const double C, const Vec& x, + Vec& grad) { + double f = FunctionGradient(x.STLVec(), grad.STLVec()); + for (size_t i = 0; i < x.Size(); i++) { + f += C * fabs(x[i]); + } + + return f; +} + +double ME_Model::constrained_line_search(double C, const Vec& x0, + const Vec& grad0, const double f0, + const Vec& dx, Vec& x, Vec& grad1) { + // compute the orthant to explore + Vec orthant = x0; + for (size_t i = 0; i < orthant.Size(); i++) { + if (orthant[i] == 0) orthant[i] = -grad0[i]; + } + + double t = 1.0 / LINE_SEARCH_BETA; + + double f; + do { + t *= LINE_SEARCH_BETA; + x = x0 + t * dx; + x.Project(orthant); + // for (size_t i = 0; i < x.Size(); i++) { + // if (x0[i] != 0 && sign(x[i]) != sign(x0[i])) x[i] = 0; + // } + + f = regularized_func_grad(C, x, grad1); + // cout << "*"; + } while (f > f0 + LINE_SEARCH_ALPHA * dot_product(x - x0, grad0)); + + return f; +} + +vector ME_Model::perform_OWLQN(const vector& x0, + const double C) { + const size_t dim = x0.size(); + Vec x = x0; + + Vec grad(dim), dx(dim); + double f = regularized_func_grad(C, x, grad); + + Vec s[M], y[M]; + double z[M]; // rho + + for (int iter = 0; iter < OWLQN_MAX_ITER; iter++) { + Vec pg = pseudo_gradient(x, grad, C); + + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); + if (_nheldout > 0) { + const double heldout_logl = heldout_likelihood(); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + + if (sqrt(dot_product(pg, pg)) < MIN_GRAD_NORM) break; + + dx = -1 * approximate_Hg(iter, pg, s, y, z); + if (dot_product(dx, pg) >= 0) dx.Project(-1 * pg); + + Vec x1(dim), grad1(dim); + f = constrained_line_search(C, x, pg, f, dx, x1, grad1); + + s[iter % M] = x1 - x; + y[iter % M] = grad1 - grad; + z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); + + x = x1; + grad = grad1; + } + + return x.STLVec(); +} + +// SGD + +// const double SGD_ETA0 = 1; +// const double SGD_ITER = 30; +// const double SGD_ALPHA = 0.85; + +//#define FOLOS_NAIVE +//#define FOLOS_LAZY +#define SGD_CP + +inline void apply_l1_penalty(const int i, const double u, vector& _vl, + vector& q) { + double& w = _vl[i]; + const double z = w; + double& qi = q[i]; + if (w > 0) { + w = max(0.0, w - (u + qi)); + } else if (w < 0) { + w = min(0.0, w + (u - qi)); + } + qi += w - z; +} + +static double l1norm(const vector& v) { + double sum = 0; + for (size_t i = 0; i < v.size(); i++) sum += abs(v[i]); + return sum; +} + +inline void update_folos_lazy(const int iter_sample, const int k, + vector& _vl, + const vector& sum_eta, + vector& last_updated) { + const double penalty = sum_eta[iter_sample] - sum_eta[last_updated[k]]; + double& x = _vl[k]; + if (x > 0) + x = max(0.0, x - penalty); + else + x = min(0.0, x + penalty); + last_updated[k] = iter_sample; +} + +int ME_Model::perform_SGD() { + if (_l2reg > 0) { + cerr << "error: L2 regularization is currently not supported in SGD mode." + << endl; + exit(1); + } + + cerr << "performing SGD" << endl; + + const double l1param = _l1reg; + + const int d = _fb.Size(); + + vector ri(_vs.size()); + for (size_t i = 0; i < ri.size(); i++) ri[i] = i; + + vector grad(d); + int iter_sample = 0; + const double eta0 = SGD_ETA0; + + // cerr << "l1param = " << l1param << endl; + cerr << "eta0 = " << eta0 << " alpha = " << SGD_ALPHA << endl; + + double u = 0; + vector q(d, 0); + vector last_updated(d, 0); + vector sum_eta; + sum_eta.push_back(0); + + for (int iter = 0; iter < SGD_ITER; iter++) { + + random_shuffle(ri.begin(), ri.end()); + + double logl = 0; + int ncorrect = 0, ntotal = 0; + for (size_t i = 0; i < _vs.size(); i++, ntotal++, iter_sample++) { + const Sample& s = _vs[ri[i]]; + +#ifdef FOLOS_LAZY + for (vector::const_iterator j = s.positive_features.begin(); + j != s.positive_features.end(); j++) { + for (vector::const_iterator k = _feature2mef[*j].begin(); + k != _feature2mef[*j].end(); k++) { + update_folos_lazy(iter_sample, *k, _vl, sum_eta, last_updated); + } + } +#endif + + vector membp(_num_classes); + const int max_label = conditional_probability(s, membp); + + const double eta = + eta0 * pow(SGD_ALPHA, + (double)iter_sample / _vs.size()); // exponential decay + // const double eta = eta0 / (1.0 + (double)iter_sample / + // _vs.size()); + + // if (iter_sample % _vs.size() == 0) cerr << "eta = " << eta << + // endl; + u += eta * l1param; + + sum_eta.push_back(sum_eta.back() + eta * l1param); + + logl += log(membp[s.label]); + if (max_label == s.label) ncorrect++; + + // binary features + for (vector::const_iterator j = s.positive_features.begin(); + j != s.positive_features.end(); j++) { + for (vector::const_iterator k = _feature2mef[*j].begin(); + k != _feature2mef[*j].end(); k++) { + const double me = membp[_fb.Feature(*k).label()]; + const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); + const double grad = (me - ee); + _vl[*k] -= eta * grad; +#ifdef SGD_CP + apply_l1_penalty(*k, u, _vl, q); +#endif + } + } + // real-valued features + for (vector >::const_iterator j = s.rvfeatures.begin(); + j != s.rvfeatures.end(); j++) { + for (vector::const_iterator k = _feature2mef[j->first].begin(); + k != _feature2mef[j->first].end(); k++) { + const double me = membp[_fb.Feature(*k).label()]; + const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); + const double grad = (me - ee) * j->second; + _vl[*k] -= eta * grad; +#ifdef SGD_CP + apply_l1_penalty(*k, u, _vl, q); +#endif + } + } + +#ifdef FOLOS_NAIVE + for (size_t j = 0; j < d; j++) { + double& x = _vl[j]; + if (x > 0) + x = max(0.0, x - eta * l1param); + else + x = min(0.0, x + eta * l1param); + } +#endif + } + logl /= _vs.size(); +// fprintf(stderr, "%4d logl = %8.3f acc = %6.4f ", iter, logl, +// (double)ncorrect / ntotal); + +#ifdef FOLOS_LAZY + if (l1param > 0) { + for (size_t j = 0; j < d; j++) + update_folos_lazy(iter_sample, j, _vl, sum_eta, last_updated); + } +#endif + + double f = logl; + if (l1param > 0) { + const double l1 = + l1norm(_vl); // this is not accurate when lazy update is used + // cerr << "f0 = " << update_model_expectation() - l1param * l1 << " + // "; + f -= l1param * l1; + int nonzero = 0; + for (int j = 0; j < d; j++) + if (_vl[j] != 0) nonzero++; + // cerr << " f = " << f << " l1 = " << l1 << " nonzero_features = " + // << nonzero << endl; + } + // fprintf(stderr, "%4d obj = %7.3f acc = %6.4f", iter+1, f, + // (double)ncorrect/ntotal); + // fprintf(stderr, "%4d obj = %f", iter+1, f); + fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, f, + 1 - (double)ncorrect / ntotal); + + if (_nheldout > 0) { + double heldout_logl = heldout_likelihood(); + // fprintf(stderr, " heldout_logl = %f acc = %6.4f\n", + // heldout_logl, 1 - _heldout_error); + fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, + _heldout_error); + } + fprintf(stderr, "\n"); + } + + return 0; +} + +} // namespace maxent + /* * $Log: maxent.cpp,v $ * Revision 1.1.1.1 2007/05/15 08:30:35 kyoshida diff --git a/utils/maxent.h b/utils/maxent.h index b1efd88e..74d13a6f 100644 --- a/utils/maxent.h +++ b/utils/maxent.h @@ -5,21 +5,95 @@ #ifndef __MAXENT_H_ #define __MAXENT_H_ -#include -#include -#include -#include #include #include +#include +#include #include +#include +#include + #include -#include "mathvec.h" -#define USE_HASH_MAP // if you encounter errors with hash, try commenting out - // this line. (the program will be a bit slower, though) -#ifdef USE_HASH_MAP -#include -#endif +namespace maxent { +class Vec { + private: + std::vector _v; + + public: + Vec(const size_t n = 0, const double val = 0) { _v.resize(n, val); } + Vec(const std::vector& v) : _v(v) {} + const std::vector& STLVec() const { return _v; } + std::vector& STLVec() { return _v; } + size_t Size() const { return _v.size(); } + double& operator[](int i) { return _v[i]; } + const double& operator[](int i) const { return _v[i]; } + Vec& operator+=(const Vec& b) { + assert(b.Size() == _v.size()); + for (size_t i = 0; i < _v.size(); i++) { + _v[i] += b[i]; + } + return *this; + } + Vec& operator*=(const double c) { + for (size_t i = 0; i < _v.size(); i++) { + _v[i] *= c; + } + return *this; + } + void Project(const Vec& y) { + for (size_t i = 0; i < _v.size(); i++) { + // if (sign(_v[i]) != sign(y[i])) _v[i] = 0; + if (_v[i] * y[i] <= 0) _v[i] = 0; + } + } +}; + +inline double dot_product(const Vec& a, const Vec& b) { + double sum = 0; + for (size_t i = 0; i < a.Size(); i++) { + sum += a[i] * b[i]; + } + return sum; +} + +inline std::ostream& operator<<(std::ostream& s, const Vec& a) { + s << "("; + for (size_t i = 0; i < a.Size(); i++) { + if (i != 0) s << ", "; + s << a[i]; + } + s << ")"; + return s; +} + +inline const Vec operator+(const Vec& a, const Vec& b) { + Vec v(a.Size()); + assert(a.Size() == b.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] + b[i]; + } + return v; +} + +inline const Vec operator-(const Vec& a, const Vec& b) { + Vec v(a.Size()); + assert(a.Size() == b.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] - b[i]; + } + return v; +} + +inline const Vec operator*(const Vec& a, const double c) { + Vec v(a.Size()); + for (size_t i = 0; i < a.Size(); i++) { + v[i] = a[i] * c; + } + return v; +} + +inline const Vec operator*(const double c, const Vec& a) { return a * c; } // // data format for each sample for training/testing @@ -309,6 +383,7 @@ class ME_Model { static double FunctionGradientWrapper(const std::vector& x, std::vector& grad); }; +} // namespace maxent #endif diff --git a/utils/owlqn.cpp b/utils/owlqn.cpp deleted file mode 100644 index c3a0f0da..00000000 --- a/utils/owlqn.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include -#include -#include -#include -#include "mathvec.h" -#include "lbfgs.h" -#include "maxent.h" - -using namespace std; - -const static int M = LBFGS_M; -const static double LINE_SEARCH_ALPHA = 0.1; -const static double LINE_SEARCH_BETA = 0.5; - -// stopping criteria -int OWLQN_MAX_ITER = 300; -const static double MIN_GRAD_NORM = 0.0001; - -Vec approximate_Hg(const int iter, const Vec& grad, const Vec s[], - const Vec y[], const double z[]); - -inline int sign(double x) { - if (x > 0) return 1; - if (x < 0) return -1; - return 0; -}; - -static Vec pseudo_gradient(const Vec& x, const Vec& grad0, const double C) { - Vec grad = grad0; - for (size_t i = 0; i < x.Size(); i++) { - if (x[i] != 0) { - grad[i] += C * sign(x[i]); - continue; - } - const double gm = grad0[i] - C; - if (gm > 0) { - grad[i] = gm; - continue; - } - const double gp = grad0[i] + C; - if (gp < 0) { - grad[i] = gp; - continue; - } - grad[i] = 0; - } - - return grad; -} - -double ME_Model::regularized_func_grad(const double C, const Vec& x, - Vec& grad) { - double f = FunctionGradient(x.STLVec(), grad.STLVec()); - for (size_t i = 0; i < x.Size(); i++) { - f += C * fabs(x[i]); - } - - return f; -} - -double ME_Model::constrained_line_search(double C, const Vec& x0, - const Vec& grad0, const double f0, - const Vec& dx, Vec& x, Vec& grad1) { - // compute the orthant to explore - Vec orthant = x0; - for (size_t i = 0; i < orthant.Size(); i++) { - if (orthant[i] == 0) orthant[i] = -grad0[i]; - } - - double t = 1.0 / LINE_SEARCH_BETA; - - double f; - do { - t *= LINE_SEARCH_BETA; - x = x0 + t * dx; - x.Project(orthant); - // for (size_t i = 0; i < x.Size(); i++) { - // if (x0[i] != 0 && sign(x[i]) != sign(x0[i])) x[i] = 0; - // } - - f = regularized_func_grad(C, x, grad1); - // cout << "*"; - } while (f > f0 + LINE_SEARCH_ALPHA * dot_product(x - x0, grad0)); - - return f; -} - -vector ME_Model::perform_OWLQN(const vector& x0, - const double C) { - const size_t dim = x0.size(); - Vec x = x0; - - Vec grad(dim), dx(dim); - double f = regularized_func_grad(C, x, grad); - - Vec s[M], y[M]; - double z[M]; // rho - - for (int iter = 0; iter < OWLQN_MAX_ITER; iter++) { - Vec pg = pseudo_gradient(x, grad, C); - - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, -f, _train_error); - if (_nheldout > 0) { - const double heldout_logl = heldout_likelihood(); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - - if (sqrt(dot_product(pg, pg)) < MIN_GRAD_NORM) break; - - dx = -1 * approximate_Hg(iter, pg, s, y, z); - if (dot_product(dx, pg) >= 0) dx.Project(-1 * pg); - - Vec x1(dim), grad1(dim); - f = constrained_line_search(C, x, pg, f, dx, x1, grad1); - - s[iter % M] = x1 - x; - y[iter % M] = grad1 - grad; - z[iter % M] = 1.0 / dot_product(y[iter % M], s[iter % M]); - - x = x1; - grad = grad1; - } - - return x.STLVec(); -} diff --git a/utils/sgd.cpp b/utils/sgd.cpp deleted file mode 100644 index 8613edca..00000000 --- a/utils/sgd.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include "maxent.h" -#include -#include - -using namespace std; - -// const double SGD_ETA0 = 1; -// const double SGD_ITER = 30; -// const double SGD_ALPHA = 0.85; - -//#define FOLOS_NAIVE -//#define FOLOS_LAZY -#define SGD_CP - -inline void apply_l1_penalty(const int i, const double u, vector& _vl, - vector& q) { - double& w = _vl[i]; - const double z = w; - double& qi = q[i]; - if (w > 0) { - w = max(0.0, w - (u + qi)); - } else if (w < 0) { - w = min(0.0, w + (u - qi)); - } - qi += w - z; -} - -static double l1norm(const vector& v) { - double sum = 0; - for (size_t i = 0; i < v.size(); i++) sum += abs(v[i]); - return sum; -} - -inline void update_folos_lazy(const int iter_sample, const int k, - vector& _vl, - const vector& sum_eta, - vector& last_updated) { - const double penalty = sum_eta[iter_sample] - sum_eta[last_updated[k]]; - double& x = _vl[k]; - if (x > 0) - x = max(0.0, x - penalty); - else - x = min(0.0, x + penalty); - last_updated[k] = iter_sample; -} - -int ME_Model::perform_SGD() { - if (_l2reg > 0) { - cerr << "error: L2 regularization is currently not supported in SGD mode." - << endl; - exit(1); - } - - cerr << "performing SGD" << endl; - - const double l1param = _l1reg; - - const int d = _fb.Size(); - - vector ri(_vs.size()); - for (size_t i = 0; i < ri.size(); i++) ri[i] = i; - - vector grad(d); - int iter_sample = 0; - const double eta0 = SGD_ETA0; - - // cerr << "l1param = " << l1param << endl; - cerr << "eta0 = " << eta0 << " alpha = " << SGD_ALPHA << endl; - - double u = 0; - vector q(d, 0); - vector last_updated(d, 0); - vector sum_eta; - sum_eta.push_back(0); - - for (int iter = 0; iter < SGD_ITER; iter++) { - - random_shuffle(ri.begin(), ri.end()); - - double logl = 0; - int ncorrect = 0, ntotal = 0; - for (size_t i = 0; i < _vs.size(); i++, ntotal++, iter_sample++) { - const Sample& s = _vs[ri[i]]; - -#ifdef FOLOS_LAZY - for (vector::const_iterator j = s.positive_features.begin(); - j != s.positive_features.end(); j++) { - for (vector::const_iterator k = _feature2mef[*j].begin(); - k != _feature2mef[*j].end(); k++) { - update_folos_lazy(iter_sample, *k, _vl, sum_eta, last_updated); - } - } -#endif - - vector membp(_num_classes); - const int max_label = conditional_probability(s, membp); - - const double eta = - eta0 * pow(SGD_ALPHA, - (double)iter_sample / _vs.size()); // exponential decay - // const double eta = eta0 / (1.0 + (double)iter_sample / - // _vs.size()); - - // if (iter_sample % _vs.size() == 0) cerr << "eta = " << eta << - // endl; - u += eta * l1param; - - sum_eta.push_back(sum_eta.back() + eta * l1param); - - logl += log(membp[s.label]); - if (max_label == s.label) ncorrect++; - - // binary features - for (vector::const_iterator j = s.positive_features.begin(); - j != s.positive_features.end(); j++) { - for (vector::const_iterator k = _feature2mef[*j].begin(); - k != _feature2mef[*j].end(); k++) { - const double me = membp[_fb.Feature(*k).label()]; - const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); - const double grad = (me - ee); - _vl[*k] -= eta * grad; -#ifdef SGD_CP - apply_l1_penalty(*k, u, _vl, q); -#endif - } - } - // real-valued features - for (vector >::const_iterator j = s.rvfeatures.begin(); - j != s.rvfeatures.end(); j++) { - for (vector::const_iterator k = _feature2mef[j->first].begin(); - k != _feature2mef[j->first].end(); k++) { - const double me = membp[_fb.Feature(*k).label()]; - const double ee = (_fb.Feature(*k).label() == s.label ? 1.0 : 0); - const double grad = (me - ee) * j->second; - _vl[*k] -= eta * grad; -#ifdef SGD_CP - apply_l1_penalty(*k, u, _vl, q); -#endif - } - } - -#ifdef FOLOS_NAIVE - for (size_t j = 0; j < d; j++) { - double& x = _vl[j]; - if (x > 0) - x = max(0.0, x - eta * l1param); - else - x = min(0.0, x + eta * l1param); - } -#endif - } - logl /= _vs.size(); -// fprintf(stderr, "%4d logl = %8.3f acc = %6.4f ", iter, logl, -// (double)ncorrect / ntotal); - -#ifdef FOLOS_LAZY - if (l1param > 0) { - for (size_t j = 0; j < d; j++) - update_folos_lazy(iter_sample, j, _vl, sum_eta, last_updated); - } -#endif - - double f = logl; - if (l1param > 0) { - const double l1 = - l1norm(_vl); // this is not accurate when lazy update is used - // cerr << "f0 = " << update_model_expectation() - l1param * l1 << " - // "; - f -= l1param * l1; - int nonzero = 0; - for (int j = 0; j < d; j++) - if (_vl[j] != 0) nonzero++; - // cerr << " f = " << f << " l1 = " << l1 << " nonzero_features = " - // << nonzero << endl; - } - // fprintf(stderr, "%4d obj = %7.3f acc = %6.4f", iter+1, f, - // (double)ncorrect/ntotal); - // fprintf(stderr, "%4d obj = %f", iter+1, f); - fprintf(stderr, "%3d obj(err) = %f (%6.4f)", iter + 1, f, - 1 - (double)ncorrect / ntotal); - - if (_nheldout > 0) { - double heldout_logl = heldout_likelihood(); - // fprintf(stderr, " heldout_logl = %f acc = %6.4f\n", - // heldout_logl, 1 - _heldout_error); - fprintf(stderr, " heldout_logl(err) = %f (%6.4f)", heldout_logl, - _heldout_error); - } - fprintf(stderr, "\n"); - } - - return 0; -} -- cgit v1.2.3