diff options
Diffstat (limited to 'fast/hypergraph.cc')
-rw-r--r-- | fast/hypergraph.cc | 182 |
1 files changed, 86 insertions, 96 deletions
diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc index 4e6601f..c3c587c 100644 --- a/fast/hypergraph.cc +++ b/fast/hypergraph.cc @@ -1,41 +1,30 @@ #include "hypergraph.hh" -namespace Hg { +namespace Hg { /* * Node * */ -bool -Node::is_marked() -{ - return mark >= incoming.size(); -} - std::ostream& operator<<(std::ostream& os, const Node& n) { os << \ "Node<id=" << n.id << \ - ", symbol=" << n.symbol << \ + ", symbol='" << n.symbol << "'" << \ ", span=(" << n.left << "," << n.right << ")" \ ", score=" << n.score << \ ", incoming:" << n.incoming.size() << \ ", outgoing:" << n.outgoing.size() << \ ", mark=" << n.mark << ">"; + return os; } /* * Edge * */ -bool -Edge::is_marked() -{ - return mark >= arity; -} - std::ostream& operator<<(std::ostream& os, const Edge& e) { @@ -45,10 +34,10 @@ operator<<(std::ostream& os, const Edge& e) } os << \ "Edge<head=" << e.head->id << \ - "', tails=[" << _.str() << "]" \ + ", tails=[" << _.str() << "]" \ ", score=" << e.score << \ - ", rule:'" << "TODO" << \ - " , f=" << "TODO" << \ + ", rule:'" << "TODO" << "'" << \ + ", f=" << "TODO" << \ ", arity=" << e.arity << \ ", mark=" << e.mark << ">"; return os; @@ -67,7 +56,7 @@ reset(list<Node*> nodes, vector<Edge*> edges) (**it).mark = 0; } -template<class Semiring> void +template<typename Semiring> void init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring) { for (auto it = nodes.begin(); it != nodes.end(); ++it) @@ -78,8 +67,10 @@ init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring) void topological_sort(list<Node*>& nodes, list<Node*>::iterator root) { + cout << "root " << **root << endl; + for (auto it = nodes.begin(); it != nodes.end(); it++) + cout << (**it).id << endl; auto p = root; - (**p).mark = 0; // is_marked()==true auto to = nodes.begin(); while (to != nodes.end()) { if ((**p).is_marked()) { @@ -97,21 +88,27 @@ topological_sort(list<Node*>& nodes, list<Node*>::iterator root) p = to; } else { ++p; + if (p == nodes.end()) { + p = next(to); + to = next(to); + } } } + cout << "---" << endl; + for (auto it = nodes.begin(); it != nodes.end(); it++) + cout << (**it).id << endl; } void viterbi(Hypergraph& hg) { - list<Node*>::iterator root = hg.nodes.begin(); // FIXME? + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), [](Node* n) { return n->incoming.size() == 0; }); Hg::topological_sort(hg.nodes, root); Semiring::Viterbi<double> semiring; Hg::init(hg.nodes, root, semiring); - for (auto n = hg.nodes.begin(); n != hg.nodes.end(); ++n) { for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); ++e) { - cout << **e << endl; double s = semiring.one; for (auto m = (**e).tails.begin(); m != (**e).tails.end(); ++m) { s = semiring.multiply(s, (**m).score); @@ -119,10 +116,6 @@ viterbi(Hypergraph& hg) (**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score)); } } - - for (auto it = hg.nodes.begin(); it != hg.nodes.end(); ++it) { - cout << (**it).id << " " << (**it).score << endl; - } } namespace io { @@ -134,30 +127,39 @@ read(Hypergraph& hg, string fn) size_t i = 0, nn, ne; msgpack::unpacker pac; while(true) { - pac.reserve_buffer(32*1024); - size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); - pac.buffer_consumed(bytes); - msgpack::unpacked result; - while(pac.next(&result)) { - msgpack::object o = result.get(); - if (i == 0) { - o.convert(&nn); - nn += 1; - } else if (i == 1) { - o.convert(&ne); - ne += 1; - } else if (i > 1 && i <= nn) { - //cout << "N " << o << endl; - Node* n = new Node; - o.convert(n); - } else if (i > nn && i <= nn+ne+1) { - //cout << "E " << o << endl; - Edge* e = new Edge; - o.convert(e); - } - i++; + pac.reserve_buffer(32*1024); + size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + pac.buffer_consumed(bytes); + msgpack::unpacked result; + while(pac.next(&result)) { + msgpack::object o = result.get(); + if (i == 0) { + o.convert(&nn); + nn += 1; + } else if (i == 1) { + o.convert(&ne); + ne += 1; + } else if (i > 1 && i <= nn) { + Node* n = new Node; + o.convert(n); + hg.nodes.push_front(n); // FIXME + hg.nodes_by_id[n->id] = n; + } else if (i > nn && i <= nn+ne+1) { + Edge* e = new Edge; + e->arity = 0; + o.convert(e); + e->head = hg.nodes_by_id[e->head_id_]; + hg.edges.push_back(e); + hg.nodes_by_id[e->head_id_]->incoming.push_back(e); + for (auto it = e->tails_ids_.begin(); it != e->tails_ids_.end(); ++it) { + hg.nodes_by_id[*it]->outgoing.push_back(e); + e->tails.push_back(hg.nodes_by_id[*it]); + e->arity++; + } } - if (!bytes) break; + i++; + } + if (!bytes) break; } } @@ -181,107 +183,95 @@ void manual(Hypergraph& hg) { // nodes - Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = false; a->right = false; a->mark = 0; - Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0; - Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0; - Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0; - Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0; - Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0; - Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0; - Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; + Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0; + Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0; + Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0; + Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0; + Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0; + Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0; + Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0; + Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; - hg.add_node(a); - hg.add_node(h); - hg.add_node(g); + hg.add_node(b); hg.add_node(c); hg.add_node(d); - hg.add_node(f); - hg.add_node(b); hg.add_node(e); + hg.add_node(f); + hg.add_node(g); + hg.add_node(h); + hg.add_node(a); // edges Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; + q->arity = 1; q->mark = 0; + hg.edges.push_back(q); hg.nodes_by_id[1]->incoming.push_back(q); hg.nodes_by_id[0]->outgoing.push_back(q); - q->arity = 1; - q->mark = 0; - hg.edges.push_back(q); Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; + p->arity = 1; p->mark = 0; + hg.edges.push_back(p); hg.nodes_by_id[2]->incoming.push_back(p); hg.nodes_by_id[0]->outgoing.push_back(p); - p->arity = 1; - p->mark = 0; - hg.edges.push_back(p); Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; + r->arity = 1; r->mark = 0; + hg.edges.push_back(r); hg.nodes_by_id[3]->incoming.push_back(r); hg.nodes_by_id[0]->outgoing.push_back(r); - r->arity = 1; - r->mark = 0; - hg.edges.push_back(r); Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; + s->arity = 1; s->mark = 0; + hg.edges.push_back(s); hg.nodes_by_id[3]->incoming.push_back(s); hg.nodes_by_id[0]->outgoing.push_back(s); - s->arity = 1; - s->mark = 0; - hg.edges.push_back(s); Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; + t->arity = 1; t->mark = 0; + hg.edges.push_back(t); hg.nodes_by_id[4]->incoming.push_back(t); hg.nodes_by_id[0]->outgoing.push_back(t); - t->arity = 1; - t->mark = 0; - hg.edges.push_back(t); Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; + u->arity = 1; u->mark = 0; + hg.edges.push_back(u); hg.nodes_by_id[4]->incoming.push_back(u); hg.nodes_by_id[0]->outgoing.push_back(u); - u->arity = 1; - u->mark = 0; - hg.edges.push_back(u); Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; + v->arity = 1; v->mark = 0; + hg.edges.push_back(v); hg.nodes_by_id[4]->incoming.push_back(v); hg.nodes_by_id[3]->outgoing.push_back(v); - v->arity = 1; - v->mark = 0; - hg.edges.push_back(v); Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; + w->arity = 1; w->mark = 0; + hg.edges.push_back(w); hg.nodes_by_id[4]->incoming.push_back(w); hg.nodes_by_id[3]->outgoing.push_back(w); - w->arity = 1; - w->mark = 0; - hg.edges.push_back(w); Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; + x->arity = 1; x->mark = 0; + hg.edges.push_back(x); hg.nodes_by_id[5]->incoming.push_back(x); hg.nodes_by_id[4]->outgoing.push_back(x); - x->arity = 1; - x->mark = 0; - hg.edges.push_back(x); Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; + y->arity = 2; y->mark = 0; + hg.edges.push_back(y); hg.nodes_by_id[6]->incoming.push_back(y); hg.nodes_by_id[2]->outgoing.push_back(y); hg.nodes_by_id[5]->outgoing.push_back(y); - y->arity = 2; - y->mark = 0; - hg.edges.push_back(y); Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; + z->arity = 2; z->mark = 0; + hg.edges.push_back(z); hg.nodes_by_id[7]->incoming.push_back(z); hg.nodes_by_id[1]->outgoing.push_back(z); hg.nodes_by_id[6]->outgoing.push_back(z); - z->arity = 2; - z->mark = 0; - hg.edges.push_back(z); } } // namespace - } // namespace |