summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-06 17:46:41 +0100
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-06 17:46:41 +0100
commit1cef9b6842fec7598a0a0571f69bf4caab8e4c91 (patch)
tree89c3b5cf241c3a49688994ce19a07c7cdd01aa71 /python/pkg/cdec/sa
parentc1b77250f656c4cff9f0e532d6b6644cb0dc993c (diff)
[cdec.sa] Allow sentence annotation and initial configuration
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/__init__.py14
-rw-r--r--python/pkg/cdec/sa/extractor.py8
2 files changed, 20 insertions, 2 deletions
diff --git a/python/pkg/cdec/sa/__init__.py b/python/pkg/cdec/sa/__init__.py
index d4b94484..e0a344b7 100644
--- a/python/pkg/cdec/sa/__init__.py
+++ b/python/pkg/cdec/sa/__init__.py
@@ -4,7 +4,21 @@ from cdec.sa._sa import make_lattice, decode_lattice, decode_sentence,\
from cdec.sa.extractor import GrammarExtractor
_SA_FEATURES = []
+_SA_ANNOTATORS = {}
+_SA_CONFIGURE = []
def feature(fn):
_SA_FEATURES.append(fn)
return fn
+
+def annotator(fn):
+ _SA_ANNOTATORS[fn.__name__] = fn
+
+def annotate(sentence):
+ meta = {}
+ for name, fn in _SA_ANNOTATORS.iteritems():
+ meta[name] = fn(sentence)
+ return meta
+
+def configure(fn):
+ _SA_CONFIGURE.append(fn)
diff --git a/python/pkg/cdec/sa/extractor.py b/python/pkg/cdec/sa/extractor.py
index 94392c30..a5ce8a68 100644
--- a/python/pkg/cdec/sa/extractor.py
+++ b/python/pkg/cdec/sa/extractor.py
@@ -71,10 +71,14 @@ class GrammarExtractor:
sampler = cdec.sa.Sampler(300, fsarray)
self.factory.configure(fsarray, edarray, sampler, scorer)
+ # Initialize feature definitions with configuration
+ for fn in cdec.sa._SA_CONFIGURE:
+ fn(config)
def grammar(self, sentence):
if isinstance(sentence, unicode):
sentence = sentence.encode('utf8')
- words = chain(('<s>',), sentence.split(), ('</s>',))
+ words = tuple(chain(('<s>',), sentence.split(), ('</s>',)))
+ meta = cdec.sa.annotate(words)
cnet = cdec.sa.make_lattice(words)
- return self.factory.input(cnet)
+ return self.factory.input(cnet, meta)