diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/fast_weaver.cc | 26 | ||||
-rw-r--r-- | src/grammar.hh | 334 | ||||
-rw-r--r-- | src/hypergraph.cc | 362 | ||||
-rw-r--r-- | src/hypergraph.hh | 102 | ||||
-rw-r--r-- | src/make_pak.cc | 103 | ||||
-rw-r--r-- | src/parse.hh | 301 | ||||
-rw-r--r-- | src/read_pak.cc | 26 | ||||
-rw-r--r-- | src/semiring.hh | 35 | ||||
-rw-r--r-- | src/sparse_vector.hh | 186 | ||||
-rw-r--r-- | src/test_grammar.cc | 19 | ||||
-rw-r--r-- | src/test_parse.cc | 19 | ||||
-rw-r--r-- | src/test_sparse_vector.cc | 36 | ||||
-rw-r--r-- | src/types.hh | 10 | ||||
-rw-r--r-- | src/util.hh | 47 |
14 files changed, 1606 insertions, 0 deletions
diff --git a/src/fast_weaver.cc b/src/fast_weaver.cc new file mode 100644 index 0000000..4854476 --- /dev/null +++ b/src/fast_weaver.cc @@ -0,0 +1,26 @@ +#include "hypergraph.hh" +#include <ctime> + +int +main(int argc, char** argv) +{ + Hg::Hypergraph hg; + G::Vocabulary y; + G::Grammar g; + Hg::io::read(hg, g.rules, y, argv[1]); + //Hg::io::manual(hg, g.rules); + clock_t begin = clock(); + Hg::Path p; + Hg::viterbi_path(hg, p); + vector<string> s; + Hg::derive(p, p.back()->head, s); + for (auto it: s) + cout << it << " "; + cout << endl; + clock_t end = clock(); + double elapsed_secs = double(end - begin) / CLOCKS_PER_SEC; + cout << elapsed_secs << " s" << endl; + + return 0; +} + diff --git a/src/grammar.hh b/src/grammar.hh new file mode 100644 index 0000000..c489ec5 --- /dev/null +++ b/src/grammar.hh @@ -0,0 +1,334 @@ +#pragma once + +#include <iostream> +#include <fstream> +#include <map> +#include <sstream> +#include <string> +#include <vector> +#include <set> + +#include <msgpack.hpp> + +#include "sparse_vector.hh" +#include "util.hh" + +using namespace std; + +namespace G { + +enum item_type { + UNDEFINED, + TERMINAL, + NON_TERMINAL +}; + +struct Item { + virtual size_t index() const { return 0; } + virtual symbol_t symbol() const { return ""; } + virtual item_type type() const { return UNDEFINED; } + + virtual ostream& repr(ostream& os) const { return os << "Item<>"; } + virtual ostream& escaped(ostream& os) const { return os << ""; } + + friend ostream& + operator<<(ostream& os, const Item& i) + { + return i.repr(os); + }; +}; + +struct NT : public Item { + symbol_t symbol_; + size_t index_; + + NT() {} + + NT(string const& s) + { + index_ = 0; // default + string t(s); + t.erase(0, 1); t.pop_back(); // remove '[' and ']' + istringstream ss(t); + if (ss >> index_) { // [i] + symbol_ = ""; + index_ = stoi(s); + return; + } else { + ss.clear(); + string buf; + size_t j = 0; + while (ss.good() && getline(ss, buf, ',')) { + if (j == 0) { + symbol_ = buf; + } else { + index_ = stoi(buf); + } + j++; + } + } + } + + virtual size_t index() const { return index_; } + virtual symbol_t symbol() const { return symbol_; } + virtual item_type type() const { return NON_TERMINAL; } + + virtual ostream& + repr(ostream& os) const + { + return os << "NT<" << symbol_ << "," << index_ << ">"; + } + + virtual ostream& + escaped(ostream& os) const + { + os << "[" << symbol_; + if (index_ > 0) + os << "," << index_; + os << "]"; + + return os; + } +}; + +struct T : public Item { + symbol_t symbol_; + + T(string const& s) + { + symbol_ = s; + } + + virtual symbol_t symbol() const { return symbol_; } + virtual item_type type() const { return TERMINAL; } + + virtual ostream& + repr(ostream& os) const + { + return os << "T<" << symbol_ << ">"; + } + + virtual ostream& + escaped(ostream& os) const + { + os << util::json_escape(symbol_); + } +}; + +struct Vocabulary +{ + unordered_map<symbol_t, size_t> m_; + vector<Item*> items_; + + bool + is_non_terminal(string const& s) + { + return s.front() == '[' && s.back() == ']'; + } + + Item* + get(symbol_t const& s) + { + if (is_non_terminal(s)) + return new NT(s); + if (m_.find(s) != m_.end()) + return items_[m_[s]]; + return add(s); + } + + Item* + add(symbol_t const& s) + { + size_t next_index_ = items_.size(); + T* item = new T(s); + items_.push_back(item); + m_[s] = next_index_; + + return item; + } +}; + +struct Rule { + NT* lhs; + vector<Item*> rhs; + vector<Item*> target; + size_t arity; +Sv::SparseVector<string, score_t>* f; + map<size_t, size_t> order; + string as_str_; + + Rule() {} + + Rule(string const& s, Vocabulary& vocab) { from_s(this, s, vocab); } + + static void + from_s(Rule* r, string const& s, Vocabulary& vocab) + { + istringstream ss(s); + string buf; + size_t j = 0, i = 1; + r->arity = 0; + vector<NT*> rhs_non_terminals; + r->f = new Sv::SparseVector<string, score_t>(); + while (ss >> buf) { + if (buf == "|||") { j++; continue; } + if (j == 0) { // left-hand side + r->lhs = new NT(buf); + } else if (j == 1) { // right-hand side + Item* item = vocab.get(buf); + r->rhs.push_back(item); + if (item->type() == NON_TERMINAL) { + r->arity++; + rhs_non_terminals.push_back(reinterpret_cast<NT*>(item)); + } + } else if (j == 2) { // target + Item* item = vocab.get(buf); + if (item->type() == NON_TERMINAL) { + r->order.insert(make_pair(i, item->index())); + i++; + if (item->symbol() == "") { // only [1], [2] ... on target + reinterpret_cast<NT*>(item)->symbol_ = \ + rhs_non_terminals[item->index()-1]->symbol(); + } + } + r->target.push_back(item); + } else if (j == 3) { // feature vector + Sv::SparseVector<string, score_t>::from_s(r->f, buf); + // FIXME: this is slow!!! ^^^ + } else if (j == 4) { // alignment + } else { + // error + } + if (j == 4) break; + } + } + + ostream& + repr(ostream& os) const + { + os << "Rule<lhs="; + lhs->repr(os); + os << ", rhs:{"; + for (auto it = rhs.begin(); it != rhs.end(); it++) { + (**it).repr(os); + if (next(it) != rhs.end()) os << " "; + } + os << "}, target:{"; + for (auto it = target.begin(); it != target.end(); it++) { + (**it).repr(os); + if (next(it) != target.end()) os << " "; + } + os << "}, f:"; + f->repr(os); + os << ", arity=" << arity << \ + ", order:{"; + for (auto it = order.begin(); it != order.end(); it++) { + os << it->first << "->" << it->second; + if (next(it) != order.end()) os << ", "; + } + os << "}>"; + + return os; + } + + ostream& + escaped(ostream& os) const + { + lhs->escaped(os); + os << " ||| "; + for (auto it = rhs.begin(); it != rhs.end(); it++) { + (**it).escaped(os); + if (next(it) != rhs.end()) os << " "; + } + os << " ||| "; + for (auto it = target.begin(); it != target.end(); it++) { + (**it).escaped(os); + if (next(it) != target.end()) os << " "; + } + os << " ||| "; + f->escaped(os); + os << " ||| " << \ + "TODO"; + + return os; + }; + + friend ostream& + operator<<(ostream& os, Rule const& r) + { + return r.repr(os); + }; + + // -- + void + prep_for_serialization_() + { + ostringstream os; + escaped(os); + as_str_ = os.str(); + }; + MSGPACK_DEFINE(as_str_); + // ^^^ FIXME +}; + +struct Grammar { + vector<Rule*> rules; + vector<Rule*> flat; + vector<Rule*> start_non_terminal; + vector<Rule*> start_terminal; + set<symbol_t> nts; + + Grammar() {} + + Grammar(string const& fn, Vocabulary& vocab) + { + ifstream ifs(fn); + string line; + while (getline(ifs, line)) { + G::Rule* r = new G::Rule(line, vocab); + rules.push_back(r); + nts.insert(r->lhs->symbol()); + if (r->arity == 0) + flat.push_back(r); + else if (r->rhs.front()->type() == NON_TERMINAL) + start_non_terminal.push_back(r); + else + start_terminal.push_back(r); + } + } + + void + add_glue(Vocabulary& vocab) + { + for (auto nt: nts) { + ostringstream oss_1; + oss_1 << "[S] ||| [" << nt << ",1] ||| [" << nt << ",1] ||| "; + cout << oss_1.str() << endl; + Rule* r1 = new Rule(oss_1.str(), vocab); + rules.push_back(r1); start_non_terminal.push_back(r1); + ostringstream oss_2; + oss_2 << "[S] ||| [S,1] [" << nt << ",2] ||| [S,1] [" << nt << ",2] ||| "; + cout << oss_2.str() << endl; + Rule* r2 = new Rule(oss_2.str(), vocab); + cout << *r2 << endl; + rules.push_back(r2); start_non_terminal.push_back(r2); + } + } + + void add_pass_through(const string& input); + // ^^^ TODO + + friend ostream& + operator<<(ostream& os, Grammar const& g) + { + for (const auto it: g.rules) { + it->repr(os); + os << endl; + } + + return os; + } +}; + +} // namespace G + diff --git a/src/hypergraph.cc b/src/hypergraph.cc new file mode 100644 index 0000000..40bcc64 --- /dev/null +++ b/src/hypergraph.cc @@ -0,0 +1,362 @@ +#include "hypergraph.hh" + +namespace Hg { + +template<typename Semiring> void +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring) +{ + for (const auto it: nodes) + it->score = semiring.null; + (**root).score = semiring.one; +} + +void +reset(const list<Node*> nodes, const vector<Edge*> edges) +{ + for (const auto it: nodes) + it->mark = 0; + for (auto it: edges) + it->mark = 0; +} + +void +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root) +{ + auto p = root; + auto to = nodes.begin(); + while (to != nodes.end()) { + if ((**p).is_marked()) { + for (const auto e: (**p).outgoing) { // explore edges + e->mark++; + if (e->is_marked()) { + e->head->mark++; + } + } + } + if ((**p).is_marked()) { + nodes.splice(to, nodes, p); + to++; + p = to; + } else { + ++p; + } + } +} + +void +viterbi(Hypergraph& hg) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + + Hg::topological_sort(hg.nodes, root); + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (const auto n: hg.nodes) { + for (const auto e: n->incoming) { + score_t s = semiring.one; + for (const auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + } +} + +void +viterbi_path(Hypergraph& hg, Path& p) +{ + list<Node*>::iterator root = \ + find_if(hg.nodes.begin(), hg.nodes.end(), \ + [](Node* n) { return n->incoming.size() == 0; }); + //list<Node*>::iterator root = hg.nodes.begin(); + + Hg::topological_sort(hg.nodes, root); + // ^^^ FIXME do I need to do this when reading from file? + Semiring::Viterbi<score_t> semiring; + Hg::init(hg.nodes, root, semiring); + + for (auto n: hg.nodes) { + Edge* best_edge; + bool best = false; + for (auto e: n->incoming) { + score_t s = semiring.one; + for (auto m: e->tails) { + s = semiring.multiply(s, m->score); + } + if (n->score < semiring.multiply(s, e->score)) { // find max + best_edge = e; + best = true; + } + n->score = semiring.add(n->score, semiring.multiply(s, e->score)); + } + if (best) + p.push_back(best_edge); + } +} + + +void +derive(const Path& p, const Node* cur, vector<string>& carry) +{ + Edge* next; + for (auto it: p) { + if (it->head->symbol == cur->symbol && + it->head->left == cur->left && + it->head->right == cur->right) { + next = it; + } + } // FIXME this is probably not so good + + unsigned j = 0; + for (auto it: next->rule->target) { + if (it->type() == G::NON_TERMINAL) { + derive(p, next->tails[next->rule->order[j]], carry); + j++; + } else { + carry.push_back(it->symbol()); + } + } +} + +namespace io { + +void +read(Hypergraph& hg, vector<G::Rule*>& rules, G::Vocabulary& vocab, const string& fn) +{ + ifstream ifs(fn); + size_t i = 0, r, n, e; + msgpack::unpacker pac; + while(true) { + pac.reserve_buffer(32*1024); + size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + pac.buffer_consumed(bytes); + msgpack::unpacked result; + while(pac.next(&result)) { + msgpack::object o = result.get(); + if (i == 0) { + o.convert(&r); + } else if (i == 1) { + o.convert(&n); + } else if (i == 2) { + o.convert(&e); + } else if (i > 2 && i <= r+2) { + string s; + o.convert(&s); + G::Rule* rule = new G::Rule; + G::Rule::from_s(rule, s, vocab); + rules.push_back(rule); + } else if (i > r+2 && i <= r+n+2) { + Node* n = new Node; + o.convert(n); + hg.nodes.push_back(n); + hg.nodes_by_id[n->id] = n; + } else if (i > n+2 && i <= r+n+e+2) { + Edge* e = new Edge; + e->arity = 0; + o.convert(e); + e->head = hg.nodes_by_id[e->head_id_]; + hg.edges.push_back(e); + hg.nodes_by_id[e->head_id_]->incoming.push_back(e); + e->arity = 0; + for (auto it = e->tails_ids_.begin(); it != e->tails_ids_.end(); it++) { + hg.nodes_by_id[*it]->outgoing.push_back(e); + e->tails.push_back(hg.nodes_by_id[*it]); + e->arity++; + } + e->rule = rules[e->rule_id_]; + } else { + // ERROR + } + i++; + } + if (!bytes) break; + } +} + +void +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn) // FIXME +{ + FILE* file = fopen(fn.c_str(), "wb"); + msgpack::fbuffer fbuf(file); + msgpack::pack(fbuf, rules.size()); + msgpack::pack(fbuf, hg.nodes.size()); + msgpack::pack(fbuf, hg.edges.size()); + for (auto it = rules.cbegin(); it != rules.cend(); it++) + msgpack::pack(fbuf, **it); + for (auto it = hg.nodes.cbegin(); it != hg.nodes.cend(); it++) + msgpack::pack(fbuf, **it); + for (auto it = hg.edges.cbegin(); it != hg.edges.cend(); it++) + msgpack::pack(fbuf, **it); + fclose(file); +} + +void +manual(Hypergraph& hg, vector<G::Rule*>& rules, G::Vocabulary& vocab) +{ + // nodes + Node* a = new Node; a->id = 0; a->symbol = "root"; a->left = -1; a->right = -1; a->mark = 0; + hg.nodes.push_back(a); hg.nodes_by_id[a->id] = a; + Node* b = new Node; b->id = 1; b->symbol = "NP"; b->left = 0; b->right = 1; b->mark = 0; + hg.nodes.push_back(b); hg.nodes_by_id[b->id] = b; + Node* c = new Node; c->id = 2; c->symbol = "V"; c->left = 1; c->right = 2; c->mark = 0; + hg.nodes.push_back(c); hg.nodes_by_id[c->id] = c; + Node* d = new Node; d->id = 3; d->symbol = "JJ"; d->left = 3; d->right = 4; d->mark = 0; + hg.nodes.push_back(d); hg.nodes_by_id[d->id] = d; + Node* e = new Node; e->id = 4; e->symbol = "NN"; e->left = 3; e->right = 5; e->mark = 0; + hg.nodes.push_back(e); hg.nodes_by_id[e->id] = e; + Node* f = new Node; f->id = 5; f->symbol = "NP"; f->left = 2; f->right = 5; f->mark = 0; + hg.nodes.push_back(f); hg.nodes_by_id[f->id] = f; + Node* g = new Node; g->id = 6; g->symbol = "NP"; g->left = 1; g->right = 5; g->mark = 0; + hg.nodes.push_back(g); hg.nodes_by_id[g->id] = g; + Node* h = new Node; h->id = 7; h->symbol = "S"; h->left = 0; h->right = 6; h->mark = 0; + hg.nodes.push_back(h); hg.nodes_by_id[h->id] = h; + + // rules + vector<string> rule_strs; + rule_strs.push_back("[NP] ||| ich ||| i ||| ||| "); + rule_strs.push_back("[V] ||| sah ||| saw ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| small ||| ||| "); + rule_strs.push_back("[JJ] ||| kleines ||| little ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| small house ||| ||| "); + rule_strs.push_back("[NN] ||| kleines haus ||| little house ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] shell ||| ||| "); + rule_strs.push_back("[NN] ||| [JJ,1] haus ||| [JJ,1] house ||| ||| "); + rule_strs.push_back("[NP] ||| ein [NN,1] ||| a [NN,1] ||| ||| "); + rule_strs.push_back("[VP] ||| [V,1] [NP,2] ||| [V,1] [NP,2] ||| ||| "); + rule_strs.push_back("[S] ||| [NP,1] [VP,2] ||| [NP,1] [VP,2] ||| ||| "); + + for (auto it: rule_strs) { + rules.push_back(new G::Rule(it, vocab)); + rules.back()->f = new Sv::SparseVector<string, score_t>(); + } + + // edges + Edge* q = new Edge; q->head = hg.nodes_by_id[1]; q->tails.push_back(hg.nodes_by_id[0]); q->score = 0.367879441171; + q->arity = 1; q->mark = 0; + hg.edges.push_back(q); + hg.nodes_by_id[1]->incoming.push_back(q); + hg.nodes_by_id[0]->outgoing.push_back(q); + q->rule = rules[0]; + + Edge* p = new Edge; p->head = hg.nodes_by_id[2]; p->tails.push_back(hg.nodes_by_id[0]); p->score = 0.606530659713; + p->arity = 1; p->mark = 0; + hg.edges.push_back(p); + hg.nodes_by_id[2]->incoming.push_back(p); + hg.nodes_by_id[0]->outgoing.push_back(p); + p->rule = rules[1]; + + Edge* r = new Edge; r->head = hg.nodes_by_id[3]; r->tails.push_back(hg.nodes_by_id[0]); r->score = 1.0; + r->arity = 1; r->mark = 0; + hg.edges.push_back(r); + hg.nodes_by_id[3]->incoming.push_back(r); + hg.nodes_by_id[0]->outgoing.push_back(r); + r->rule = rules[2]; + + Edge* s = new Edge; s->head = hg.nodes_by_id[3]; s->tails.push_back(hg.nodes_by_id[0]); s->score = 1.0; + s->arity = 1; s->mark = 0; + hg.edges.push_back(s); + hg.nodes_by_id[3]->incoming.push_back(s); + hg.nodes_by_id[0]->outgoing.push_back(s); + s->rule = rules[3]; + + Edge* t = new Edge; t->head = hg.nodes_by_id[4]; t->tails.push_back(hg.nodes_by_id[0]); t->score = 1.0; + t->arity = 1; t->mark = 0; + hg.edges.push_back(t); + hg.nodes_by_id[4]->incoming.push_back(t); + hg.nodes_by_id[0]->outgoing.push_back(t); + t->rule = rules[4]; + + Edge* u = new Edge; u->head = hg.nodes_by_id[4]; u->tails.push_back(hg.nodes_by_id[0]); u->score = 1.0; + u->arity = 1; u->mark = 0; + hg.edges.push_back(u); + hg.nodes_by_id[4]->incoming.push_back(u); + hg.nodes_by_id[0]->outgoing.push_back(u); + u->rule = rules[5]; + + Edge* v = new Edge; v->head = hg.nodes_by_id[4]; v->tails.push_back(hg.nodes_by_id[3]); v->score = 1.0; + v->arity = 1; v->mark = 0; + hg.edges.push_back(v); + hg.nodes_by_id[4]->incoming.push_back(v); + hg.nodes_by_id[3]->outgoing.push_back(v); + v->rule = rules[6]; + + Edge* w = new Edge; w->head = hg.nodes_by_id[4]; w->tails.push_back(hg.nodes_by_id[3]); w->score = 2.71828182846; + w->arity = 1; w->mark = 0; + hg.edges.push_back(w); + hg.nodes_by_id[4]->incoming.push_back(w); + hg.nodes_by_id[3]->outgoing.push_back(w); + w->rule = rules[7]; + + Edge* x = new Edge; x->head = hg.nodes_by_id[5]; x->tails.push_back(hg.nodes_by_id[4]); x->score = 1.0; + x->arity = 1; x->mark = 0; + hg.edges.push_back(x); + hg.nodes_by_id[5]->incoming.push_back(x); + hg.nodes_by_id[4]->outgoing.push_back(x); + x->rule = rules[8]; + + Edge* y = new Edge; y->head = hg.nodes_by_id[6]; y->tails.push_back(hg.nodes_by_id[2]); y->tails.push_back(hg.nodes_by_id[5]); y->score = 1.0; + y->arity = 2; y->mark = 0; + hg.edges.push_back(y); + hg.nodes_by_id[6]->incoming.push_back(y); + hg.nodes_by_id[2]->outgoing.push_back(y); + hg.nodes_by_id[5]->outgoing.push_back(y); + y->rule = rules[9]; + + Edge* z = new Edge; z->head = hg.nodes_by_id[7]; z->tails.push_back(hg.nodes_by_id[1]); z->tails.push_back(hg.nodes_by_id[6]); z->score = 1.0; + z->arity = 2; z->mark = 0; + hg.edges.push_back(z); + hg.nodes_by_id[7]->incoming.push_back(z); + hg.nodes_by_id[1]->outgoing.push_back(z); + hg.nodes_by_id[6]->outgoing.push_back(z); + z->rule = rules[10]; +} + +} // namespace Hg::io + +/* + * Hg::Node + * + */ +ostream& +operator<<(ostream& os, const Node& n) +{ + os << \ + "Node<id=" << n.id << \ + ", symbol='" << n.symbol << "'" << \ + ", span=(" << n.left << "," << n.right << ")" \ + ", score=" << n.score << \ + ", incoming:" << n.incoming.size() << \ + ", outgoing:" << n.outgoing.size() << \ + ", mark=" << n.mark << ">"; + return os; +} + +/* + * Hg::Edge + * + */ +ostream& +operator<<(ostream& os, const Edge& e) +{ + ostringstream _; + for (auto it: e.tails) { + _ << it->id; + if (it != e.tails.back()) _ << ","; + } + os << \ + "Edge<head=" << e.head->id << \ + ", tails=[" << _.str() << "]" \ + ", score=" << e.score << \ + ", rule:'"; + e.rule->escaped(os); + os << "', f=" << "TODO" << \ + ", arity=" << e.arity << \ + ", mark=" << e.mark << ">"; + return os; +} + +} // namespace Hg + diff --git a/src/hypergraph.hh b/src/hypergraph.hh new file mode 100644 index 0000000..8e05e9f --- /dev/null +++ b/src/hypergraph.hh @@ -0,0 +1,102 @@ +#pragma once + +#include <algorithm> +#include <fstream> +#include <functional> +#include <iostream> +#include <iterator> +#include <list> +#include <msgpack.hpp> +#include <msgpack/fbuffer.hpp> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +#include "grammar.hh" +#include "semiring.hh" +#include "sparse_vector.hh" +#include "types.hh" + +using namespace std; + +namespace Hg { + +struct Node; + +struct Edge { + Node* head; + vector<Node*> tails; + score_t score; + G::Rule* rule; + unsigned int arity = 0; + unsigned int mark = 0; + + inline bool is_marked() { return mark >= arity; } + friend ostream& operator<<(ostream& os, const Edge& e); + + size_t head_id_; + vector<size_t> tails_ids_; // node ids + size_t rule_id_; + + MSGPACK_DEFINE(head_id_, tails_ids_, rule_id_, score, arity); +}; + +struct Node { + size_t id; + string symbol; + short left; + short right; + score_t score; + vector<Edge*> incoming; + vector<Edge*> outgoing; + unsigned int mark; + + inline bool is_marked() { return mark >= incoming.size(); }; + friend ostream& operator<<(ostream& os, const Node& n); + + MSGPACK_DEFINE(id, symbol, left, right, score); +}; + +struct Hypergraph { + list<Node*> nodes; + vector<Edge*> edges; + unordered_map<size_t, Node*> nodes_by_id; + unsigned int arity; +}; + +template<typename Semiring> void +init(const list<Node*>& nodes, const list<Node*>::iterator root, const Semiring& semiring); + +void +reset(const list<Node*> nodes, const vector<Edge*> edges); + +void +topological_sort(list<Node*>& nodes, const list<Node*>::iterator root); + +void +viterbi(Hypergraph& hg); + +typedef vector<Edge*> Path; + +void +viterbi_path(Hypergraph& hg, Path& p); + +void +derive(const Path& p, const Node* cur, vector<string>& carry); + +namespace io { + +void +read(Hypergraph& hg, vector<G::Rule*>& rules, G::Vocabulary& vocab, const string& fn); // FIXME + +void +write(Hypergraph& hg, vector<G::Rule*>& rules, const string& fn); // FIXME + +void +manual(Hypergraph& hg, vector<G::Rule*>& rules); + +} // namespace + +} // namespace + diff --git a/src/make_pak.cc b/src/make_pak.cc new file mode 100644 index 0000000..db3a8a4 --- /dev/null +++ b/src/make_pak.cc @@ -0,0 +1,103 @@ +#include <iostream> +#include <fstream> +#include <msgpack.hpp> +#include <msgpack/fbuffer.hpp> +#include <string> + +#include "json-cpp/single_include/json-cpp.hpp" +#include "hypergraph.hh" +#include "types.hh" + +using namespace std; + +struct DummyNode { + size_t id; + string symbol; + vector<short> span; +}; + +struct DummyEdge { + size_t head_id; + size_t rule_id; + vector<size_t> tails_ids; + string f; + score_t score; +}; + +struct DummyHg { + vector<string> rules; + vector<DummyNode> nodes; + vector<DummyEdge> edges; +}; + +template<typename X> inline void +serialize(jsoncpp::Stream<X>& stream, DummyNode& o) +{ + fields(o, stream, "id", o.id, "symbol", o.symbol, "span", o.span); +} + +template<typename X> inline void +serialize(jsoncpp::Stream<X>& stream, DummyEdge& o) +{ + fields(o, stream, "head", o.head_id, "rule", o.rule_id, "tails", o.tails_ids, "score", o.score); +} + +template<typename X> inline void +serialize(jsoncpp::Stream<X>& stream, DummyHg& o) +{ + fields(o, stream, "rules", o.rules, "nodes", o.nodes, "edges", o.edges); +} + +int +main(int argc, char** argv) +{ + // read from json + ifstream ifs(argv[1]); + string json_str((istreambuf_iterator<char>(ifs) ), + (istreambuf_iterator<char>())); + DummyHg hg; + vector<string> rules; + hg.rules = rules; + vector<DummyNode> nodes; + hg.nodes = nodes; + vector<DummyEdge> edges; + hg.edges = edges; + jsoncpp::parse(hg, json_str); + + // convert to proper objects + vector<Hg::Node*> nodes_conv; + for (const auto it: hg.nodes) { + Hg::Node* n = new Hg::Node; + n->id = it.id; + n->symbol = it.symbol; + n->left = it.span[0]; + n->right = it.span[1]; + nodes_conv.push_back(n); + } + vector<Hg::Edge*> edges_conv; + for (const auto it: hg.edges) { + Hg::Edge* e = new Hg::Edge; + e->head_id_ = it.head_id; + e->tails_ids_ = it.tails_ids; + e->score = it.score; + e->rule_id_ = it.rule_id; + edges_conv.push_back(e); + } + + // write to msgpack + FILE* file = fopen(argv[2], "wb"); + msgpack::fbuffer fbuf(file); + msgpack::pack(fbuf, hg.rules.size()); + msgpack::pack(fbuf, hg.nodes.size()); + msgpack::pack(fbuf, hg.edges.size()); + for (const auto it: hg.rules) + msgpack::pack(fbuf, it); + for (const auto it: nodes_conv) + msgpack::pack(fbuf, *it); + for (const auto it: edges_conv) + msgpack::pack(fbuf, *it); + fclose(file); + + return 0; +} + diff --git a/src/parse.hh b/src/parse.hh new file mode 100644 index 0000000..0dd2fc0 --- /dev/null +++ b/src/parse.hh @@ -0,0 +1,301 @@ +#pragma once + +#include <vector> +#include <utility> +#include <sstream> +#include <unordered_map> +#include <set> + +#include "grammar.hh" +#include "util.hh" +#include "types.hh" + +using namespace std; + +typedef pair<size_t,size_t> Span; +namespace std { + template <> + struct hash<Span> + { + size_t + operator()(Span const& k) const + { + return ((hash<size_t>()(k.first) + ^ (hash<size_t>()(k.second) << 1)) >> 1); + } + }; +} + +namespace Parse { + +void visit(vector<Span>& p, + size_t i, size_t l, size_t r, size_t x=0) +{ + for (size_t s = i; s <= r-x; s++) { + for (size_t k = l; k <= r-s; k++) { + p.push_back(Span(k,k+s)); + } + } +} + +struct ChartItem +{ + Span span; + G::Rule const* rule; + vector<Span> tails_spans; + size_t dot; + + ChartItem() {} + + ChartItem(G::Rule* r) : rule(r), dot(0) {} + + ChartItem(G::Rule* r, Span s, size_t dot) + : rule(r), span(s), dot(dot) {} + + ChartItem(ChartItem const& o) + : span(o.span), + rule(o.rule), + tails_spans(o.tails_spans), + dot(o.dot) + { + } + + ostream& + repr(ostream& os) const + { + os << "ChartItem<"; + os << "span=(" << span.first << "," << span.second << "), lhs="; + rule->lhs->repr(os); + os << ", dot=" << dot; + os << ", tails=" << tails_spans.size() << ", "; + os << "rule="; + rule->repr(os); + os << ">"; + os << endl; + } + + friend ostream& + operator<<(ostream& os, ChartItem item) + { + item.repr(os); + + return os; + } +}; + +struct Chart +{ + size_t n_; + map<Span, vector<ChartItem*> > m_; + unordered_map<string,bool> b_; + + vector<ChartItem*>& at(Span s) + { + return m_[s]; + } + + string h(symbol_t sym, Span s) + { + ostringstream ss; + ss << sym; + ss << s.first; + ss << s.second; + + return ss.str(); + } + + bool + has_at(symbol_t sym, Span s) + { + return b_[h(sym, s)]; + } + + void add(ChartItem* item, Span s) + { + if (m_.count(s) > 0) + m_[s].push_back(item); + else { + m_.insert(make_pair(s, vector<ChartItem*>{item})); + } + b_[h(item->rule->lhs->symbol(), s)] = true; + } + + Chart(size_t n) : n_(n) {} + + friend ostream& + operator<<(ostream& os, Chart const& chart) + { + for (map<Span, vector<ChartItem*> >::const_iterator it = chart.m_.cbegin(); + it != chart.m_.cend(); it++) { + os << "(" << it->first.first << "," << it->first.second << ")" << endl; + for (auto jt: it->second) + jt->repr(os); os << endl; + } + + return os; + } +}; + +bool +scan(ChartItem* item, vector<symbol_t> in, size_t limit, Chart& passive) +{ + //cout << "S1" << endl; + while (item->dot < item->rule->rhs.size() && + item->rule->rhs[item->dot]->type() == G::TERMINAL) { + //cout << "S2" << endl; + if (item->span.second == limit) return false; + //cout << "S3" << endl; + if (item->rule->rhs[item->dot]->symbol() == in[item->span.second]) { + //cout << "S4" << endl; + item->dot++; + //cout << "S5" << endl; + item->span.second++; + //cout << "S6" << endl; + } else { + //cout << "S7" << endl; + return false; + } + } + //cout << "S8" << endl; + return true; +} + + +void +init(vector<symbol_t> const& in, size_t n, Chart& active, Chart& passive, G::Grammar const& g) +{ + for (auto rule: g.flat) { + size_t j = 0; + for (auto it: in) { + if (it == rule->rhs.front()->symbol()) { + cout << it << " " << j << j+rule->rhs.size() << endl; + Span span(j, j+rule->rhs.size()); + passive.add(new ChartItem(rule, span, rule->rhs.size()), span); + cout << "new passive item [1] " << *passive.at(span).back() << endl; + } + j++; + } + } +} + +void +parse(vector<symbol_t> const& in, size_t n, Chart& active, Chart& passive, G::Grammar const& g) +{ + vector<Span> spans; + Parse::visit(spans, 1, 0, n); + for (auto span: spans) { + + cout << "Span (" << span.first << "," << span.second << ")" << endl; + + for (auto it: g.start_terminal) { + ChartItem* item = new ChartItem(it, Span(span.first,span.first), 0); + if (scan(item, in, span.second, passive) + && span.first + item->rule->rhs.size() <= span.second) { + active.add(item, span); + cout << "new active item [1] " << *active.at(span).back(); + } + } + + for (auto it: g.start_non_terminal) { + if (it->rhs.size() > span.second-span.first) continue; + active.add(new ChartItem(it, Span(span.first,span.first), 0), span); + cout << "new active item [2] " << *active.at(span).back(); + } + + set<symbol_t> new_symbols; + vector<ChartItem*> remaining_items; + + while (true) { + cout << "active size at (" << span.first << "," << span.second << ") " << active.at(span).size() << endl; + cout << "passive size at (" << span.first << "," << span.second << ") " << passive.at(span).size() << endl; + if (active.at(span).empty()) break; + ChartItem* item = active.at(span).back(); + cout << "current item " << *item; + active.at(span).pop_back(); + bool advanced = false; + vector<Span> spans2; + Parse::visit(spans2, 1, span.first, span.second, 1); + for (auto span2: spans2) { + cout << "A" << endl; + if (item->rule->rhs[item->dot]->type() == G::NON_TERMINAL) { + cout << "B" << endl; + if (passive.has_at(item->rule->rhs[item->dot]->symbol(), span2)) { + cout << "C" << endl; + if (span2.first == item->span.second) { + cout << "D" << endl; + ChartItem* new_item = new ChartItem(*item); + cout << "D1" << endl; + new_item->span.second = span2.second; + cout << "D2" << endl; + new_item->dot++; + cout << "D3" << endl; + new_item->tails_spans.push_back(span2); + cout << "D4" << endl; + if (scan(new_item, in, span.second, passive)) { + cout << "E" << endl; + if (new_item->dot == new_item->rule->rhs.size()) { + cout << "F" << endl; + if (new_item->span.first == span.first && new_item->span.second == span.second) { + cout << "G" << endl; + cout << "H" << endl; + new_symbols.insert(new_item->rule->lhs->symbol()); + passive.add(new_item, span); + cout << "new passive item [2] " << *new_item; + advanced = true; + } + } else { + if (new_item->span.second+(new_item->rule->rhs.size()-new_item->dot) <= span.second) { + active.add(new_item, span); + cout << "new active item [3] " << *new_item; + } + } + } + cout << "I" << endl; + } + } + } + } + cout << "J" << endl; + if (!advanced) { + cout << "K" << endl; + remaining_items.push_back(item); + } + } + + for (auto new_sym: new_symbols) { + cout << "new sym " << new_sym << endl; + for (auto rem_item: remaining_items) { + if (rem_item->dot != 0 || + rem_item->rule->rhs[rem_item->dot]->type() != G::NON_TERMINAL) { + continue; + cout << "K1" << endl; + } + cout << "K2" << endl; + if (rem_item->rule->rhs[rem_item->dot]->symbol() == new_sym) { + cout << "K3" << endl; + ChartItem* new_item = new ChartItem(*rem_item); + cout << "K31" << endl; + //new_item->tails_spans[new_item->dot-1] = span; + new_item->tails_spans.push_back(span); + new_item->dot++; + cout << "K32" << endl; + if (new_item->dot == new_item->rule->rhs.size()) { + cout << "K4" << endl; + new_symbols.insert(new_item->rule->lhs->symbol()); + passive.add(new_item, span); + } + } + } + } + + cout << "L" << endl; + cout << "-------------------" << endl; + cout << endl; + } + + //cout << "ACTIVE" << endl << active << endl; + cout << "PASSIVE" << endl << passive << endl; +} + +} // + diff --git a/src/read_pak.cc b/src/read_pak.cc new file mode 100644 index 0000000..c894442 --- /dev/null +++ b/src/read_pak.cc @@ -0,0 +1,26 @@ +#include <iostream> +#include <fstream> +#include <msgpack.hpp> + +using namespace std; + +int +main(int argc, char** argv) +{ + ifstream ifs(argv[1]); + msgpack::unpacker pac; + while(true) { + pac.reserve_buffer(32*1024); + size_t bytes = ifs.readsome(pac.buffer(), pac.buffer_capacity()); + pac.buffer_consumed(bytes); + msgpack::unpacked result; + while(pac.next(&result)) { + msgpack::object o = result.get(); + cout << o << endl; + } + if (!bytes) break; + } + + return 0; +} + diff --git a/src/semiring.hh b/src/semiring.hh new file mode 100644 index 0000000..3f4ac08 --- /dev/null +++ b/src/semiring.hh @@ -0,0 +1,35 @@ +#pragma once + + +namespace Semiring { + +template<typename T> +struct Viterbi { + T one = 1.0; + T null = 0.0; + + T add(T x, T y); + T multiply(T x, T y); + T convert(T x); +}; + +template<typename T> T +Viterbi<T>::add(T x, T y) +{ + return max(x, y); +} + +template<typename T> T +Viterbi<T>::multiply(T x, T y) +{ + return x * y; +} + +template<typename T> T +Viterbi<T>::convert(T x) +{ + return (T)x; +} + +} // namespace + diff --git a/src/sparse_vector.hh b/src/sparse_vector.hh new file mode 100644 index 0000000..7fff338 --- /dev/null +++ b/src/sparse_vector.hh @@ -0,0 +1,186 @@ +#pragma once + +#include <iostream> +#include <sstream> +#include <string> +#include <unordered_map> +#include <vector> + +#include "util.hh" +#include "types.hh" + +using namespace std; + + +namespace Sv { + +template<typename K, typename V> +struct SparseVector { + unordered_map<K,V> m_; + V zero = 0.f; + + SparseVector() {}; + + SparseVector(string& s) + { + from_s(this, s); + }; + + void + insert(K k, V v) { m_[k] = v; }; + + V + dot(SparseVector& other) + { + V r; + unordered_map<K,V>* o = &m_; + auto b = m_.cbegin(); + auto e = m_.cend(); + if (other.size() < size()) { + b = other.m_.cbegin(); + e = other.m_.cend(); + o = &other.m_; + } + for (auto it = b; it != e; it++) + r += it->second * o->at(it->first); + + return r; + }; + + size_t + size() + { + return m_.size(); + } + + V& + operator[](const K& k) + { + return at(k); + }; + + const V& + at(const K& k) const + { + if (m_.find(k) == m_.end()) + return zero; + else + return m_.at(k); + } + + SparseVector + operator+(const SparseVector& other) const + { + SparseVector<K,V> v; + v.m_.insert(m_.cbegin(), m_.cend()); + v.m_.insert(other.m_.cbegin(), other.m_.cend()); + for (const auto it: v.m_) + v.m_[it.first] = this->at(it.first) + other.at(it.first); + + return v; + }; + + SparseVector& + operator+=(const SparseVector& other) + { + for (const auto it: other.m_) + m_[it.first] += it.second; + + return *this; + }; + + SparseVector + operator-(const SparseVector& other) const + { + SparseVector<K,V> v; + v.m_.insert(m_.cbegin(), m_.cend()); + v.m_.insert(other.m_.cbegin(), other.m_.cend()); + for (const auto it: v.m_) + v.m_[it.first] = this->at(it.first) - other.at(it.first); + + return v; + }; + + SparseVector& + operator-=(const SparseVector& other) + { + for (const auto it: other.m_) + m_[it.first] -= it.second; + + return *this; + }; + + SparseVector + operator*(V f) const + { + SparseVector<K,V> v; + for (const auto it: m_) + v.m_[it.first] = this->at(it.first) * f; + + return v; + }; + + SparseVector& + operator*=(V f) + { + for (const auto it: m_) + m_[it.first] *= f; + + return *this; + }; + + static void + from_s(SparseVector* w, const string& s) + { + stringstream ss(s); + while (!ss.eof()) { + string t; + ss >> t; + size_t eq = t.find_first_of("="); + if (eq == string::npos) { + return; + } + t.replace(eq, 1, " "); + stringstream tt(t); + K k; V v; + tt >> k >> v; + w->m_.emplace(k.substr(k.find_first_of("\"")+1, k.find_last_of("\"")-1), v); + } + } + + ostream& + repr(ostream& os) const + { + os << "SparseVector<{"; + for (auto it = m_.cbegin(); it != m_.cend(); it++) { + os << "'" << it->first << "'=" << it->second; + if (next(it) != m_.end()) + os << ", "; + } + os << "}>"; + + return os; + }; + + ostream& + escaped(ostream& os, bool quote_keys=false) const { + for (auto it = m_.cbegin(); it != m_.cend(); it++) { + if (quote_keys) os << '"'; + os << util::json_escape(it->first); + if (quote_keys) os << '"'; + os << "=" << it->second; + if (next(it) != m_.cend()) os << " "; + } + + return os; + }; + + friend ostream& + operator<<(ostream& os, const SparseVector& v) + { + return v.repr(os); + } +}; + +} // namespace + diff --git a/src/test_grammar.cc b/src/test_grammar.cc new file mode 100644 index 0000000..7fee79a --- /dev/null +++ b/src/test_grammar.cc @@ -0,0 +1,19 @@ +#include <fstream> + +#include "grammar.hh" + +using namespace std; + +int +main(int argc, char** argv) +{ + G::Vocabulary y; + G::Grammar g(argv[1], y); + for (auto it: g.rules) { + it->escaped(cout); + cout << endl; + } + + return 0; +} + diff --git a/src/test_parse.cc b/src/test_parse.cc new file mode 100644 index 0000000..2d51d44 --- /dev/null +++ b/src/test_parse.cc @@ -0,0 +1,19 @@ +#include "parse.hh" + +int main(int argc, char** argv) +{ + //string in("ich sah ein kleines haus"); + //string in("europa bildet den ersten oder zweiten markt für die zehn am häufigsten von indien exportierten produkte , erklärte der europäische kommissar weiter . die asiatischen und europäischen giganten tauschen jährlich güter im wert von 47 milliarden euro und dienstleistungen im wert von 10 milliarden euro aus , hatte diese woche daniéle smadja , vorsitzende der abordnung der europäischen kommission in neu delhi , erklärt , und bedauert , dass der gegenseitige handel sein potential noch nicht ausgeschöpft hat . die eu und indien treffen sich am freitag zu ihrem achten diplomatischen in neu delhi , bei dem premierminister manmohan singh und der präsident der europäischen kommission josé manuel durao barrosso anwesend sein werden ."); + //string in("aber schon bald nach seinem eintritt kam der erste große erfolg ."); + string in("lebensmittel schuld an europäischer inflation"); + vector<symbol_t> tok = util::tokenize(in); + size_t n = tok.size(); + G::Vocabulary v; + G::Grammar g(argv[1], v); + g.add_glue(v); + Parse::Chart active(n); + Parse::Chart passive(n); + init(tok, n, active, passive, g); + parse(tok, n, active, passive, g); +} + diff --git a/src/test_sparse_vector.cc b/src/test_sparse_vector.cc new file mode 100644 index 0000000..69aaa21 --- /dev/null +++ b/src/test_sparse_vector.cc @@ -0,0 +1,36 @@ +#include "sparse_vector.hh" + +int +main(void) +{ + Sv::SparseVector<string, score_t> a; + a.insert("1", 1); + a.insert("2", 2); + cout << "a:" << a << endl; + + Sv::SparseVector<string, score_t> b; + b.insert("2", 2); + cout << "b:" << b << endl; + + Sv::SparseVector<string, score_t> c = a + b; + cout << "a+b:" << c << endl; + + a += b; + cout << "a+=b:" << a << endl; + + a -= b; + cout << "a-=b:" << a << endl; + + cout << "a*2:" << a*2 << endl; + + a *= 2; + cout << "a*=2:" << a << endl; + + string s("\"a\"=2 \"b\"=3"); + Sv::SparseVector<string, score_t>* sv = new Sv::SparseVector<string, score_t>(s); + cout << *sv << endl; + cout << sv->dot(*sv) << endl; + + return 0; +} + diff --git a/src/types.hh b/src/types.hh new file mode 100644 index 0000000..e89b4dd --- /dev/null +++ b/src/types.hh @@ -0,0 +1,10 @@ +#pragma once + +#include <string> + +using namespace std; + + +typedef double score_t; +typedef string symbol_t; + diff --git a/src/util.hh b/src/util.hh new file mode 100644 index 0000000..93ea320 --- /dev/null +++ b/src/util.hh @@ -0,0 +1,47 @@ +#pragma once + +#include <string> + +#include "types.hh" + +using namespace std; + + +namespace util { + +inline string +json_escape(const string& s) +{ + ostringstream os; + for (auto it = s.cbegin(); it != s.cend(); it++) { + switch (*it) { + case '"': os << "\\\""; break; + case '\\': os << "\\\\"; break; + case '\b': os << "\\b"; break; + case '\f': os << "\\f"; break; + case '\n': os << "\\n"; break; + case '\r': os << "\\r"; break; + case '\t': os << "\\t"; break; + default: os << *it; break; + } + } + + return os.str(); +} + +inline vector<symbol_t> +tokenize(string s) +{ + istringstream ss(s); + vector<symbol_t> r; + while (ss.good()) { + string buf; + ss >> buf; + r.push_back(buf); + } + + return r; +} + +} // namespace util + |