summaryrefslogtreecommitdiff
path: root/fast
diff options
context:
space:
mode:
Diffstat (limited to 'fast')
-rw-r--r--fast/Makefile15
-rw-r--r--fast/dummyvector.h1
-rw-r--r--fast/grammar.obin2872 -> 34344 bytes
-rw-r--r--fast/hypergraph.cc182
-rw-r--r--fast/hypergraph.hh25
-rw-r--r--fast/main.cc10
-rw-r--r--fast/make_paks.cc24
-rw-r--r--fast/read_pak.cc5
8 files changed, 129 insertions, 133 deletions
diff --git a/fast/Makefile b/fast/Makefile
index 1d88446..55c4df7 100644
--- a/fast/Makefile
+++ b/fast/Makefile
@@ -1,19 +1,20 @@
+COMPILER=clang
+
all: hypergraph.o main.cc
- clang -std=c++11 -lstdc++ -lm -lmsgpack hypergraph.o main.cc -o fast_weaver
+ $(COMPILER) -std=c++11 -lstdc++ -lm -lmsgpack hypergraph.o main.cc -o fast_weaver
hypergraph.o: hypergraph.cc hypergraph.hh grammar.o semiring.hh
- clang -std=c++11 -lmsgpack -c hypergraph.cc
+ $(COMPILER) -g -std=c++11 -lmsgpack -c hypergraph.cc
grammar.o: grammar.cc grammar.hh
- clang -std=c++11 -c grammar.cc
+ $(COMPILER) -g -std=c++11 -c grammar.cc
make_paks: make_paks.cc
- g++ -std=c++11 -lmsgpack make_paks.cc -o make_paks
+ $(COMPILER) -std=c++11 -lstdc++ -lm -lmsgpack make_paks.cc -o make_paks
read_pak: read_pak.cc
- g++ -std=c++11 -lmsgpack read_pak.cc -o read_pak
-
+ $(COMPILER) -std=c++11 -lmsgpack read_pak.cc -o read_pak
clean:
- rm -f fast_weaver hypergraph.o grammar.o
+ rm -f fast_weaver hypergraph.o grammar.o make_paks read_pak
diff --git a/fast/dummyvector.h b/fast/dummyvector.h
index 09cf3f7..18e2121 100644
--- a/fast/dummyvector.h
+++ b/fast/dummyvector.h
@@ -1,4 +1,5 @@
#pragma once
+
#include <msgpack.hpp>
diff --git a/fast/grammar.o b/fast/grammar.o
index aae25db..9582624 100644
--- a/fast/grammar.o
+++ b/fast/grammar.o
Binary files differ
diff --git a/fast/hypergraph.cc b/fast/hypergraph.cc
index 4e6601f..c3c587c 100644
--- a/fast/hypergraph.cc
+++ b/fast/hypergraph.cc
@@ -1,41 +1,30 @@
#include "hypergraph.hh"
-namespace Hg {
+namespace Hg {
/*
* Node
*
*/
-bool
-Node::is_marked()
-{
- return mark >= incoming.size();
-}
-
std::ostream&
operator<<(std::ostream& os, const Node& n)
{
os << \
"Node<id=" << n.id << \
- ", symbol=" << n.symbol << \
+ ", symbol='" << n.symbol << "'" << \
", span=(" << n.left << "," << n.right << ")" \
", score=" << n.score << \
", incoming:" << n.incoming.size() << \
", outgoing:" << n.outgoing.size() << \
", mark=" << n.mark << ">";
+ return os;
}
/*
* Edge
*
*/
-bool
-Edge::is_marked()
-{
- return mark >= arity;
-}
-
std::ostream&
operator<<(std::ostream& os, const Edge& e)
{
@@ -45,10 +34,10 @@ operator<<(std::ostream& os, const Edge& e)
}
os << \
"Edge<head=" << e.head->id << \
- "', tails=[" << _.str() << "]" \
+ ", tails=[" << _.str() << "]" \
", score=" << e.score << \
- ", rule:'" << "TODO" << \
- " , f=" << "TODO" << \
+ ", rule:'" << "TODO" << "'" << \
+ ", f=" << "TODO" << \
", arity=" << e.arity << \
", mark=" << e.mark << ">";
return os;
@@ -67,7 +56,7 @@ reset(list<Node*> nodes, vector<Edge*> edges)
(**it).mark = 0;
}
-template<class Semiring> void
+template<typename Semiring> void
init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring)
{
for (auto it = nodes.begin(); it != nodes.end(); ++it)
@@ -78,8 +67,10 @@ init(list<Node*>& nodes, list<Node*>::iterator root, Semiring& semiring)
void
topological_sort(list<Node*>& nodes, list<Node*>::iterator root)
{
+ cout << "root " << **root << endl;
+ for (auto it = nodes.begin(); it != nodes.end(); it++)
+ cout << (**it).id << endl;
auto p = root;
- (**p).mark = 0; // is_marked()==true
auto to = nodes.begin();
while (to != nodes.end()) {
if ((**p).is_marked()) {
@@ -97,21 +88,27 @@ topological_sort(list<Node*>& nodes, list<Node*>::iterator root)
p = to;
} else {
++p;
+ if (p == nodes.end()) {
+ p = next(to);
+ to = next(to);
+ }
}
}
+ cout << "---" << endl;
+ for (auto it = nodes.begin(); it != nodes.end(); it++)
+ cout << (**it).id << endl;
}
void
viterbi(Hypergraph& hg)
{
- list<Node*>::iterator root = hg.nodes.begin(); // FIXME?
+ list<Node*>::iterator root = \
+ find_if(hg.nodes.begin(), hg.nodes.end(), [](Node* n) { return n->incoming.size() == 0; });
Hg::topological_sort(hg.nodes, root);
Semiring::Viterbi<double> semiring;
Hg::init(hg.nodes, root, semiring);
-
for (auto n = hg.nodes.begin(); n != hg.nodes.end(); ++n) {
for (auto e = (**n).incoming.begin(); e != (**n).incoming.end(); ++e) {
- cout << **e << endl;
double s = semiring.one;
for (auto m = (**e).tails.begin(); m != (**e).tails.end(); ++m) {
s = semiring.multiply(s, (**m).score);
@@ -119,10 +116,6 @@ viterbi(Hypergraph& hg)
(**n).score = semiring.add((**n).score, semiring.multiply(s, (**e).score));
}
}
-
- for (auto it = hg.nodes.begin(); it != hg.nodes.end(); ++it) {
- cout << (**it).id << " " << (**it).score << endl;
- }
}
namespace io {
@@ -134,30 +127,39 @@ read(Hypergraph& hg, string fn)
size_t i = 0, nn, ne;
msgpack::unpacker pac;
while(true) {
- pac.reserve_buffer(32*1024);
- size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity());
- pac.buffer_consumed(bytes);
- msgpack::unpacked result;
- while(pac.next(&result)) {
- msgpack::object o = result.get();
- if (i == 0) {
- o.convert(&nn);
- nn += 1;
- } else if (i == 1) {
- o.convert(&ne);
- ne += 1;
- } else if (i > 1 && i <= nn) {
- //cout << "N " << o << endl;
- Node* n = new Node;
- o.convert(n);
- } else if (i > nn && i <= nn+ne+1) {
- //cout << "E " << o << endl;
- Edge* e = new Edge;
- o.convert(e);
- }
- i++;
+ pac.reserve_buffer(32*1024);
+ size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity());
+ pac.buffer_consumed(bytes);
+ msgpack::unpacked result;
+ while(pac.next(&result)) {
+ msgpack::object o = result.get();
+ if (i == 0) {
+ o.convert(&nn);
+ nn += 1;
+ } else if (i == 1) {
+ o.convert(&ne);
+ ne += 1;
+ } else if (i > 1 && i <= nn) {
+ Node* n = new Node;
+ o.convert(n);
+ hg.nodes.push_front(n); // FIXME
+ hg.nodes_by_id[n->id] = n;
+ } else if (i > nn && i <= nn+ne+1) {
+ Edge* e = new Edge;
+ e->arity = 0;
+ o.convert(e);
+ e->head = hg.nodes_by_id[e->head_id_];
+ hg.edges.push_back(e);
+ hg.nodes_by_id[e->head_id_]->incoming.push_back(e);
+ for (auto it = e->tails_ids_.begin(); it != e->tails_ids_.end(); ++it) {
+ hg.nodes_by_id[*it]->outgoing.push_back(e);
+ e->tails.push_back(hg.nodes_by_id[*it]);
+ e->arity++;
+ }
}
- if (!bytes) break;
+ i++;
+ }
+ if (!bytes) break;
}
}
@@ -181,107 +183,95 @@ void
manual(Hypergraph& hg)
{
// nodes
- Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = false; a->right = false; a->mark = 0;
- Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0;
- Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0;
- Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0;
- Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0;
- Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0;
- Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0;
- Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0;
+ Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0;
+ Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0;
+ Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0;
+ Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0;
+ Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0;
+ Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0;
+ Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0;
+ Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0;
- hg.add_node(a);
- hg.add_node(h);
- hg.add_node(g);
+ hg.add_node(b);
hg.add_node(c);
hg.add_node(d);
- hg.add_node(f);
- hg.add_node(b);
hg.add_node(e);
+ hg.add_node(f);
+ hg.add_node(g);
+ hg.add_node(h);
+ hg.add_node(a);
// edges
Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171;
+ q->arity = 1; q->mark = 0;
+ hg.edges.push_back(q);
hg.nodes_by_id[1]->incoming.push_back(q);
hg.nodes_by_id[0]->outgoing.push_back(q);
- q->arity = 1;
- q->mark = 0;
- hg.edges.push_back(q);
Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713;
+ p->arity = 1; p->mark = 0;
+ hg.edges.push_back(p);
hg.nodes_by_id[2]->incoming.push_back(p);
hg.nodes_by_id[0]->outgoing.push_back(p);
- p->arity = 1;
- p->mark = 0;
- hg.edges.push_back(p);
Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0;
+ r->arity = 1; r->mark = 0;
+ hg.edges.push_back(r);
hg.nodes_by_id[3]->incoming.push_back(r);
hg.nodes_by_id[0]->outgoing.push_back(r);
- r->arity = 1;
- r->mark = 0;
- hg.edges.push_back(r);
Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0;
+ s->arity = 1; s->mark = 0;
+ hg.edges.push_back(s);
hg.nodes_by_id[3]->incoming.push_back(s);
hg.nodes_by_id[0]->outgoing.push_back(s);
- s->arity = 1;
- s->mark = 0;
- hg.edges.push_back(s);
Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0;
+ t->arity = 1; t->mark = 0;
+ hg.edges.push_back(t);
hg.nodes_by_id[4]->incoming.push_back(t);
hg.nodes_by_id[0]->outgoing.push_back(t);
- t->arity = 1;
- t->mark = 0;
- hg.edges.push_back(t);
Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0;
+ u->arity = 1; u->mark = 0;
+ hg.edges.push_back(u);
hg.nodes_by_id[4]->incoming.push_back(u);
hg.nodes_by_id[0]->outgoing.push_back(u);
- u->arity = 1;
- u->mark = 0;
- hg.edges.push_back(u);
Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0;
+ v->arity = 1; v->mark = 0;
+ hg.edges.push_back(v);
hg.nodes_by_id[4]->incoming.push_back(v);
hg.nodes_by_id[3]->outgoing.push_back(v);
- v->arity = 1;
- v->mark = 0;
- hg.edges.push_back(v);
Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846;
+ w->arity = 1; w->mark = 0;
+ hg.edges.push_back(w);
hg.nodes_by_id[4]->incoming.push_back(w);
hg.nodes_by_id[3]->outgoing.push_back(w);
- w->arity = 1;
- w->mark = 0;
- hg.edges.push_back(w);
Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0;
+ x->arity = 1; x->mark = 0;
+ hg.edges.push_back(x);
hg.nodes_by_id[5]->incoming.push_back(x);
hg.nodes_by_id[4]->outgoing.push_back(x);
- x->arity = 1;
- x->mark = 0;
- hg.edges.push_back(x);
Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0;
+ y->arity = 2; y->mark = 0;
+ hg.edges.push_back(y);
hg.nodes_by_id[6]->incoming.push_back(y);
hg.nodes_by_id[2]->outgoing.push_back(y);
hg.nodes_by_id[5]->outgoing.push_back(y);
- y->arity = 2;
- y->mark = 0;
- hg.edges.push_back(y);
Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0;
+ z->arity = 2; z->mark = 0;
+ hg.edges.push_back(z);
hg.nodes_by_id[7]->incoming.push_back(z);
hg.nodes_by_id[1]->outgoing.push_back(z);
hg.nodes_by_id[6]->outgoing.push_back(z);
- z->arity = 2;
- z->mark = 0;
- hg.edges.push_back(z);
}
} // namespace
-
} // namespace
diff --git a/fast/hypergraph.hh b/fast/hypergraph.hh
index 2e30911..530fbe6 100644
--- a/fast/hypergraph.hh
+++ b/fast/hypergraph.hh
@@ -1,7 +1,5 @@
#pragma once
-#include "grammar.hh"
-#include "semiring.hh"
#include <iostream>
#include <string>
#include <sstream>
@@ -12,17 +10,19 @@
#include <algorithm>
#include <iterator>
#include <fstream>
+#include <msgpack.hpp>
+#include "grammar.hh"
+#include "semiring.hh"
#include "dummyvector.h"
-#include <msgpack.hpp>
using namespace std;
typedef double score_t;
typedef double weight_t;
-namespace Hg {
+namespace Hg {
struct Node;
@@ -31,11 +31,11 @@ struct Edge {
vector<Node*> tails;
score_t score;
string rule; //FIXME
- DummyVector f; //FIXME
+ DummyVector f; //FIXME
unsigned int arity;
- unsigned int mark;
+ unsigned int mark = 0;
- bool is_marked();
+ inline bool is_marked() { return mark >= arity; }
friend std::ostream& operator<<(std::ostream& os, const Edge& s);
size_t head_id_;
@@ -47,19 +47,17 @@ struct Edge {
struct Node {
size_t id;
string symbol;
- unsigned short left;
- unsigned short right;
+ short left;
+ short right;
score_t score;
vector<Edge*> incoming;
vector<Edge*> outgoing;
unsigned int mark;
- bool is_marked();
+ inline bool is_marked() { return mark >= incoming.size(); };
friend std::ostream& operator<<(std::ostream& os, const Node& n);
- vector<size_t> incoming_ids_; // edge ids
- vector<size_t> outgoing_ids_; // edge ids
- MSGPACK_DEFINE(id, symbol, left, right, score, incoming_ids_, outgoing_ids_);
+ MSGPACK_DEFINE(id, symbol, left, right, score);
};
struct Hypergraph {
@@ -96,6 +94,5 @@ manual(Hypergraph& hg);
} // namespace
-
} // namespace
diff --git a/fast/main.cc b/fast/main.cc
index a7b5837..9c64976 100644
--- a/fast/main.cc
+++ b/fast/main.cc
@@ -7,6 +7,14 @@ main(int argc, char** argv)
Hg::Hypergraph hg;
//Hg::io::manual(hg);
Hg::io::read(hg, argv[1]);
- //Hg::viterbi(hg);
+ /*cout << "---" << endl;
+ for (auto it = hg.nodes.begin(); it!=hg.nodes.end(); it++)
+ cout << **it << endl;
+ for (auto it = hg.edges.begin(); it!=hg.edges.end(); it++)
+ cout << **it << endl;
+ cout << "---" << endl;*/
+ Hg::viterbi(hg);
+
+ return 0;
}
diff --git a/fast/make_paks.cc b/fast/make_paks.cc
index 6fe7fae..c0fee90 100644
--- a/fast/make_paks.cc
+++ b/fast/make_paks.cc
@@ -1,26 +1,25 @@
#include <iostream>
#include <fstream>
#include <string>
+#include <unordered_map>
#include <msgpack.hpp>
-#include <msgpack/fbuffer.h>
#include <msgpack/fbuffer.hpp>
-#include <unordered_map>
#include "json-cpp.hpp"
-#include "hypergraph.hh"
#include "dummyvector.h"
+#include "hypergraph.hh"
using namespace std;
struct DummyNode {
- int id;
+ size_t id;
string cat;
- vector<int> span;
+ vector<short> span;
};
struct DummyEdge {
- int head;
+ size_t head;
string rule;
vector<size_t> tails;
DummyVector f;
@@ -57,15 +56,13 @@ serialize(jsoncpp::Stream<X>& stream, DummyVector& o)
fields(o, stream, "EgivenFCoherent", o.EgivenFCoherent, "SampleCountF", o.SampleCountF, "CountEF", o.CountEF, "MaxLexFgivenE", o.MaxLexFgivenE, "MaxLexEgivenF", o.MaxLexEgivenF, "IsSingletonF", o.IsSingletonF, "IsSingletonFE", o.IsSingletonFE, "LanguageModel", o.LanguageModel, "LanguageModel_OOV", o.LanguageModel_OOV, "PassThrough", o.PassThrough, "PassThrough_1", o.PassThrough_1, "PassThrough_2", o.PassThrough_2, "PassThrough_3", o.PassThrough_3, "PassThrough_4", o.PassThrough_4, "PassThrough_5", o.PassThrough_5, "PassThrough_6", o.PassThrough_6, "WordPenalty", o.WordPenalty, "Glue", o.Glue);
}
-
-
int
main(int argc, char** argv)
{
+ // read from json
ifstream ifs(argv[1]);
string json_str((istreambuf_iterator<char>(ifs) ),
(istreambuf_iterator<char>()));
-
DummyHg hg;
vector<DummyNode> nodes;
hg.nodes = nodes;
@@ -75,19 +72,19 @@ main(int argc, char** argv)
hg.weights = w;
jsoncpp::parse(hg, json_str);
+ // convert objects
vector<Hg::Node*> nodes_;
for (auto it = hg.nodes.begin(); it != hg.nodes.end(); ++it) {
- Hg::Node* n = new Hg::Node;
- n->id = it->id;
+ Hg::Node* n = new Hg::Node;
+ n->id = it->id;
n->symbol = it->cat;
n->left = it->span[0];
n->right = it->span[1];
nodes_.push_back(n);
}
-
vector<Hg::Edge*> edges_;
for (auto it = hg.edges.begin(); it != hg.edges.end(); ++it) {
- Hg::Edge* e = new Hg::Edge;
+ Hg::Edge* e = new Hg::Edge;
e->head_id_ = it->head;
e->tails_ids_ = it->tails;
e->score = it->weight;
@@ -96,6 +93,7 @@ main(int argc, char** argv)
edges_.push_back(e);
}
+ // write to msgpack
FILE* file = fopen(argv[2], "wb");
msgpack::fbuffer fbuf(file);
msgpack::pack(fbuf, hg.nodes.size());
diff --git a/fast/read_pak.cc b/fast/read_pak.cc
index 81eed5d..c1cf761 100644
--- a/fast/read_pak.cc
+++ b/fast/read_pak.cc
@@ -1,6 +1,6 @@
-#include <msgpack.hpp>
#include <iostream>
#include <fstream>
+#include <msgpack.hpp>
using namespace std;
@@ -20,7 +20,8 @@ main(int argc, char** argv)
msgpack::object o = result.get();
cout << o << endl;
}
-
if (!bytes) break;
}
+
+ return 0;
}