summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
authorChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-03-08 22:45:06 -0500
committerChris Dyer <cdyer@allegro.clab.cs.cmu.edu>2013-03-08 22:45:06 -0500
commit445ba0b0ba94284d17c88a7e75f81c5d3c54a001 (patch)
tree6f655d69d397089d7bc9e87d8cd104657d4c0967 /python/pkg/cdec/sa
parent3a162d28033d1b9d5241e31f32978dba4eba6296 (diff)
parent99538847039c06bdcc288e2c5dfcdb507ff879ca (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/extract.py44
1 files changed, 24 insertions, 20 deletions
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 2e596bd3..782bed8b 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -1,22 +1,25 @@
#!/usr/bin/env python
import sys
import os
+import re
+import gzip
import argparse
import logging
-import re
-import multiprocessing as mp
import signal
+import multiprocessing as mp
import cdec.sa
extractor, prefix = None, None
-online = False
+online, compress = False, False
-def make_extractor(config, grammars, features):
- global extractor, prefix, online
+def make_extractor(args):
+ global extractor, prefix, online, compress
signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
- load_features(features)
- extractor = cdec.sa.GrammarExtractor(config, online)
- prefix = grammars
+ load_features(args.features)
+ extractor = cdec.sa.GrammarExtractor(args.config, online)
+ prefix = args.grammars
+ online = args.online
+ compress = args.compress
def load_features(features):
for featdef in features:
@@ -27,7 +30,7 @@ def load_features(features):
sys.path.remove(prefix)
def extract(inp):
- global extractor, prefix, online
+ global extractor, prefix, online, compress
i, sentence = inp
sentence = sentence[:-1]
fields = re.split('\s*\|\|\|\s*', sentence)
@@ -36,7 +39,7 @@ def extract(inp):
if online:
if len(fields) < 3:
sys.stderr.write('Error: online mode requires references and alignments.'
- ' Not adding sentence to training data: {0}\n'.format(sentence))
+ ' Not adding sentence to training data: {}\n'.format(sentence))
sentence = fields[0]
else:
sentence, reference, alignment = fields[0:3]
@@ -46,18 +49,19 @@ def extract(inp):
if len(fields) > 1:
sentence = fields[0]
suffix = ' ||| ' + ' ||| '.join(fields[1:])
- grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
- with open(grammar_file, 'w') as output:
+
+ grammar_file = os.path.join(prefix, 'grammar.'+str(i))
+ if compress: grammar_file += '.gz'
+ with (gzip.open if compress else open)(grammar_file, 'w') as output:
for rule in extractor.grammar(sentence):
output.write(str(rule)+'\n')
# Add training instance _after_ extracting grammars
if online:
extractor.add_instance(sentence, reference, alignment)
grammar_file = os.path.abspath(grammar_file)
- return '<seg grammar="{0}" id="{1}"> {2} </seg>{3}'.format(grammar_file, i, sentence, suffix)
+ return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix)
def main():
- global online
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')
parser.add_argument('-c', '--config', required=True,
@@ -70,30 +74,30 @@ def main():
help='number of sentences / chunk')
parser.add_argument('-f', '--features', nargs='*', default=[],
help='additional feature definitions')
- parser.add_argument('-o', '--online', action='store_true', default=False,
+ parser.add_argument('-o', '--online', action='store_true',
help='online grammar extraction')
+ parser.add_argument('-z', '--compress', action='store_true',
+ help='compress grammars with gzip')
args = parser.parse_args()
if not os.path.exists(args.grammars):
os.mkdir(args.grammars)
for featdef in args.features:
if not featdef.endswith('.py'):
- sys.stderr.write('Error: feature definition file <{0}>'
+ sys.stderr.write('Error: feature definition file <{}>'
' should be a python module\n'.format(featdef))
sys.exit(1)
- online = args.online
-
if args.jobs > 1:
logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
- pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars, args.features))
+ pool = mp.Pool(args.jobs, make_extractor, (args,))
try:
for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
print(output)
except KeyboardInterrupt:
pool.terminate()
else:
- make_extractor(args.config, args.grammars, args.features)
+ make_extractor(args)
for output in map(extract, enumerate(sys.stdin)):
print(output)