diff options
Diffstat (limited to 'fast')
-rw-r--r-- | fast/Makefile | 9 | ||||
-rw-r--r-- | fast/README.md | 2 | ||||
-rw-r--r-- | fast/dummyvector.h | 28 | ||||
-rw-r--r-- | fast/grammar.cc | 243 | ||||
-rw-r--r-- | fast/grammar.hh | 58 | ||||
-rw-r--r-- | fast/hypergraph.cc | 147 | ||||
-rw-r--r-- | fast/hypergraph.hh | 45 | ||||
-rw-r--r-- | fast/main.cc | 14 | ||||
-rw-r--r-- | fast/semiring.hh | 1 | ||||
-rw-r--r-- | fast/sparse_vector.hh | 106 | ||||
-rw-r--r-- | fast/test_grammar.cc | 4 | ||||
-rw-r--r-- | fast/test_sparse_vector.cc | 11 | ||||
-rw-r--r-- | fast/util.hh | 29 | ||||
-rw-r--r-- | fast/weaver.hh | 4 |
14 files changed, 453 insertions, 248 deletions
diff --git a/fast/Makefile b/fast/Makefile index 6d05fea..40ce0eb 100644 --- a/fast/Makefile +++ b/fast/Makefile @@ -2,17 +2,18 @@ COMPILER=clang CFLAGS=-std=c++11 -O3 -all: hypergraph.o main.cc +all: grammar.o hypergraph.o main.cc $(COMPILER) $(CFLAGS) -std=c++11 -lstdc++ -lm -lmsgpack grammar.o hypergraph.o main.cc -o fast_weaver -test: test_grammar test_sparse_vector -hypergraph.o: hypergraph.cc hypergraph.hh grammar.o semiring.hh +hypergraph.o: hypergraph.cc hypergraph.hh grammar.o semiring.hh sparse_vector.hh weaver.hh $(COMPILER) $(CFLAGS) -g -c hypergraph.cc -grammar.o: grammar.cc grammar.hh +grammar.o: grammar.cc grammar.hh sparse_vector.hh util.hh $(COMPILER) $(CFLAGS) -g -c grammar.cc +test: test_grammar test_sparse_vector + test_grammar: test_grammar.cc grammar.o $(COMPILER) $(CFLAGS) -lstdc++ -lm grammar.o test_grammar.cc -o test_grammar diff --git a/fast/README.md b/fast/README.md index 541f93f..a11bd85 100644 --- a/fast/README.md +++ b/fast/README.md @@ -30,3 +30,5 @@ http://bytes.com/topic/c/answers/702569-blas-vs-cblas-c http://www.netlib.org/lapack/#_standard_c_language_apis_for_lapack http://www.osl.iu.edu/research/mtl/download.php3 http://scicomp.stackexchange.com/questions/351/recommendations-for-a-usable-fast-c-matrix-library + +http://goog-perftools.sourceforge.net/doc/tcmalloc.html diff --git a/fast/dummyvector.h b/fast/dummyvector.h deleted file mode 100644 index 18e2121..0000000 --- a/fast/dummyvector.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include <msgpack.hpp> - - -struct DummyVector { - double CountEF; - double EgivenFCoherent; - double Glue; - double IsSingletonF; - double IsSingletonFE; - double LanguageModel; - double LanguageModel_OOV; - double MaxLexFgivenE; - double MaxLexEgivenF; - double PassThrough; - double PassThrough_1; - double PassThrough_2; - double PassThrough_3; - double PassThrough_4; - double PassThrough_5; - double PassThrough_6; - double SampleCountF; - double WordPenalty; - - MSGPACK_DEFINE(CountEF, EgivenFCoherent, Glue, IsSingletonF, IsSingletonFE, LanguageModel, LanguageModel_OOV, MaxLexEgivenF, MaxLexFgivenE, PassThrough, PassThrough_1, PassThrough_2, PassThrough_3, PassThrough_4, PassThrough_5, PassThrough_6, SampleCountF, WordPenalty); -}; - diff --git a/fast/grammar.cc b/fast/grammar.cc index 7f2d506..558f6e6 100644 --- a/fast/grammar.cc +++ b/fast/grammar.cc @@ -1,170 +1,165 @@ #include "grammar.hh" -string -esc_str(const string& s) { // FIXME - ostringstream os; - for (auto it = s.cbegin(); it != s.cend(); it++) { - switch (*it) { - case '"': os << "\\\""; break; - case '\\': os << "\\\\"; break; - case '\b': os << "\\b"; break; - case '\f': os << "\\f"; break; - case '\n': os << "\\n"; break; - case '\r': os << "\\r"; break; - case '\t': os << "\\t"; break; - default: os << *it; break; - } - } - - return os.str(); -} - namespace G { +/* + * G::NT + * + */ NT::NT(string& s) { - s.erase(0, 1); - s.pop_back(); + s.erase(0, 1); s.pop_back(); // remove '[' and ']' stringstream ss(s); string buf; - size_t c = 0; - index = 0; + size_t j = 0; + index = 0; // default while (ss.good() && getline(ss, buf, ',')) { - if (c == 0) { + if (j == 0) { symbol = buf; } else { index = stoi(buf); } - c++; + j++; } } -T::T(string& s) +string +NT::repr() const { - word = s; + ostringstream os; + os << "NT<" << symbol << "," << index << ">"; + + return os.str(); } -Item::Item(string& s) +string +NT::escaped() const { - if (s.front() == '[' && s.back() == ']') { - type = NON_TERMINAL; - nt = new NT(s); - } else { - type = TERMINAL; - t = new T(s); - } + ostringstream os; + os << "[" << symbol; + if (index > 0) + os << "," << index; + os << "]"; + + return os.str(); } -Rule::Rule(string& s) +ostream& +operator<<(ostream& os, const NT& nt) { - stringstream ss(s); - size_t c = 0; - string buf; - while (ss >> buf) { - if (buf == "|||") { c++; continue; } - if (c == 0) { // LHS - lhs = new NT(buf); - } else if (c == 1) { // RHS - rhs.push_back(new Item(buf)); - if (rhs.back()->type == NON_TERMINAL) arity++; - } else if (c == 2) { // TARGET - target.push_back(new Item(buf)); - } else if (c == 3) { // F TODO - } else if (c == 4) { // A TODO - } else { // ERROR FIXME - } - if (c == 4) break; - } - arity = 0; + return os << nt.repr(); } -Grammar::Grammar(string fn) +/* + * G::T + * + */ +T::T(const string& s) { - ifstream ifs(fn); - string line; - while (getline(ifs, line)) { - G::Rule* r = new G::Rule(line); - rules.push_back(r); - if (r->arity == 0) - flat.push_back(r); - else if (r->rhs.front()->type == NON_TERMINAL) - start_nt.push_back(r); - else - start_t.push_back(r); - } + word = s; } string -Item::repr() const +T::repr() const { ostringstream os; - if (type == TERMINAL) - os << t->repr(); - else - os << nt->repr(); + os << "T<" << word << ">"; return os.str(); } string -Item::escaped() const +T::escaped() const { - ostringstream os; - if (type == TERMINAL) - os << t->escaped(); - else - os << nt->escaped(); - - return os.str(); + return util::json_escape(word); } ostream& -operator<<(ostream& os, const Item& i) +operator<<(ostream& os, const T& t) { - return os << i.repr(); + return os << t.repr(); } -string -NT::repr() const -{ - ostringstream os; - os << "NT<" << symbol << "," << index << ">"; - return os.str(); +/* + * G::Item + * + * Better solve this by inheritance + * -> rhs, target as vector<base class> ? + * + */ +Item::Item(string& s) +{ + if (s.front() == '[' && s.back() == ']') { + type = NON_TERMINAL; + nt = new NT(s); + } else { + type = TERMINAL; + t = new T(s); + } } string -NT::escaped() const +Item::repr() const { ostringstream os; - os << "[" << symbol; - if (index > 0) - os << "," << index; - os << "]"; + if (type == TERMINAL) + os << t->repr(); + else + os << nt->repr(); return os.str(); } -ostream& -operator<<(ostream& os, const NT& nt) -{ - return os << nt.repr(); -} - string -T::repr() const +Item::escaped() const { ostringstream os; - os << "T<" << word << ">"; + if (type == TERMINAL) + os << t->escaped(); + else + os << nt->escaped(); return os.str(); } ostream& -operator<<(ostream& os, const T& t) +operator<<(ostream& os, const Item& i) { - return os << t.repr(); + return os << i.repr(); +} + +/* + * G::Rule + * + */ +Rule::Rule(const string& s) +{ + stringstream ss(s); + size_t j = 0; + string buf; + arity = 0; + size_t index = 1; + while (ss >> buf) { + if (buf == "|||") { j++; continue; } + if (j == 0) { // LHS + lhs = new NT(buf); + } else if (j == 1) { // RHS + rhs.push_back(new Item(buf)); + if (rhs.back()->type == NON_TERMINAL) arity++; + } else if (j == 2) { // TARGET + target.push_back(new Item(buf)); + if (target.back()->type == NON_TERMINAL) { + order.insert(make_pair(index, target.back()->nt->index)); + index++; + } + } else if (j == 3) { // F TODO + } else if (j == 4) { // A TODO + } else { // ERROR + } + if (j == 4) break; + } } string @@ -183,7 +178,7 @@ Rule::repr() const if (next(it) != target.end()) os << " "; } os << "}" \ - ", f:" << "TODO" << \ + ", f:" << f->repr() << \ ", arity=" << arity << \ ", map:" << "TODO" << \ ">"; @@ -191,12 +186,6 @@ Rule::repr() const return os.str(); } -ostream& -operator<<(ostream& os, const Rule& r) -{ - return os << r.repr(); -} - string Rule::escaped() const { @@ -212,7 +201,7 @@ Rule::escaped() const if (next(it) != target.end()) os << " "; } os << " ||| "; - os << "TODO"; + os << f->escaped(); os << " ||| "; os << "TODO"; @@ -220,10 +209,36 @@ Rule::escaped() const } ostream& +operator<<(ostream& os, const Rule& r) +{ + return os << r.repr(); +} + +/* + * G::Grammmar + * + */ +Grammar::Grammar(const string& fn) +{ + ifstream ifs(fn); + string line; + while (getline(ifs, line)) { + G::Rule* r = new G::Rule(line); + rules.push_back(r); + if (r->arity == 0) + flat.push_back(r); + else if (r->rhs.front()->type == NON_TERMINAL) + start_nt.push_back(r); + else + start_t.push_back(r); + } +} + +ostream& operator<<(ostream& os, const Grammar& g) { - for (auto it = g.rules.begin(); it != g.rules.end(); it++) - os << (**it).repr() << endl; + for (const auto it: g.rules) + os << it->repr() << endl; return os; } diff --git a/fast/grammar.hh b/fast/grammar.hh index 51501cf..48a5116 100644 --- a/fast/grammar.hh +++ b/fast/grammar.hh @@ -1,38 +1,42 @@ #pragma once +#include <fstream> #include <iostream> -#include <string> #include <sstream> -#include <fstream> -#include <vector> +#include <string> #include <map> +#include <msgpack.hpp> +#include <vector> -#include "dummyvector.h" +#include "sparse_vector.hh" +#include "util.hh" using namespace std; -string esc_str(const string& s); // FIXME - namespace G { struct NT { - string symbol; - unsigned int index; + string symbol; + size_t index; NT() {}; NT(string& s); + string repr() const; string escaped() const; + friend ostream& operator<<(ostream& os, const NT& t); }; struct T { - string word; + string word; // use word ids instead? + + T(const string& s); - T(string& s); string repr() const; - string escaped() const { return esc_str(word); } + string escaped() const; + friend ostream& operator<<(ostream& os, const NT& nt); }; @@ -47,26 +51,33 @@ struct Item { T* t; Item(string& s); + string repr() const; string escaped() const; + friend ostream& operator<<(ostream& os, const Item& i); }; struct Rule { - NT* lhs; - vector<Item*> rhs; - vector<Item*> target; - //map<int,int> map; - size_t arity; - DummyVector f; + NT* lhs; + vector<Item*> rhs; + vector<Item*> target; + size_t arity; +Sv::SparseVector<string, score_t>* f; + map<size_t, size_t> order; + string as_str_; // FIXME Rule() {}; - Rule(string& s); + Rule(const string& s); + string repr() const; string escaped() const; + friend ostream& operator<<(ostream& os, const Rule& r); - MSGPACK_DEFINE(); + void prep_for_serialization_() { as_str_ = escaped(); }; // FIXME + + MSGPACK_DEFINE(as_str_); // TODO }; struct Grammar { @@ -75,9 +86,12 @@ struct Grammar { vector<Rule*> start_nt; vector<Rule*> start_t; - Grammar(string fn); - void add_glue(); - void add_pass_through(); + Grammar() {}; + Grammar(const string& fn); + + void add_glue(); // TODO + void add_pass_through(const string& input); // TODO + friend ostream& operator<<(ostream& os, const Grammar& g); }; diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc index 6b7bd07..e1debb1 100644 --- a/fast/hypergraph.cc +++ b/fast/hypergraph.cc @@ -3,35 +3,34 @@ namespace Hg { -template<typename Semiring> void -init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring) +template<typename Semiring> void +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring) { - for (auto it = nodes.begin(); it != nodes.end(); it++) - (**it).score = semiring.null; + for (const auto it: nodes) + it->score = semiring.null; (**root).score = semiring.one; } void -reset(list<Node*> nodes, vector<Edge*> edges) +reset(const list<Node*> nodes, const vector<Edge*> edges) { - for (auto it = nodes.begin(); it != nodes.end(); it++) - (**it).mark = 0; - for (auto it = edges.begin(); it != edges.end(); it++) - (**it).mark = 0; + for (const auto it: nodes) + it->mark = 0; + for (auto it: edges) + it->mark = 0; } void -topological_sort(list<Node*>& nodes, list<Node*>::iterator root) +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root) { auto p = root; auto to = nodes.begin(); while (to != nodes.end()) { if ((**p).is_marked()) { - // explore edges - for (auto e = (**p).outgoing.begin(); e!=(**p).outgoing.end(); e++) { - (**e).mark++; - if ((**e).is_marked()) { - (**e).head->mark++; + for (const auto e: (**p).outgoing) { // explore edges + e->mark++; + if (e->is_marked()) { + e->head->mark++; } } } @@ -51,16 +50,71 @@ viterbi(Hypergraph& hg) list<Node*>::iterator root = \ find_if(hg.nodes.begin(), hg.nodes.end(), \ [](Node* n) { return n->incoming.size() == 0; }); + + Hg::topological_sort(hg.nodes, root); + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (const auto n: hg.nodes) { + for (const auto e: n->incoming) { + score_t s = semiring.one; + for (const auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + } +} + +void +viterbi_path(Hypergraph& hg, Path& p) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + Hg::topological_sort(hg.nodes, root); - Semiring::Viterbi<double> semiring; + Semiring::Viterbi<score_t> semiring; Hg::init(hg.nodes, root, semiring); - for (auto n = hg.nodes.begin(); n != hg.nodes.end(); n++) { - for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); e++) { - double s = semiring.one; - for (auto m = (**e).tails.begin(); m != (**e).tails.end(); m++) { - s = semiring.multiply(s, (**m).score); + + for (auto n: hg.nodes) { + Edge* best_edge; + bool best = false; + for (auto e: n->incoming) { + score_t s = semiring.one; + for (auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + if (n->score < semiring.multiply(s, e->score)) { // find max + best_edge = e; + best = true; } - (**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score)); + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + if (best) + p.push_back(best_edge); + } +} + + +void +derive(const Path& p, const Node* cur, vector<string>& carry) +{ + Edge* next; + for (auto it: p) { + if (it->head->symbol == cur->symbol && + it->head->left == cur->left && + it->head->right == cur->right) { + next = it; + } + } + unsigned j = 0; + for (auto it: next->rule->target) { + if (it->type == G::NON_TERMINAL) { + derive(p, next->tails[next->rule->order[j]], carry); + j++; + } else { + carry.push_back(it->t->word); } } } @@ -68,7 +122,7 @@ viterbi(Hypergraph& hg) namespace io { void -read(Hypergraph& hg, vector<G::Rule*> rules, string fn) +read(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME { ifstream ifs(fn); size_t i = 0, nr, nn, ne; @@ -112,7 +166,7 @@ read(Hypergraph& hg, vector<G::Rule*> rules, string fn) } void -write(Hypergraph& hg, vector<G::Rule*> rules, string fn) +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME { FILE* file = fopen(fn.c_str(), "wb"); msgpack::fbuffer fbuf(file); @@ -129,7 +183,7 @@ write(Hypergraph& hg, vector<G::Rule*> rules, string fn) } void -manual(Hypergraph& hg) +manual(Hypergraph& hg, vector<G::Rule*>& rules) { // nodes Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0; @@ -149,60 +203,88 @@ manual(Hypergraph& hg) Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; hg.nodes.push_back(h); hg.nodes_by_id[h->id] = h; + // rules + vector<string> rule_strs; + rule_strs.push_back("[NP] ||| ich ||| i ||| ||| "); + rule_strs.push_back("[V] ||| sah ||| saw ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| small ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| little ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| small house ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| little house ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] shell ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] house ||| ||| "); + rule_strs.push_back("[NP] ||| ein [NN,1] ||| a [NN,1] ||| ||| "); + rule_strs.push_back("[VP] ||| [V,1] [NP,2] ||| [V,1] [NP,2] ||| ||| "); + rule_strs.push_back("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| ||| "); + + for (auto it: rule_strs) { + rules.push_back(new G::Rule(it)); + rules.back()->f = new Sv::SparseVector<string, score_t>(); + } + // edges Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; q->arity = 1; q->mark = 0; hg.edges.push_back(q); hg.nodes_by_id[1]->incoming.push_back(q); hg.nodes_by_id[0]->outgoing.push_back(q); + q->rule = rules[0]; Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; p->arity = 1; p->mark = 0; hg.edges.push_back(p); hg.nodes_by_id[2]->incoming.push_back(p); hg.nodes_by_id[0]->outgoing.push_back(p); + p->rule = rules[1]; Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; r->arity = 1; r->mark = 0; hg.edges.push_back(r); hg.nodes_by_id[3]->incoming.push_back(r); hg.nodes_by_id[0]->outgoing.push_back(r); + r->rule = rules[2]; Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; s->arity = 1; s->mark = 0; hg.edges.push_back(s); hg.nodes_by_id[3]->incoming.push_back(s); hg.nodes_by_id[0]->outgoing.push_back(s); + s->rule = rules[3]; Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; t->arity = 1; t->mark = 0; hg.edges.push_back(t); hg.nodes_by_id[4]->incoming.push_back(t); hg.nodes_by_id[0]->outgoing.push_back(t); + t->rule = rules[4]; Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; u->arity = 1; u->mark = 0; hg.edges.push_back(u); hg.nodes_by_id[4]->incoming.push_back(u); hg.nodes_by_id[0]->outgoing.push_back(u); + u->rule = rules[5]; Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; v->arity = 1; v->mark = 0; hg.edges.push_back(v); hg.nodes_by_id[4]->incoming.push_back(v); hg.nodes_by_id[3]->outgoing.push_back(v); + v->rule = rules[6]; Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; w->arity = 1; w->mark = 0; hg.edges.push_back(w); hg.nodes_by_id[4]->incoming.push_back(w); hg.nodes_by_id[3]->outgoing.push_back(w); + w->rule = rules[7]; Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; x->arity = 1; x->mark = 0; hg.edges.push_back(x); hg.nodes_by_id[5]->incoming.push_back(x); hg.nodes_by_id[4]->outgoing.push_back(x); + x->rule = rules[8]; Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; y->arity = 2; y->mark = 0; @@ -210,6 +292,7 @@ manual(Hypergraph& hg) hg.nodes_by_id[6]->incoming.push_back(y); hg.nodes_by_id[2]->outgoing.push_back(y); hg.nodes_by_id[5]->outgoing.push_back(y); + y->rule = rules[9]; Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; z->arity = 2; z->mark = 0; @@ -217,10 +300,15 @@ manual(Hypergraph& hg) hg.nodes_by_id[7]->incoming.push_back(z); hg.nodes_by_id[1]->outgoing.push_back(z); hg.nodes_by_id[6]->outgoing.push_back(z); + z->rule = rules[10]; } } // namespace Hg::io +/* + * Hg::Node + * + */ ostream& operator<<(ostream& os, const Node& n) { @@ -235,12 +323,17 @@ operator<<(ostream& os, const Node& n) return os; } +/* + * Hg::Edge + * + */ ostream& operator<<(ostream& os, const Edge& e) { ostringstream _; - for (auto it = e.tails.begin(); it != e.tails.end(); it++) { - _ << (**it).id; if (*it != e.tails.back()) _ << ","; + for (auto it: e.tails) { + _ << it->id; + if (it != e.tails.back()) _ << ","; } os << \ "Edge<head=" << e.head->id << \ diff --git a/fast/hypergraph.hh b/fast/hypergraph.hh index 79ee97b..699bfdf 100644 --- a/fast/hypergraph.hh +++ b/fast/hypergraph.hh @@ -1,28 +1,25 @@ #pragma once -#include <iostream> -#include <string> -#include <sstream> -#include <vector> -#include <list> -#include <unordered_map> -#include <functional> #include <algorithm> -#include <iterator> #include <fstream> +#include <functional> +#include <iostream> +#include <iterator> +#include <list> #include <msgpack.hpp> #include <msgpack/fbuffer.hpp> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> #include "grammar.hh" #include "semiring.hh" -#include "dummyvector.h" #include "sparse_vector.hh" +#include "weaver.hh" using namespace std; -typedef double score_t; -typedef double weight_t; - namespace Hg { @@ -69,28 +66,36 @@ struct Hypergraph { unsigned int arity; }; -void -reset(); - template<typename Semiring> void -init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring); +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring); void -topological_sort(list<Node*>& nodes, list<Node*>::iterator root); +reset(const list<Node*> nodes, const vector<Edge*> edges); + +void +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root); void viterbi(Hypergraph& hg); +typedef vector<Edge*> Path; + +void +viterbi_path(Hypergraph& hg, Path& p); + +void +derive(const Path& p, const Node* cur, vector<string>& carry); + namespace io { void -read(Hypergraph& hg, vector<G::Rule*> rules, string fn); +read(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn); // FIXME void -write(Hypergraph& hg, vector<G::Rule*> rules, string fn); +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn); // TODO void -manual(Hypergraph& hg); +manual(Hypergraph& hg, vector<G::Rule*>& rules); } // namespace diff --git a/fast/main.cc b/fast/main.cc index 2a8676b..59e25d5 100644 --- a/fast/main.cc +++ b/fast/main.cc @@ -5,8 +5,18 @@ int main(int argc, char** argv) { Hg::Hypergraph hg; - Hg::io::read(hg, argv[1]); - Hg::viterbi(hg); + G::Grammar g; +//Hg::io::read(hg, g.rules, argv[1]); + Hg::io::manual(hg, g.rules); + + Hg::Path p; + Hg::viterbi_path(hg, p); + vector<string> s; + Hg::derive(p, p.back()->head, s); + for (auto it: s) + cout << it << " "; + cout << endl; + return 0; } diff --git a/fast/semiring.hh b/fast/semiring.hh index 3f4ac08..1c3ff1d 100644 --- a/fast/semiring.hh +++ b/fast/semiring.hh @@ -1,6 +1,7 @@ #pragma once +// TODO: others namespace Semiring { template<typename T> diff --git a/fast/sparse_vector.hh b/fast/sparse_vector.hh index dd7f3cf..e497769 100644 --- a/fast/sparse_vector.hh +++ b/fast/sparse_vector.hh @@ -1,11 +1,13 @@ #pragma once +#include <iostream> +#include <sstream> +#include <string> #include <unordered_map> #include <vector> -#include <sstream> -typedef double score_t; // FIXME -typedef double weight_t; +#include "util.hh" +#include "weaver.hh" using namespace std; @@ -14,17 +16,52 @@ namespace Sv { template<typename K, typename V> struct SparseVector { - unordered_map<K, V> m_; - V zero = 0.0; + unordered_map<K,V> m_; + V zero = 0.f; + + SparseVector() {}; + SparseVector(string& s) + { + stringstream ss(s); + while (!ss.eof()) { + string t; + ss >> t; + size_t eq = t.find_first_of("="); + t.replace(eq, 1, " "); + stringstream tt(t); + K k; V v; + tt >> k >> v; + m_.emplace(k.substr(k.find_first_of("\"")+1, k.find_last_of("\"")-1), v); + } + }; void insert(K k, V v) { m_[k] = v; }; - weight_t + V dot(SparseVector& other) { + V r; + unordered_map<K,V>* o = &m_; + auto b = m_.cbegin(); + auto e = m_.cend(); + if (other.size() < size()) { + b = other.m_.cbegin(); + e = other.m_.cend(); + o = &other.m_; + } + for (auto it = b; it != e; it++) + r += it->second * o->at(it->first); + + return r; }; + size_t + size() + { + return m_.size(); + } + V& operator[](const K& k) { @@ -44,18 +81,20 @@ struct SparseVector { operator+(const SparseVector& other) const { SparseVector<K,V> v; - v.m_.insert(m_.begin(), m_.end()); - v.m_.insert(other.m_.begin(), other.m_.end()); - for (auto it = v.m_.begin(); it != v.m_.end(); it++) - v.m_[it->first] = this->at(it->first) + other.at(it->first); + v.m_.insert(m_.cbegin(), m_.cend()); + v.m_.insert(other.m_.cbegin(), other.m_.cend()); + for (const auto it: v.m_) + v.m_[it.first] = this->at(it.first) + other.at(it.first); + return v; }; SparseVector& operator+=(const SparseVector& other) { - for (auto it = other.m_.begin(); it != other.m_.end(); it++) - m_[it->first] += it->second; + for (const auto it: other.m_) + m_[it.first] += it.second; + return *this; }; @@ -63,18 +102,20 @@ struct SparseVector { operator-(const SparseVector& other) const { SparseVector<K,V> v; - v.m_.insert(m_.begin(), m_.end()); - v.m_.insert(other.m_.begin(), other.m_.end()); - for (auto it = v.m_.begin(); it != v.m_.end(); it++) - v.m_[it->first] = this->at(it->first) - other.at(it->first); + v.m_.insert(m_.cbegin(), m_.cend()); + v.m_.insert(other.m_.cbegin(), other.m_.cend()); + for (const auto it: v.m_) + v.m_[it.first] = this->at(it.first) - other.at(it.first); + return v; }; SparseVector& operator-=(const SparseVector& other) { - for (auto it = other.m_.begin(); it != other.m_.end(); it++) - m_[it->first] -= it->second; + for (const auto it: other.m_) + m_[it.first] -= it.second; + return *this; }; @@ -82,35 +123,48 @@ struct SparseVector { operator*(V f) const { SparseVector<K,V> v; - for (auto it = m_.begin(); it != m_.end(); it++) - v.m_[it->first] = this->at(it->first) * f; + for (const auto it: m_) + v.m_[it.first] = this->at(it.first) * f; + return v; }; SparseVector& operator*=(V f) { - for (auto it = m_.begin(); it != m_.end(); it++) - m_[it->first] *= f; + for (const auto it: m_) + m_[it.first] *= f; + return *this; }; string repr() const { - ostringstream os; + ostringstream os; os << "SparseVector<{"; - for (auto it = m_.begin(); it != m_.end(); it ++) { + for (auto it = m_.cbegin(); it != m_.cend(); it++) { os << "'" << it->first << "'=" << it->second; if (next(it) != m_.end()) os << ", "; } os << "}>"; + + return os.str(); + }; + + string + escaped() const { + ostringstream os; + for (auto it = m_.cbegin(); it != m_.cend(); it++) { + os << '"' << util::json_escape(it->first) << '"' << "=" << it->second; + if (next(it) != m_.cend()) os << " "; + } + return os.str(); }; - friend ostream& - operator<<(ostream& os, const SparseVector& v) { return os << v.repr(); } + friend ostream& operator<<(ostream& os, const SparseVector& v) { return os << v.repr(); } }; } // namespace diff --git a/fast/test_grammar.cc b/fast/test_grammar.cc index 34a55ba..3263edd 100644 --- a/fast/test_grammar.cc +++ b/fast/test_grammar.cc @@ -9,8 +9,8 @@ int main(int argc, char** argv) { G::Grammar g(argv[1]); - for (auto it = g.rules.begin(); it != g.rules.end(); it++) - cout << (**it).escaped() << endl; + for (auto it: g.rules) + cout << it->escaped() << endl; return 0; } diff --git a/fast/test_sparse_vector.cc b/fast/test_sparse_vector.cc index f486486..426bed1 100644 --- a/fast/test_sparse_vector.cc +++ b/fast/test_sparse_vector.cc @@ -4,16 +4,16 @@ int main(void) { - Sv::SparseVector<string, weight_t> a; + Sv::SparseVector<string, score_t> a; a.insert("1", 1); a.insert("2", 2); cout << "a:" << a << endl; - Sv::SparseVector<string, weight_t> b; + Sv::SparseVector<string, score_t> b; b.insert("2", 2); cout << "b:" << b << endl; - Sv::SparseVector<string, weight_t> c = a + b; + Sv::SparseVector<string, score_t> c = a + b; cout << "a+b:" << c << endl; a += b; @@ -27,6 +27,11 @@ main(void) a *= 2; cout << "a*=2:" << a << endl; + string s("\"a\"=2 \"b\"=3"); + Sv::SparseVector<string, score_t>* sv = new Sv::SparseVector<string, score_t>(s); + cout << *sv << endl; + cout << sv->dot(*sv) << endl; + return 0; } diff --git a/fast/util.hh b/fast/util.hh new file mode 100644 index 0000000..2a28f16 --- /dev/null +++ b/fast/util.hh @@ -0,0 +1,29 @@ +#pragma once + +#include <string> + +using namespace std; + + +namespace util { + + inline string + json_escape(const string& s) { // FIXME: only inline? + ostringstream os; + for (auto it = s.cbegin(); it != s.cend(); it++) { + switch (*it) { + case '"': os << "\\\""; break; + case '\\': os << "\\\\"; break; + case '\b': os << "\\b"; break; + case '\f': os << "\\f"; break; + case '\n': os << "\\n"; break; + case '\r': os << "\\r"; break; + case '\t': os << "\\t"; break; + default: os << *it; break; + } + } + return os.str(); + }; + +} // namespace util + diff --git a/fast/weaver.hh b/fast/weaver.hh new file mode 100644 index 0000000..e7c3238 --- /dev/null +++ b/fast/weaver.hh @@ -0,0 +1,4 @@ +#pragma once + +typedef double score_t; + |