summaryrefslogtreecommitdiff
path: root/python/cdec/sa/extract.py
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-04-20 22:25:20 -0400
committerChris Dyer <redpony@gmail.com>2014-04-20 22:25:20 -0400
commit2af1d21b74343b568fbb87a2a6902ee01f19636a (patch)
tree39723f1f3a14ec0a877b89573265f04310384d92 /python/cdec/sa/extract.py
parent1748e9a095bcc3a1db8ab47eb7ac6a1f9568772b (diff)
parentbcf989fe0ea170c4f01383a70ba552e663d69109 (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/cdec/sa/extract.py')
-rw-r--r--python/cdec/sa/extract.py56
1 files changed, 52 insertions, 4 deletions
diff --git a/python/cdec/sa/extract.py b/python/cdec/sa/extract.py
index b6502c52..b6c11f05 100644
--- a/python/cdec/sa/extract.py
+++ b/python/cdec/sa/extract.py
@@ -62,13 +62,48 @@ def extract(inp):
grammar_file = os.path.abspath(grammar_file)
return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix)
+def stream_extract():
+ global extractor, online, compress
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+ fields = re.split('\s*\|\|\|\s*', line.strip())
+ # context ||| cmd
+ if len(fields) == 2:
+ (context, cmd) = fields
+ if cmd.lower() == 'drop':
+ if online:
+ extractor.drop_ctx(context)
+ sys.stdout.write('drop {}\n'.format(context))
+ else:
+ sys.stdout.write('Error: online mode not set. Skipping line: {}\n'.format(line.strip()))
+ # context ||| sentence ||| grammar_file
+ elif len(fields) == 3:
+ (context, sentence, grammar_file) = fields
+ with (gzip.open if compress else open)(grammar_file, 'w') as output:
+ for rule in extractor.grammar(sentence, context):
+ output.write(str(rule)+'\n')
+ sys.stdout.write('{}\n'.format(grammar_file))
+ # context ||| sentence ||| reference ||| alignment
+ elif len(fields) == 4:
+ (context, sentence, reference, alignment) = fields
+ if online:
+ extractor.add_instance(sentence, reference, alignment, context)
+ sys.stdout.write('learn {}\n'.format(context))
+ else:
+ sys.stdout.write('Error: online mode not set. Skipping line: {}\n'.format(line.strip()))
+ else:
+ sys.stdout.write('Error: see README.md for stream mode usage. Skipping line: {}\n'.format(line.strip()))
+ sys.stdout.flush()
+
def main():
global online
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')
parser.add_argument('-c', '--config', required=True,
help='extractor configuration')
- parser.add_argument('-g', '--grammars', required=True,
+ parser.add_argument('-g', '--grammars',
help='grammar output path')
parser.add_argument('-j', '--jobs', type=int, default=1,
help='number of parallel extractors')
@@ -80,9 +115,15 @@ def main():
help='online grammar extraction')
parser.add_argument('-z', '--compress', action='store_true',
help='compress grammars with gzip')
+ parser.add_argument('-t', '--stream', action='store_true',
+ help='stream mode (see README.md)')
args = parser.parse_args()
- if not os.path.exists(args.grammars):
+ if not (args.grammars or args.stream):
+ sys.stderr.write('Error: either -g/--grammars or -t/--stream required\n')
+ sys.exit(1)
+
+ if args.grammars and not os.path.exists(args.grammars):
os.mkdir(args.grammars)
for featdef in args.features:
if not featdef.endswith('.py'):
@@ -91,9 +132,13 @@ def main():
sys.exit(1)
online = args.online
+ stream = args.stream
start_time = monitor_cpu()
if args.jobs > 1:
+ if stream:
+ sys.stderr.write('Error: stream mode incompatible with multiple jobs\n')
+ sys.exit(1)
logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
pool = mp.Pool(args.jobs, make_extractor, (args,))
try:
@@ -103,8 +148,11 @@ def main():
pool.terminate()
else:
make_extractor(args)
- for output in map(extract, enumerate(sys.stdin)):
- print(output)
+ if stream:
+ stream_extract()
+ else:
+ for output in map(extract, enumerate(sys.stdin)):
+ print(output)
stop_time = monitor_cpu()
logging.info("Overall extraction step took %f seconds", stop_time - start_time)