summaryrefslogtreecommitdiff
path: root/python/src/lattice.pxi
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-06-23 11:59:48 -0400
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-06-23 11:59:48 -0400
commitb738e349be490c24d3604c224f44fc54e16d3d7b (patch)
tree5d435257ef3c0023daa2211eb7260c470dbb6cdc /python/src/lattice.pxi
parent0b27ea3f91d0ad2f2ed718839d308db3d1baf5ae (diff)
Support for sparse/dense vectors in the python extension
- SparseVector, DenseVector - improved Lattice - Lattice translation - Hypergraph reweighting, pruning
Diffstat (limited to 'python/src/lattice.pxi')
-rw-r--r--python/src/lattice.pxi56
1 files changed, 56 insertions, 0 deletions
diff --git a/python/src/lattice.pxi b/python/src/lattice.pxi
new file mode 100644
index 00000000..493c6dcd
--- /dev/null
+++ b/python/src/lattice.pxi
@@ -0,0 +1,56 @@
+cimport lattice
+
+cdef class Lattice:
+ cdef lattice.Lattice* lattice
+
+ def __init__(self, inp):
+ if isinstance(inp, tuple):
+ self.lattice = new lattice.Lattice(len(inp))
+ for i, arcs in enumerate(inp):
+ self[i] = arcs
+ else:
+ if isinstance(inp, unicode):
+ inp = inp.encode('utf8')
+ if not isinstance(inp, str):
+ raise TypeError('Cannot create lattice from %s' % type(inp))
+ self.lattice = new lattice.Lattice()
+ lattice.ConvertTextToLattice(string(<char *>inp), self.lattice)
+
+ def __getitem__(self, int index):
+ if not 0 <= index < len(self):
+ raise IndexError('lattice index out of range')
+ arcs = []
+ cdef vector[lattice.LatticeArc] arc_vector = self.lattice[0][index]
+ cdef lattice.LatticeArc* arc
+ cdef str label
+ cdef unsigned i
+ for i in range(arc_vector.size()):
+ arc = &arc_vector[i]
+ label = TDConvert(arc.label)
+ arcs.append((label.decode('utf8'), arc.cost, arc.dist2next))
+ return tuple(arcs)
+
+ def __setitem__(self, int index, tuple arcs):
+ if not 0 <= index < len(self):
+ raise IndexError('lattice index out of range')
+ cdef lattice.LatticeArc* arc
+ for (label, cost, dist2next) in arcs:
+ if isinstance(label, unicode):
+ label = label.encode('utf8')
+ arc = new lattice.LatticeArc(TDConvert(<char *>label), cost, dist2next)
+ self.lattice[0][index].push_back(arc[0])
+ del arc
+
+ def __len__(self):
+ return self.lattice.size()
+
+ def __str__(self):
+ return hypergraph.AsPLF(self.lattice[0], True).c_str()
+
+ def __iter__(self):
+ cdef unsigned i
+ for i in range(len(self)):
+ yield self[i]
+
+ def __dealloc__(self):
+ del self.lattice