summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-06 17:46:41 +0100
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-06 17:46:41 +0100
commit8249f6445ed28c3dc902f0eb10b1f6283058c553 (patch)
tree50a1735defa16b4af7e1de5e0e6f12e4d7c4a5c6 /python/pkg/cdec/sa
parent28194c2d099b9ea039b60ac35393626ce26d326c (diff)
[cdec.sa] Allow sentence annotation and initial configuration
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/__init__.py14
-rw-r--r--python/pkg/cdec/sa/extractor.py8
2 files changed, 20 insertions, 2 deletions
diff --git a/python/pkg/cdec/sa/__init__.py b/python/pkg/cdec/sa/__init__.py
index d4b94484..e0a344b7 100644
--- a/python/pkg/cdec/sa/__init__.py
+++ b/python/pkg/cdec/sa/__init__.py
@@ -4,7 +4,21 @@ from cdec.sa._sa import make_lattice, decode_lattice, decode_sentence,\
from cdec.sa.extractor import GrammarExtractor
_SA_FEATURES = []
+_SA_ANNOTATORS = {}
+_SA_CONFIGURE = []
def feature(fn):
_SA_FEATURES.append(fn)
return fn
+
+def annotator(fn):
+ _SA_ANNOTATORS[fn.__name__] = fn
+
+def annotate(sentence):
+ meta = {}
+ for name, fn in _SA_ANNOTATORS.iteritems():
+ meta[name] = fn(sentence)
+ return meta
+
+def configure(fn):
+ _SA_CONFIGURE.append(fn)
diff --git a/python/pkg/cdec/sa/extractor.py b/python/pkg/cdec/sa/extractor.py
index 94392c30..a5ce8a68 100644
--- a/python/pkg/cdec/sa/extractor.py
+++ b/python/pkg/cdec/sa/extractor.py
@@ -71,10 +71,14 @@ class GrammarExtractor:
sampler = cdec.sa.Sampler(300, fsarray)
self.factory.configure(fsarray, edarray, sampler, scorer)
+ # Initialize feature definitions with configuration
+ for fn in cdec.sa._SA_CONFIGURE:
+ fn(config)
def grammar(self, sentence):
if isinstance(sentence, unicode):
sentence = sentence.encode('utf8')
- words = chain(('<s>',), sentence.split(), ('</s>',))
+ words = tuple(chain(('<s>',), sentence.split(), ('</s>',)))
+ meta = cdec.sa.annotate(words)
cnet = cdec.sa.make_lattice(words)
- return self.factory.input(cnet)
+ return self.factory.input(cnet, meta)