diff options
-rw-r--r-- | fast/Makefile | 4 | ||||
-rw-r--r-- | fast/README.md | 3 | ||||
-rwxr-xr-x | fast/fast_weaver | bin | 88433 -> 0 bytes | |||
-rw-r--r-- | fast/hypergraph.hh | 7 | ||||
-rw-r--r-- | fast/hypergraph.o | bin | 72616 -> 0 bytes | |||
-rw-r--r-- | fast/semiring.hh | 1 | ||||
-rwxr-xr-x | util/cdec2json.py | 19 |
7 files changed, 26 insertions, 8 deletions
diff --git a/fast/Makefile b/fast/Makefile index 4aea1ac..a453fd3 100644 --- a/fast/Makefile +++ b/fast/Makefile @@ -1,8 +1,8 @@ all: hypergraph.o main.cc - clang -std=c++11 -lstdc++ main.cc hypergraph.o -o fast_weaver + clang -std=c++11 -lstdc++ main.cc hypergraph.o json/libjson.a -o fast_weaver hypergraph.o: hypergraph.cc hypergraph.hh grammar.o semiring.hh - clang -std=c++11 -lstdc++ -c hypergraph.cc grammar.o + clang -std=c++11 -lstdc++ -c hypergraph.cc -I./msgpack-c/include/ grammar.o ./msgpack-c/lib/libmsgpack.a grammar.o: grammar.cc grammar.hh clang -std=c++11 -lstdc++ -c grammar.cc diff --git a/fast/README.md b/fast/README.md index 112a7ae..3087bab 100644 --- a/fast/README.md +++ b/fast/README.md @@ -2,7 +2,8 @@ TODO * grammar * parser * other semirings - * sparse vector + * sparse vector (unordered_map) + * hg serialization? json/bson/msgpack/protocol buffers (no!) * hg: json input (jsoncpp?) * language model: kenlm diff --git a/fast/fast_weaver b/fast/fast_weaver Binary files differdeleted file mode 100755 index 7d349b3..0000000 --- a/fast/fast_weaver +++ /dev/null diff --git a/fast/hypergraph.hh b/fast/hypergraph.hh index 6e53045..24e63f5 100644 --- a/fast/hypergraph.hh +++ b/fast/hypergraph.hh @@ -11,6 +11,8 @@ #include <functional> #include <algorithm> +#include <msgpack.hpp> + using namespace std; typedef double score_t; @@ -35,6 +37,7 @@ class Hyperedge { bool is_marked(); string s(); + MSGPACK_DEFINE(head, tails, score, f, mark, arity_); }; @@ -49,6 +52,8 @@ class Node { vector<Hyperedge*> incoming; string s(); + + MSGPACK_DEFINE(id, symbol, left, right, score, outgoing, incoming); }; @@ -64,6 +69,8 @@ class Hypergraph { void reset(); string s(); string json_s(); + + MSGPACK_DEFINE(nodes, edges, arity_, nodes_by_id); }; vector<Node*> topological_sort(vector<Node*>& nodes); diff --git a/fast/hypergraph.o b/fast/hypergraph.o Binary files differdeleted file mode 100644 index 8fab348..0000000 --- a/fast/hypergraph.o +++ /dev/null diff --git a/fast/semiring.hh b/fast/semiring.hh index 1e40f48..2be19ea 100644 --- a/fast/semiring.hh +++ b/fast/semiring.hh @@ -1,5 +1,6 @@ #ifndef SEMIRING_HH #define SEMIRING_HH +//#pragma once template<typename T> diff --git a/util/cdec2json.py b/util/cdec2json.py index 6cebd70..ac468ca 100755 --- a/util/cdec2json.py +++ b/util/cdec2json.py @@ -2,6 +2,9 @@ import cdec import sys, argparse +import json +import gzip + #FIXME new format def hg2json(hg, weights): @@ -20,7 +23,7 @@ def hg2json(hg, weights): res += '"nodes":'+"\n" res += "[\n" a = [] - a.append( '{ "id":-1, "cat":"root", "span":[-1,-1] }' ) + a.append( '{ "id":0, "cat":"root", "span":[-1,-1] }' ) for i in hg.nodes: a.append('{ "id":%d, "cat":"%s", "span":[%d,%d] }'%(i.id, i.cat, i.span[0], i.span[1])) res += ",\n".join(a)+"\n" @@ -31,7 +34,7 @@ def hg2json(hg, weights): for i in hg.edges: s = "{" s += '"head":%d'%(i.head_node.id) - s += ', "rule":"%s"'%(i.trule) + s += ', "rule":%s'%(json.dumps(str(i.trule))) # f xs = ' "f":{' b = [] @@ -41,9 +44,9 @@ def hg2json(hg, weights): xs += "}," # tails if len(list(i.tail_nodes)) > 0: - s += ', "tails":[ %s ],'%(",".join([str(n.id) for n in i.tail_nodes])) + s += ', "tails":[ %s ],'%(",".join([str(n.id+1) for n in i.tail_nodes])) else: - s += ', "tails":[ -1 ],' + s += ', "tails":[ 0 ],' s += xs s += ' "weight":%s }'%(i.prob) a.append(s) @@ -56,13 +59,19 @@ def main(): parser = argparse.ArgumentParser(description='get a proper json representation of cdec hypergraphs') parser.add_argument('-c', '--config', required=True, help='decoder configuration') parser.add_argument('-w', '--weights', required=True, help='feature weights') + parser.add_argument('-g', '--grammar', required=False, help='grammar') args = parser.parse_args() with open(args.config) as config: config = config.read() decoder = cdec.Decoder(config) decoder.read_weights(args.weights) ins = sys.stdin.readline().strip() - hg = decoder.translate(ins) + if args.grammar: + with gzip.open(args.grammar) as grammar: + grammar = grammar.read() + hg = decoder.translate(ins, grammar=grammar) + else: + hg = decoder.translate(ins) sys.stderr.write( "input:\n '%s'\n"%(ins) ) sys.stderr.write( "viterbi translation:\n '%s'\n"%(hg.viterbi()) ) |