diff options
Diffstat (limited to 'fast/hypergraph.cc')
-rw-r--r-- | fast/hypergraph.cc | 147 |
1 files changed, 120 insertions, 27 deletions
diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc index 6b7bd07..e1debb1 100644 --- a/fast/hypergraph.cc +++ b/fast/hypergraph.cc @@ -3,35 +3,34 @@ namespace Hg { -template<typename Semiring> void -init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring) +template<typename Semiring> void +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring) { - for (auto it = nodes.begin(); it != nodes.end(); it++) - (**it).score = semiring.null; + for (const auto it: nodes) + it->score = semiring.null; (**root).score = semiring.one; } void -reset(list<Node*> nodes, vector<Edge*> edges) +reset(const list<Node*> nodes, const vector<Edge*> edges) { - for (auto it = nodes.begin(); it != nodes.end(); it++) - (**it).mark = 0; - for (auto it = edges.begin(); it != edges.end(); it++) - (**it).mark = 0; + for (const auto it: nodes) + it->mark = 0; + for (auto it: edges) + it->mark = 0; } void -topological_sort(list<Node*>& nodes, list<Node*>::iterator root) +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root) { auto p = root; auto to = nodes.begin(); while (to != nodes.end()) { if ((**p).is_marked()) { - // explore edges - for (auto e = (**p).outgoing.begin(); e!=(**p).outgoing.end(); e++) { - (**e).mark++; - if ((**e).is_marked()) { - (**e).head->mark++; + for (const auto e: (**p).outgoing) { // explore edges + e->mark++; + if (e->is_marked()) { + e->head->mark++; } } } @@ -51,16 +50,71 @@ viterbi(Hypergraph& hg) list<Node*>::iterator root = \ find_if(hg.nodes.begin(), hg.nodes.end(), \ [](Node* n) { return n->incoming.size() == 0; }); + + Hg::topological_sort(hg.nodes, root); + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (const auto n: hg.nodes) { + for (const auto e: n->incoming) { + score_t s = semiring.one; + for (const auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + } +} + +void +viterbi_path(Hypergraph& hg, Path& p) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + Hg::topological_sort(hg.nodes, root); - Semiring::Viterbi<double> semiring; + Semiring::Viterbi<score_t> semiring; Hg::init(hg.nodes, root, semiring); - for (auto n = hg.nodes.begin(); n != hg.nodes.end(); n++) { - for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); e++) { - double s = semiring.one; - for (auto m = (**e).tails.begin(); m != (**e).tails.end(); m++) { - s = semiring.multiply(s, (**m).score); + + for (auto n: hg.nodes) { + Edge* best_edge; + bool best = false; + for (auto e: n->incoming) { + score_t s = semiring.one; + for (auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + if (n->score < semiring.multiply(s, e->score)) { // find max + best_edge = e; + best = true; } - (**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score)); + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + if (best) + p.push_back(best_edge); + } +} + + +void +derive(const Path& p, const Node* cur, vector<string>& carry) +{ + Edge* next; + for (auto it: p) { + if (it->head->symbol == cur->symbol && + it->head->left == cur->left && + it->head->right == cur->right) { + next = it; + } + } + unsigned j = 0; + for (auto it: next->rule->target) { + if (it->type == G::NON_TERMINAL) { + derive(p, next->tails[next->rule->order[j]], carry); + j++; + } else { + carry.push_back(it->t->word); } } } @@ -68,7 +122,7 @@ viterbi(Hypergraph& hg) namespace io { void -read(Hypergraph& hg, vector<G::Rule*> rules, string fn) +read(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME { ifstream ifs(fn); size_t i = 0, nr, nn, ne; @@ -112,7 +166,7 @@ read(Hypergraph& hg, vector<G::Rule*> rules, string fn) } void -write(Hypergraph& hg, vector<G::Rule*> rules, string fn) +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME { FILE* file = fopen(fn.c_str(), "wb"); msgpack::fbuffer fbuf(file); @@ -129,7 +183,7 @@ write(Hypergraph& hg, vector<G::Rule*> rules, string fn) } void -manual(Hypergraph& hg) +manual(Hypergraph& hg, vector<G::Rule*>& rules) { // nodes Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0; @@ -149,60 +203,88 @@ manual(Hypergraph& hg) Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; hg.nodes.push_back(h); hg.nodes_by_id[h->id] = h; + // rules + vector<string> rule_strs; + rule_strs.push_back("[NP] ||| ich ||| i ||| ||| "); + rule_strs.push_back("[V] ||| sah ||| saw ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| small ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| little ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| small house ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| little house ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] shell ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] house ||| ||| "); + rule_strs.push_back("[NP] ||| ein [NN,1] ||| a [NN,1] ||| ||| "); + rule_strs.push_back("[VP] ||| [V,1] [NP,2] ||| [V,1] [NP,2] ||| ||| "); + rule_strs.push_back("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| ||| "); + + for (auto it: rule_strs) { + rules.push_back(new G::Rule(it)); + rules.back()->f = new Sv::SparseVector<string, score_t>(); + } + // edges Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; q->arity = 1; q->mark = 0; hg.edges.push_back(q); hg.nodes_by_id[1]->incoming.push_back(q); hg.nodes_by_id[0]->outgoing.push_back(q); + q->rule = rules[0]; Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; p->arity = 1; p->mark = 0; hg.edges.push_back(p); hg.nodes_by_id[2]->incoming.push_back(p); hg.nodes_by_id[0]->outgoing.push_back(p); + p->rule = rules[1]; Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; r->arity = 1; r->mark = 0; hg.edges.push_back(r); hg.nodes_by_id[3]->incoming.push_back(r); hg.nodes_by_id[0]->outgoing.push_back(r); + r->rule = rules[2]; Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; s->arity = 1; s->mark = 0; hg.edges.push_back(s); hg.nodes_by_id[3]->incoming.push_back(s); hg.nodes_by_id[0]->outgoing.push_back(s); + s->rule = rules[3]; Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; t->arity = 1; t->mark = 0; hg.edges.push_back(t); hg.nodes_by_id[4]->incoming.push_back(t); hg.nodes_by_id[0]->outgoing.push_back(t); + t->rule = rules[4]; Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; u->arity = 1; u->mark = 0; hg.edges.push_back(u); hg.nodes_by_id[4]->incoming.push_back(u); hg.nodes_by_id[0]->outgoing.push_back(u); + u->rule = rules[5]; Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; v->arity = 1; v->mark = 0; hg.edges.push_back(v); hg.nodes_by_id[4]->incoming.push_back(v); hg.nodes_by_id[3]->outgoing.push_back(v); + v->rule = rules[6]; Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; w->arity = 1; w->mark = 0; hg.edges.push_back(w); hg.nodes_by_id[4]->incoming.push_back(w); hg.nodes_by_id[3]->outgoing.push_back(w); + w->rule = rules[7]; Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; x->arity = 1; x->mark = 0; hg.edges.push_back(x); hg.nodes_by_id[5]->incoming.push_back(x); hg.nodes_by_id[4]->outgoing.push_back(x); + x->rule = rules[8]; Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; y->arity = 2; y->mark = 0; @@ -210,6 +292,7 @@ manual(Hypergraph& hg) hg.nodes_by_id[6]->incoming.push_back(y); hg.nodes_by_id[2]->outgoing.push_back(y); hg.nodes_by_id[5]->outgoing.push_back(y); + y->rule = rules[9]; Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; z->arity = 2; z->mark = 0; @@ -217,10 +300,15 @@ manual(Hypergraph& hg) hg.nodes_by_id[7]->incoming.push_back(z); hg.nodes_by_id[1]->outgoing.push_back(z); hg.nodes_by_id[6]->outgoing.push_back(z); + z->rule = rules[10]; } } // namespace Hg::io +/* + * Hg::Node + * + */ ostream& operator<<(ostream& os, const Node& n) { @@ -235,12 +323,17 @@ operator<<(ostream& os, const Node& n) return os; } +/* + * Hg::Edge + * + */ ostream& operator<<(ostream& os, const Edge& e) { ostringstream _; - for (auto it = e.tails.begin(); it != e.tails.end(); it++) { - _ << (**it).id; if (*it != e.tails.back()) _ << ","; + for (auto it: e.tails) { + _ << it->id; + if (it != e.tails.back()) _ << ","; } os << \ "Edge<head=" << e.head->id << \ |