summaryrefslogtreecommitdiff
path: root/fast/hypergraph.cc
diff options
context:
space:
mode:
Diffstat (limited to 'fast/hypergraph.cc')
-rw-r--r--fast/hypergraph.cc147
1 files changed, 120 insertions, 27 deletions
diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc
index 6b7bd07..e1debb1 100644
--- a/fast/hypergraph.cc
+++ b/fast/hypergraph.cc
@@ -3,35 +3,34 @@
namespace Hg {
-template<typename Semiring> void
-init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring)
+template<typename Semiring> void
+init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring)
{
- for (auto it = nodes.begin(); it != nodes.end(); it++)
- (**it).score = semiring.null;
+ for (const auto it: nodes)
+ it->score = semiring.null;
(**root).score = semiring.one;
}
void
-reset(list<Node*> nodes, vector<Edge*> edges)
+reset(const list<Node*> nodes, const vector<Edge*> edges)
{
- for (auto it = nodes.begin(); it != nodes.end(); it++)
- (**it).mark = 0;
- for (auto it = edges.begin(); it != edges.end(); it++)
- (**it).mark = 0;
+ for (const auto it: nodes)
+ it->mark = 0;
+ for (auto it: edges)
+ it->mark = 0;
}
void
-topological_sort(list<Node*>& nodes, list<Node*>::iterator root)
+topological_sort(list<Node*>& nodes, const list<Node*>::iterator root)
{
auto p = root;
auto to = nodes.begin();
while (to != nodes.end()) {
if ((**p).is_marked()) {
- // explore edges
- for (auto e = (**p).outgoing.begin(); e!=(**p).outgoing.end(); e++) {
- (**e).mark++;
- if ((**e).is_marked()) {
- (**e).head->mark++;
+ for (const auto e: (**p).outgoing) { // explore edges
+ e->mark++;
+ if (e->is_marked()) {
+ e->head->mark++;
}
}
}
@@ -51,16 +50,71 @@ viterbi(Hypergraph& hg)
list<Node*>::iterator root = \
find_if(hg.nodes.begin(), hg.nodes.end(), \
[](Node* n) { return n->incoming.size() == 0; });
+
+ Hg::topological_sort(hg.nodes, root);
+ Semiring::Viterbi<score_t> semiring;
+ Hg::init(hg.nodes, root, semiring);
+
+ for (const auto n: hg.nodes) {
+ for (const auto e: n->incoming) {
+ score_t s = semiring.one;
+ for (const auto m: e->tails) {
+ s = semiring.multiply(s, m->score);
+ }
+ n->score = semiring.add(n->score, semiring.multiply(s, e->score));
+ }
+ }
+}
+
+void
+viterbi_path(Hypergraph& hg, Path& p)
+{
+ list<Node*>::iterator root = \
+ find_if(hg.nodes.begin(), hg.nodes.end(), \
+ [](Node* n) { return n->incoming.size() == 0; });
+
Hg::topological_sort(hg.nodes, root);
- Semiring::Viterbi<double> semiring;
+ Semiring::Viterbi<score_t> semiring;
Hg::init(hg.nodes, root, semiring);
- for (auto n = hg.nodes.begin(); n != hg.nodes.end(); n++) {
- for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); e++) {
- double s = semiring.one;
- for (auto m = (**e).tails.begin(); m != (**e).tails.end(); m++) {
- s = semiring.multiply(s, (**m).score);
+
+ for (auto n: hg.nodes) {
+ Edge* best_edge;
+ bool best = false;
+ for (auto e: n->incoming) {
+ score_t s = semiring.one;
+ for (auto m: e->tails) {
+ s = semiring.multiply(s, m->score);
+ }
+ if (n->score < semiring.multiply(s, e->score)) { // find max
+ best_edge = e;
+ best = true;
}
- (**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score));
+ n->score = semiring.add(n->score, semiring.multiply(s, e->score));
+ }
+ if (best)
+ p.push_back(best_edge);
+ }
+}
+
+
+void
+derive(const Path& p, const Node* cur, vector<string>& carry)
+{
+ Edge* next;
+ for (auto it: p) {
+ if (it->head->symbol == cur->symbol &&
+ it->head->left == cur->left &&
+ it->head->right == cur->right) {
+ next = it;
+ }
+ }
+ unsigned j = 0;
+ for (auto it: next->rule->target) {
+ if (it->type == G::NON_TERMINAL) {
+ derive(p, next->tails[next->rule->order[j]], carry);
+ j++;
+ } else {
+ carry.push_back(it->t->word);
}
}
}
@@ -68,7 +122,7 @@ viterbi(Hypergraph& hg)
namespace io {
void
-read(Hypergraph& hg, vector<G::Rule*> rules, string fn)
+read(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME
{
ifstream ifs(fn);
size_t i = 0, nr, nn, ne;
@@ -112,7 +166,7 @@ read(Hypergraph& hg, vector<G::Rule*> rules, string fn)
}
void
-write(Hypergraph& hg, vector<G::Rule*> rules, string fn)
+write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME
{
FILE* file = fopen(fn.c_str(), "wb");
msgpack::fbuffer fbuf(file);
@@ -129,7 +183,7 @@ write(Hypergraph& hg, vector<G::Rule*> rules, string fn)
}
void
-manual(Hypergraph& hg)
+manual(Hypergraph& hg, vector<G::Rule*>& rules)
{
// nodes
Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0;
@@ -149,60 +203,88 @@ manual(Hypergraph& hg)
Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0;
hg.nodes.push_back(h); hg.nodes_by_id[h->id] = h;
+ // rules
+ vector<string> rule_strs;
+ rule_strs.push_back("[NP] ||| ich ||| i ||| ||| ");
+ rule_strs.push_back("[V] ||| sah ||| saw ||| ||| ");
+ rule_strs.push_back("[JJ] ||| kleines ||| small ||| ||| ");
+ rule_strs.push_back("[JJ] ||| kleines ||| little ||| ||| ");
+ rule_strs.push_back("[NN] ||| kleines haus ||| small house ||| ||| ");
+ rule_strs.push_back("[NN] ||| kleines haus ||| little house ||| ||| ");
+ rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] shell ||| ||| ");
+ rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] house ||| ||| ");
+ rule_strs.push_back("[NP] ||| ein [NN,1] ||| a [NN,1] ||| ||| ");
+ rule_strs.push_back("[VP] ||| [V,1] [NP,2] ||| [V,1] [NP,2] ||| ||| ");
+ rule_strs.push_back("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| ||| ");
+
+ for (auto it: rule_strs) {
+ rules.push_back(new G::Rule(it));
+ rules.back()->f = new Sv::SparseVector<string, score_t>();
+ }
+
// edges
Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171;
q->arity = 1; q->mark = 0;
hg.edges.push_back(q);
hg.nodes_by_id[1]->incoming.push_back(q);
hg.nodes_by_id[0]->outgoing.push_back(q);
+ q->rule = rules[0];
Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713;
p->arity = 1; p->mark = 0;
hg.edges.push_back(p);
hg.nodes_by_id[2]->incoming.push_back(p);
hg.nodes_by_id[0]->outgoing.push_back(p);
+ p->rule = rules[1];
Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0;
r->arity = 1; r->mark = 0;
hg.edges.push_back(r);
hg.nodes_by_id[3]->incoming.push_back(r);
hg.nodes_by_id[0]->outgoing.push_back(r);
+ r->rule = rules[2];
Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0;
s->arity = 1; s->mark = 0;
hg.edges.push_back(s);
hg.nodes_by_id[3]->incoming.push_back(s);
hg.nodes_by_id[0]->outgoing.push_back(s);
+ s->rule = rules[3];
Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0;
t->arity = 1; t->mark = 0;
hg.edges.push_back(t);
hg.nodes_by_id[4]->incoming.push_back(t);
hg.nodes_by_id[0]->outgoing.push_back(t);
+ t->rule = rules[4];
Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0;
u->arity = 1; u->mark = 0;
hg.edges.push_back(u);
hg.nodes_by_id[4]->incoming.push_back(u);
hg.nodes_by_id[0]->outgoing.push_back(u);
+ u->rule = rules[5];
Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0;
v->arity = 1; v->mark = 0;
hg.edges.push_back(v);
hg.nodes_by_id[4]->incoming.push_back(v);
hg.nodes_by_id[3]->outgoing.push_back(v);
+ v->rule = rules[6];
Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846;
w->arity = 1; w->mark = 0;
hg.edges.push_back(w);
hg.nodes_by_id[4]->incoming.push_back(w);
hg.nodes_by_id[3]->outgoing.push_back(w);
+ w->rule = rules[7];
Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0;
x->arity = 1; x->mark = 0;
hg.edges.push_back(x);
hg.nodes_by_id[5]->incoming.push_back(x);
hg.nodes_by_id[4]->outgoing.push_back(x);
+ x->rule = rules[8];
Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0;
y->arity = 2; y->mark = 0;
@@ -210,6 +292,7 @@ manual(Hypergraph& hg)
hg.nodes_by_id[6]->incoming.push_back(y);
hg.nodes_by_id[2]->outgoing.push_back(y);
hg.nodes_by_id[5]->outgoing.push_back(y);
+ y->rule = rules[9];
Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0;
z->arity = 2; z->mark = 0;
@@ -217,10 +300,15 @@ manual(Hypergraph& hg)
hg.nodes_by_id[7]->incoming.push_back(z);
hg.nodes_by_id[1]->outgoing.push_back(z);
hg.nodes_by_id[6]->outgoing.push_back(z);
+ z->rule = rules[10];
}
} // namespace Hg::io
+/*
+ * Hg::Node
+ *
+ */
ostream&
operator<<(ostream& os, const Node& n)
{
@@ -235,12 +323,17 @@ operator<<(ostream& os, const Node& n)
return os;
}
+/*
+ * Hg::Edge
+ *
+ */
ostream&
operator<<(ostream& os, const Edge& e)
{
ostringstream _;
- for (auto it = e.tails.begin(); it != e.tails.end(); it++) {
- _ << (**it).id; if (*it != e.tails.back()) _ << ",";
+ for (auto it: e.tails) {
+ _ << it->id;
+ if (it != e.tails.back()) _ << ",";
}
os << \
"Edge<head=" << e.head->id << \