From eba1853a8f537bec4fa98309cc24021c684907cd Mon Sep 17 00:00:00 2001
From: Victor Chahuneau <vchahune@cs.cmu.edu>
Date: Sat, 23 Feb 2013 16:29:40 -0500
Subject: Add compression option to grammar extractor

---
 python/pkg/cdec/sa/extract.py | 44 +++++++++++++++++++++++--------------------
 1 file changed, 24 insertions(+), 20 deletions(-)

(limited to 'python')

diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 2e596bd3..782bed8b 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -1,22 +1,25 @@
 #!/usr/bin/env python
 import sys
 import os
+import re
+import gzip
 import argparse
 import logging
-import re
-import multiprocessing as mp
 import signal
+import multiprocessing as mp
 import cdec.sa
 
 extractor, prefix = None, None
-online = False
+online, compress = False, False
 
-def make_extractor(config, grammars, features):
-    global extractor, prefix, online
+def make_extractor(args):
+    global extractor, prefix, online, compress
     signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
-    load_features(features)
-    extractor = cdec.sa.GrammarExtractor(config, online)
-    prefix = grammars
+    load_features(args.features)
+    extractor = cdec.sa.GrammarExtractor(args.config, online)
+    prefix = args.grammars
+    online = args.online
+    compress = args.compress
 
 def load_features(features):
     for featdef in features:
@@ -27,7 +30,7 @@ def load_features(features):
         sys.path.remove(prefix)
 
 def extract(inp):
-    global extractor, prefix, online
+    global extractor, prefix, online, compress
     i, sentence = inp
     sentence = sentence[:-1]
     fields = re.split('\s*\|\|\|\s*', sentence)
@@ -36,7 +39,7 @@ def extract(inp):
     if online:
         if len(fields) < 3:
             sys.stderr.write('Error: online mode requires references and alignments.'
-                    '  Not adding sentence to training data: {0}\n'.format(sentence))
+                    '  Not adding sentence to training data: {}\n'.format(sentence))
             sentence = fields[0]
         else:
             sentence, reference, alignment = fields[0:3]
@@ -46,18 +49,19 @@ def extract(inp):
         if len(fields) > 1:
             sentence = fields[0]
             suffix = ' ||| ' + ' ||| '.join(fields[1:])
-    grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
-    with open(grammar_file, 'w') as output:
+
+    grammar_file = os.path.join(prefix, 'grammar.'+str(i))
+    if compress: grammar_file += '.gz'
+    with (gzip.open if compress else open)(grammar_file, 'w') as output:
         for rule in extractor.grammar(sentence):
             output.write(str(rule)+'\n')
     # Add training instance _after_ extracting grammars
     if online:
         extractor.add_instance(sentence, reference, alignment)
     grammar_file = os.path.abspath(grammar_file)
-    return '<seg grammar="{0}" id="{1}"> {2} </seg>{3}'.format(grammar_file, i, sentence, suffix)
+    return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix)
 
 def main():
-    global online
     logging.basicConfig(level=logging.INFO)
     parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')
     parser.add_argument('-c', '--config', required=True,
@@ -70,30 +74,30 @@ def main():
                         help='number of sentences / chunk')
     parser.add_argument('-f', '--features', nargs='*', default=[],
                         help='additional feature definitions')
-    parser.add_argument('-o', '--online', action='store_true', default=False,
+    parser.add_argument('-o', '--online', action='store_true',
                         help='online grammar extraction')
+    parser.add_argument('-z', '--compress', action='store_true',
+                        help='compress grammars with gzip')
     args = parser.parse_args()
 
     if not os.path.exists(args.grammars):
         os.mkdir(args.grammars)
     for featdef in args.features:
         if not featdef.endswith('.py'):
-            sys.stderr.write('Error: feature definition file <{0}>'
+            sys.stderr.write('Error: feature definition file <{}>'
                     ' should be a python module\n'.format(featdef))
             sys.exit(1)
     
-    online = args.online
-    
     if args.jobs > 1:
         logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
-        pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars, args.features))
+        pool = mp.Pool(args.jobs, make_extractor, (args,))
         try:
             for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
                 print(output)
         except KeyboardInterrupt:
             pool.terminate()
     else:
-        make_extractor(args.config, args.grammars, args.features)
+        make_extractor(args)
         for output in map(extract, enumerate(sys.stdin)):
             print(output)
 
-- 
cgit v1.2.3