diff options
Diffstat (limited to 'fast/hypergraph.cc')
-rw-r--r-- | fast/hypergraph.cc | 216 |
1 files changed, 186 insertions, 30 deletions
diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc index 44e060e..4e6601f 100644 --- a/fast/hypergraph.cc +++ b/fast/hypergraph.cc @@ -41,7 +41,7 @@ operator<<(std::ostream& os, const Edge& e) { ostringstream _; for (auto it = e.tails.begin(); it != e.tails.end(); ++it) { - _ << (*it)->id; if (*it != e.tails.back()) _ << ","; + _ << (**it).id; if (*it != e.tails.back()) _ << ","; } os << \ "Edge<head=" << e.head->id << \ @@ -55,19 +55,26 @@ operator<<(std::ostream& os, const Edge& e) } /* - * Hypergraph - * methods + * functions * */ void -Hypergraph::reset() +reset(list<Node*> nodes, vector<Edge*> edges) { + for (auto it = nodes.begin(); it != nodes.end(); ++it) + (**it).mark = 0; + for (auto it = edges.begin(); it != edges.end(); ++it) + (**it).mark = 0; +} + +template<class Semiring> void +init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring) +{ + for (auto it = nodes.begin(); it != nodes.end(); ++it) + (**it).score = semiring.null; + (**root).score = semiring.one; } -/* - * functions - * - */ void topological_sort(list<Node*>& nodes, list<Node*>::iterator root) { @@ -94,37 +101,186 @@ topological_sort(list<Node*>& nodes, list<Node*>::iterator root) } } -/*void -init(vector<Node*>& nodes, ViterbiSemiring<double>& semiring, Node* root) -{ - for (auto it = nodes.begin(); it != nodes.end(); ++it) - (*it)->score = semiring.null; - root->score = semiring.one; -} - void -viterbi(vector<Node*>& nodes, map<unsigned int, Hg::Node*> nodes_by_id, Node* root) +viterbi(Hypergraph& hg) { - vector<Node*> sorted = topological_sort(nodes); - ViterbiSemiring<double> semiring; - - init(sorted, semiring, root); + list<Node*>::iterator root = hg.nodes.begin(); // FIXME? + Hg::topological_sort(hg.nodes, root); + Semiring::Viterbi<double> semiring; + Hg::init(hg.nodes, root, semiring); - for (auto n_it = sorted.begin(); n_it != sorted.end(); ++n_it) { - for (auto e_it = (*n_it)->incoming.begin(); e_it != (*n_it)->incoming.end(); ++e_it) { - cout << (*e_it)->s() << endl; + for (auto n = hg.nodes.begin(); n != hg.nodes.end(); ++n) { + for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); ++e) { + cout << **e << endl; double s = semiring.one; - for (auto m_it = (*e_it)->tails.begin(); m_it != (*e_it)->tails.end(); m_it++) { - s = semiring.multiply(s, (*m_it)->score); + for (auto m = (**e).tails.begin(); m != (**e).tails.end(); ++m) { + s = semiring.multiply(s, (**m).score); } - (*n_it)->score = semiring.add((*n_it)->score, semiring.multiply(s, (*e_it)->score)); + (**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score)); } } - for (auto it = sorted.begin(); it != sorted.end(); ++it) { - cout << (*it)->id << " " << (*it)->score << endl; + for (auto it = hg.nodes.begin(); it != hg.nodes.end(); ++it) { + cout << (**it).id << " " << (**it).score << endl; } -}*/ +} + +namespace io { + +void +read(Hypergraph& hg, string fn) +{ + ifstream ifs(fn); + size_t i = 0, nn, ne; + msgpack::unpacker pac; + while(true) { + pac.reserve_buffer(32*1024); + size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + pac.buffer_consumed(bytes); + msgpack::unpacked result; + while(pac.next(&result)) { + msgpack::object o = result.get(); + if (i == 0) { + o.convert(&nn); + nn += 1; + } else if (i == 1) { + o.convert(&ne); + ne += 1; + } else if (i > 1 && i <= nn) { + //cout << "N " << o << endl; + Node* n = new Node; + o.convert(n); + } else if (i > nn && i <= nn+ne+1) { + //cout << "E " << o << endl; + Edge* e = new Edge; + o.convert(e); + } + i++; + } + if (!bytes) break; + } +} + +void +write(Hypergraph& hg, string fn) +{ + /*FILE* file = fopen(argv[2], "wb"); + msgpack::fbuffer fbuf(file); + msgpack::pack(fbuf, hg.nodes.size()); + msgpack::pack(fbuf, hg.edges.size()); + msgpack::pack(fbuf, hg.weights); + for (auto it = hg.nodes.begin(); it != hg.nodes.end(); it++) + msgpack::pack(fbuf, *it); + for (auto it = hg.edges.begin(); it != hg.edges.end(); it++) + msgpack::pack(fbuf, *it); + + fclose(file);*/ +} + +void +manual(Hypergraph& hg) +{ + // nodes + Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = false; a->right = false; a->mark = 0; + Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0; + Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0; + Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0; + Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0; + Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0; + Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0; + Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; + + hg.add_node(a); + hg.add_node(h); + hg.add_node(g); + hg.add_node(c); + hg.add_node(d); + hg.add_node(f); + hg.add_node(b); + hg.add_node(e); + + // edges + Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; + hg.nodes_by_id[1]->incoming.push_back(q); + hg.nodes_by_id[0]->outgoing.push_back(q); + q->arity = 1; + q->mark = 0; + hg.edges.push_back(q); + + Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; + hg.nodes_by_id[2]->incoming.push_back(p); + hg.nodes_by_id[0]->outgoing.push_back(p); + p->arity = 1; + p->mark = 0; + hg.edges.push_back(p); + + Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; + hg.nodes_by_id[3]->incoming.push_back(r); + hg.nodes_by_id[0]->outgoing.push_back(r); + r->arity = 1; + r->mark = 0; + hg.edges.push_back(r); + + Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; + hg.nodes_by_id[3]->incoming.push_back(s); + hg.nodes_by_id[0]->outgoing.push_back(s); + s->arity = 1; + s->mark = 0; + hg.edges.push_back(s); + + Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; + hg.nodes_by_id[4]->incoming.push_back(t); + hg.nodes_by_id[0]->outgoing.push_back(t); + t->arity = 1; + t->mark = 0; + hg.edges.push_back(t); + + Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; + hg.nodes_by_id[4]->incoming.push_back(u); + hg.nodes_by_id[0]->outgoing.push_back(u); + u->arity = 1; + u->mark = 0; + hg.edges.push_back(u); + + Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; + hg.nodes_by_id[4]->incoming.push_back(v); + hg.nodes_by_id[3]->outgoing.push_back(v); + v->arity = 1; + v->mark = 0; + hg.edges.push_back(v); + + Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; + hg.nodes_by_id[4]->incoming.push_back(w); + hg.nodes_by_id[3]->outgoing.push_back(w); + w->arity = 1; + w->mark = 0; + hg.edges.push_back(w); + + Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; + hg.nodes_by_id[5]->incoming.push_back(x); + hg.nodes_by_id[4]->outgoing.push_back(x); + x->arity = 1; + x->mark = 0; + hg.edges.push_back(x); + + Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; + hg.nodes_by_id[6]->incoming.push_back(y); + hg.nodes_by_id[2]->outgoing.push_back(y); + hg.nodes_by_id[5]->outgoing.push_back(y); + y->arity = 2; + y->mark = 0; + hg.edges.push_back(y); + + Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; + hg.nodes_by_id[7]->incoming.push_back(z); + hg.nodes_by_id[1]->outgoing.push_back(z); + hg.nodes_by_id[6]->outgoing.push_back(z); + z->arity = 2; + z->mark = 0; + hg.edges.push_back(z); +} + +} // namespace } // namespace |