diff options
Diffstat (limited to 'src/hypergraph.cc')
-rw-r--r-- | src/hypergraph.cc | 362 |
1 files changed, 362 insertions, 0 deletions
diff --git a/src/hypergraph.cc b/src/hypergraph.cc new file mode 100644 index 0000000..40bcc64 --- /dev/null +++ b/src/hypergraph.cc @@ -0,0 +1,362 @@ +#include "hypergraph.hh" + +namespace Hg { + +template<typename Semiring> void +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring) +{ + for (const auto it: nodes) + it->score = semiring.null; + (**root).score = semiring.one; +} + +void +reset(const list<Node*> nodes, const vector<Edge*> edges) +{ + for (const auto it: nodes) + it->mark = 0; + for (auto it: edges) + it->mark = 0; +} + +void +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root) +{ + auto p = root; + auto to = nodes.begin(); + while (to != nodes.end()) { + if ((**p).is_marked()) { + for (const auto e: (**p).outgoing) { // explore edges + e->mark++; + if (e->is_marked()) { + e->head->mark++; + } + } + } + if ((**p).is_marked()) { + nodes.splice(to, nodes, p); + to++; + p = to; + } else { + ++p; + } + } +} + +void +viterbi(Hypergraph& hg) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + + Hg::topological_sort(hg.nodes, root); + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (const auto n: hg.nodes) { + for (const auto e: n->incoming) { + score_t s = semiring.one; + for (const auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + } +} + +void +viterbi_path(Hypergraph& hg, Path& p) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + //list<Node*>::iterator root = hg.nodes.begin(); + + Hg::topological_sort(hg.nodes, root); + // ^^^ FIXME do I need to do this when reading from file? + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (auto n: hg.nodes) { + Edge* best_edge; + bool best = false; + for (auto e: n->incoming) { + score_t s = semiring.one; + for (auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + if (n->score < semiring.multiply(s, e->score)) { // find max + best_edge = e; + best = true; + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + if (best) + p.push_back(best_edge); + } +} + + +void +derive(const Path& p, const Node* cur, vector<string>& carry) +{ + Edge* next; + for (auto it: p) { + if (it->head->symbol == cur->symbol && + it->head->left == cur->left && + it->head->right == cur->right) { + next = it; + } + } // FIXME this is probably not so good + + unsigned j = 0; + for (auto it: next->rule->target) { + if (it->type() == G::NON_TERMINAL) { + derive(p, next->tails[next->rule->order[j]], carry); + j++; + } else { + carry.push_back(it->symbol()); + } + } +} + +namespace io { + +void +read(Hypergraph& hg, vector<G::Rule*>& rules, G::Vocabulary& vocab, const string& fn) +{ + ifstream ifs(fn); + size_t i = 0, r, n, e; + msgpack::unpacker pac; + while(true) { + pac.reserve_buffer(32*1024); + size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + pac.buffer_consumed(bytes); + msgpack::unpacked result; + while(pac.next(&result)) { + msgpack::object o = result.get(); + if (i == 0) { + o.convert(&r); + } else if (i == 1) { + o.convert(&n); + } else if (i == 2) { + o.convert(&e); + } else if (i > 2 && i <= r+2) { + string s; + o.convert(&s); + G::Rule* rule = new G::Rule; + G::Rule::from_s(rule, s, vocab); + rules.push_back(rule); + } else if (i > r+2 && i <= r+n+2) { + Node* n = new Node; + o.convert(n); + hg.nodes.push_back(n); + hg.nodes_by_id[n->id] = n; + } else if (i > n+2 && i <= r+n+e+2) { + Edge* e = new Edge; + e->arity = 0; + o.convert(e); + e->head = hg.nodes_by_id[e->head_id_]; + hg.edges.push_back(e); + hg.nodes_by_id[e->head_id_]->incoming.push_back(e); + e->arity = 0; + for (auto it = e->tails_ids_.begin(); it != e->tails_ids_.end(); it++) { + hg.nodes_by_id[*it]->outgoing.push_back(e); + e->tails.push_back(hg.nodes_by_id[*it]); + e->arity++; + } + e->rule = rules[e->rule_id_]; + } else { + // ERROR + } + i++; + } + if (!bytes) break; + } +} + +void +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME +{ + FILE* file = fopen(fn.c_str(), "wb"); + msgpack::fbuffer fbuf(file); + msgpack::pack(fbuf, rules.size()); + msgpack::pack(fbuf, hg.nodes.size()); + msgpack::pack(fbuf, hg.edges.size()); + for (auto it = rules.cbegin(); it != rules.cend(); it++) + msgpack::pack(fbuf, **it); + for (auto it = hg.nodes.cbegin(); it != hg.nodes.cend(); it++) + msgpack::pack(fbuf, **it); + for (auto it = hg.edges.cbegin(); it != hg.edges.cend(); it++) + msgpack::pack(fbuf, **it); + fclose(file); +} + +void +manual(Hypergraph& hg, vector<G::Rule*>& rules, G::Vocabulary& vocab) +{ + // nodes + Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0; + hg.nodes.push_back(a); hg.nodes_by_id[a->id] = a; + Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0; + hg.nodes.push_back(b); hg.nodes_by_id[b->id] = b; + Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0; + hg.nodes.push_back(c); hg.nodes_by_id[c->id] = c; + Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0; + hg.nodes.push_back(d); hg.nodes_by_id[d->id] = d; + Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0; + hg.nodes.push_back(e); hg.nodes_by_id[e->id] = e; + Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0; + hg.nodes.push_back(f); hg.nodes_by_id[f->id] = f; + Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0; + hg.nodes.push_back(g); hg.nodes_by_id[g->id] = g; + Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; + hg.nodes.push_back(h); hg.nodes_by_id[h->id] = h; + + // rules + vector<string> rule_strs; + rule_strs.push_back("[NP] ||| ich ||| i ||| ||| "); + rule_strs.push_back("[V] ||| sah ||| saw ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| small ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| little ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| small house ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| little house ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] shell ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] house ||| ||| "); + rule_strs.push_back("[NP] ||| ein [NN,1] ||| a [NN,1] ||| ||| "); + rule_strs.push_back("[VP] ||| [V,1] [NP,2] ||| [V,1] [NP,2] ||| ||| "); + rule_strs.push_back("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| ||| "); + + for (auto it: rule_strs) { + rules.push_back(new G::Rule(it, vocab)); + rules.back()->f = new Sv::SparseVector<string, score_t>(); + } + + // edges + Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; + q->arity = 1; q->mark = 0; + hg.edges.push_back(q); + hg.nodes_by_id[1]->incoming.push_back(q); + hg.nodes_by_id[0]->outgoing.push_back(q); + q->rule = rules[0]; + + Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; + p->arity = 1; p->mark = 0; + hg.edges.push_back(p); + hg.nodes_by_id[2]->incoming.push_back(p); + hg.nodes_by_id[0]->outgoing.push_back(p); + p->rule = rules[1]; + + Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; + r->arity = 1; r->mark = 0; + hg.edges.push_back(r); + hg.nodes_by_id[3]->incoming.push_back(r); + hg.nodes_by_id[0]->outgoing.push_back(r); + r->rule = rules[2]; + + Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; + s->arity = 1; s->mark = 0; + hg.edges.push_back(s); + hg.nodes_by_id[3]->incoming.push_back(s); + hg.nodes_by_id[0]->outgoing.push_back(s); + s->rule = rules[3]; + + Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; + t->arity = 1; t->mark = 0; + hg.edges.push_back(t); + hg.nodes_by_id[4]->incoming.push_back(t); + hg.nodes_by_id[0]->outgoing.push_back(t); + t->rule = rules[4]; + + Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; + u->arity = 1; u->mark = 0; + hg.edges.push_back(u); + hg.nodes_by_id[4]->incoming.push_back(u); + hg.nodes_by_id[0]->outgoing.push_back(u); + u->rule = rules[5]; + + Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; + v->arity = 1; v->mark = 0; + hg.edges.push_back(v); + hg.nodes_by_id[4]->incoming.push_back(v); + hg.nodes_by_id[3]->outgoing.push_back(v); + v->rule = rules[6]; + + Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; + w->arity = 1; w->mark = 0; + hg.edges.push_back(w); + hg.nodes_by_id[4]->incoming.push_back(w); + hg.nodes_by_id[3]->outgoing.push_back(w); + w->rule = rules[7]; + + Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; + x->arity = 1; x->mark = 0; + hg.edges.push_back(x); + hg.nodes_by_id[5]->incoming.push_back(x); + hg.nodes_by_id[4]->outgoing.push_back(x); + x->rule = rules[8]; + + Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; + y->arity = 2; y->mark = 0; + hg.edges.push_back(y); + hg.nodes_by_id[6]->incoming.push_back(y); + hg.nodes_by_id[2]->outgoing.push_back(y); + hg.nodes_by_id[5]->outgoing.push_back(y); + y->rule = rules[9]; + + Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; + z->arity = 2; z->mark = 0; + hg.edges.push_back(z); + hg.nodes_by_id[7]->incoming.push_back(z); + hg.nodes_by_id[1]->outgoing.push_back(z); + hg.nodes_by_id[6]->outgoing.push_back(z); + z->rule = rules[10]; +} + +} // namespace Hg::io + +/* + * Hg::Node + * + */ +ostream& +operator<<(ostream& os, const Node& n) +{ + os << \ + "Node<id=" << n.id << \ + ", symbol='" << n.symbol << "'" << \ + ", span=(" << n.left << "," << n.right << ")" \ + ", score=" << n.score << \ + ", incoming:" << n.incoming.size() << \ + ", outgoing:" << n.outgoing.size() << \ + ", mark=" << n.mark << ">"; + return os; +} + +/* + * Hg::Edge + * + */ +ostream& +operator<<(ostream& os, const Edge& e) +{ + ostringstream _; + for (auto it: e.tails) { + _ << it->id; + if (it != e.tails.back()) _ << ","; + } + os << \ + "Edge<head=" << e.head->id << \ + ", tails=[" << _.str() << "]" \ + ", score=" << e.score << \ + ", rule:'"; + e.rule->escaped(os); + os << "', f=" << "TODO" << \ + ", arity=" << e.arity << \ + ", mark=" << e.mark << ">"; + return os; +} + +} // namespace Hg + |