summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-09-20 21:51:31 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2012-09-20 21:51:31 -0400
commit214f4714d95cb27d31ff976a11dec8a0c0eb438d (patch)
tree0970ab16db5260f128a65d60f1dc60caf831efc5 /python/pkg/cdec/sa
parent17d085055e24bf189a3b378af77e1071922893cc (diff)
parente26edac51cc47b2b2322fbb870308daa708cec8c (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/__init__.py16
-rw-r--r--python/pkg/cdec/sa/extract.py23
-rw-r--r--python/pkg/cdec/sa/extractor.py13
3 files changed, 37 insertions, 15 deletions
diff --git a/python/pkg/cdec/sa/__init__.py b/python/pkg/cdec/sa/__init__.py
index cc532fb9..e0a344b7 100644
--- a/python/pkg/cdec/sa/__init__.py
+++ b/python/pkg/cdec/sa/__init__.py
@@ -1,10 +1,24 @@
-from cdec.sa._sa import sym_fromstring,\
+from cdec.sa._sa import make_lattice, decode_lattice, decode_sentence,\
SuffixArray, DataArray, LCP, Precomputation, Alignment, BiLex,\
HieroCachingRuleFactory, Sampler, Scorer
from cdec.sa.extractor import GrammarExtractor
_SA_FEATURES = []
+_SA_ANNOTATORS = {}
+_SA_CONFIGURE = []
def feature(fn):
_SA_FEATURES.append(fn)
return fn
+
+def annotator(fn):
+ _SA_ANNOTATORS[fn.__name__] = fn
+
+def annotate(sentence):
+ meta = {}
+ for name, fn in _SA_ANNOTATORS.iteritems():
+ meta[name] = fn(sentence)
+ return meta
+
+def configure(fn):
+ _SA_CONFIGURE.append(fn)
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 472f128b..10a81556 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -11,16 +11,17 @@ extractor, prefix = None, None
def make_extractor(config, grammars, features):
global extractor, prefix
signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
- if features: load_features(features)
+ load_features(features)
extractor = cdec.sa.GrammarExtractor(config)
prefix = grammars
def load_features(features):
- logging.info('Loading additional feature definitions from %s', features)
- prefix = os.path.dirname(features)
- sys.path.append(prefix)
- __import__(os.path.basename(features).replace('.py', ''))
- sys.path.remove(prefix)
+ for featdef in features:
+ logging.info('Loading additional feature definitions from %s', featdef)
+ prefix = os.path.dirname(featdef)
+ sys.path.append(prefix)
+ __import__(os.path.basename(featdef).replace('.py', ''))
+ sys.path.remove(prefix)
def extract(inp):
global extractor, prefix
@@ -44,15 +45,17 @@ def main():
help='number of parallel extractors')
parser.add_argument('-s', '--chunksize', type=int, default=10,
help='number of sentences / chunk')
- parser.add_argument('-f', '--features', type=str, default=None,
+ parser.add_argument('-f', '--features', nargs='*', default=[],
help='additional feature definitions')
args = parser.parse_args()
if not os.path.exists(args.grammars):
os.mkdir(args.grammars)
- if not (args.features is None or args.features.endswith('.py')):
- sys.stderr.write('Error: feature definition file should be a python module\n')
- sys.exit(1)
+ for featdef in args.features:
+ if not featdef.endswith('.py'):
+ sys.stderr.write('Error: feature definition file <{0}>'
+ ' should be a python module\n'.format(featdef))
+ sys.exit(1)
if args.jobs > 1:
logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
diff --git a/python/pkg/cdec/sa/extractor.py b/python/pkg/cdec/sa/extractor.py
index 89e35bf8..a5ce8a68 100644
--- a/python/pkg/cdec/sa/extractor.py
+++ b/python/pkg/cdec/sa/extractor.py
@@ -57,6 +57,8 @@ class GrammarExtractor:
# lexical weighting tables
tt = cdec.sa.BiLex(from_binary=config['lex_file'])
+ # TODO: use @cdec.sa.features decorator for standard features too
+ # + add a mask to disable features
scorer = cdec.sa.Scorer(EgivenFCoherent, SampleCountF, CountEF,
MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE,
*cdec.sa._SA_FEATURES)
@@ -69,11 +71,14 @@ class GrammarExtractor:
sampler = cdec.sa.Sampler(300, fsarray)
self.factory.configure(fsarray, edarray, sampler, scorer)
+ # Initialize feature definitions with configuration
+ for fn in cdec.sa._SA_CONFIGURE:
+ fn(config)
def grammar(self, sentence):
if isinstance(sentence, unicode):
sentence = sentence.encode('utf8')
- cnet = chain(('<s>',), sentence.split(), ('</s>',))
- cnet = (cdec.sa.sym_fromstring(word, terminal=True) for word in cnet)
- cnet = tuple(((word, None, 1), ) for word in cnet)
- return self.factory.input(cnet)
+ words = tuple(chain(('<s>',), sentence.split(), ('</s>',)))
+ meta = cdec.sa.annotate(words)
+ cnet = cdec.sa.make_lattice(words)
+ return self.factory.input(cnet, meta)