From 38c38f707e58960f80a8dc216673ae0bb0796ade Mon Sep 17 00:00:00 2001
From: Victor Chahuneau <vchahune@cs.cmu.edu>
Date: Tue, 4 Sep 2012 10:21:25 +0100
Subject: Multi-processing grammar extraction

+ various surface fixes
---
 python/examples/rampion.py    | 11 +++++------
 python/pkg/cdec/sa/extract.py | 45 ++++++++++++++++++++++++++++++++-----------
 python/src/hypergraph.pxd     |  2 +-
 3 files changed, 40 insertions(+), 18 deletions(-)

(limited to 'python')

diff --git a/python/examples/rampion.py b/python/examples/rampion.py
index 66d89a61..30244cf7 100644
--- a/python/examples/rampion.py
+++ b/python/examples/rampion.py
@@ -15,7 +15,7 @@ cost = lambda c: 10 * (1 - c.score) # cost definition
 
 def rampion(decoder, sources, references):
     # Empty k-best lists
-    cs = [cdec.score.BLEU(refs).candidate_set() for refs in references]
+    candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references]
     # Weight vector -> sparse
     w = decoder.weights.tosparse()
     w0 = w.copy()
@@ -25,7 +25,7 @@ def rampion(decoder, sources, references):
         logging.info('Iteration {0}: translating...'.format(t+1))
         # Get the hypergraphs and extend the k-best lists
         hgs = []
-        for src, candidates in izip(sources, cs):
+        for src, candidates in izip(sources, candidate_sets):
             hg = decoder.translate(src)
             hgs.append(hg)
             candidates.add_kbest(hg, K)
@@ -36,17 +36,16 @@ def rampion(decoder, sources, references):
         for _ in range(T2):
             # y_i^+, h_i^+; i=1..N
             plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
-                    for candidates in cs]
+                    for candidates in candidate_sets]
             for _ in range(T3):
-                for fp, candidates in izip(plus, cs):
+                for fp, candidates in izip(plus, candidate_sets):
                     # y^-, h^-
                     fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
                     # update weights (line 11-12)
                     w += eta * ((fp - fm) - C/N * (w - w0))
         logging.info('Updated weight vector: {0}'.format(dict(w)))
         # Update decoder weights
-        for fname, fval in w:
-            decoder.weights[fname] = fval
+        decoder.weights = w
 
 def main():
     logging.basicConfig(level=logging.INFO, format='%(message)s')
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 875bf42e..39eac824 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -3,29 +3,52 @@ import sys
 import os
 import argparse
 import logging
+import multiprocessing as mp
+import signal
 import cdec.sa
 
+extractor, prefix = None, None
+def make_extractor(config, grammars):
+    global extractor, prefix
+    signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
+    extractor = cdec.sa.GrammarExtractor(config)
+    prefix = grammars
+
+def extract(inp):
+    global extractor, prefix
+    i, sentence = inp
+    sentence = sentence[:-1]
+    grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
+    with open(grammar_file, 'w') as output:
+        for rule in extractor.grammar(sentence):
+            output.write(str(rule)+'\n')
+    grammar_file = os.path.abspath(grammar_file)
+    return '<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence)
+
+
 def main():
     logging.basicConfig(level=logging.INFO)
     parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')
     parser.add_argument('-c', '--config', required=True,
-                        help='Extractor configuration')
+                        help='extractor configuration')
     parser.add_argument('-g', '--grammars', required=True,
-                        help='Grammar output path')
+                        help='grammar output path')
+    parser.add_argument('-j', '--jobs', type=int, default=1,
+                        help='number of parallel extractors')
+    parser.add_argument('-s', '--chunksize', type=int, default=10,
+                        help='number of sentences / chunk')
     args = parser.parse_args()
 
     if not os.path.exists(args.grammars):
         os.mkdir(args.grammars)
 
-    extractor = cdec.sa.GrammarExtractor(args.config)
-    for i, sentence in enumerate(sys.stdin):
-        sentence = sentence[:-1]
-        grammar_file = os.path.join(args.grammars, 'grammar.{0}'.format(i))
-        with open(grammar_file, 'w') as output:
-            for rule in extractor.grammar(sentence):
-                output.write(str(rule)+'\n')
-        grammar_file = os.path.abspath(grammar_file)
-        print('<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence))
+    logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
+    pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars))
+    try:
+        for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
+            print(output)
+    except KeyboardInterrupt:
+        pool.terminate()
 
 if __name__ == '__main__':
     main()
diff --git a/python/src/hypergraph.pxd b/python/src/hypergraph.pxd
index acab7244..dd3d39cc 100644
--- a/python/src/hypergraph.pxd
+++ b/python/src/hypergraph.pxd
@@ -38,7 +38,7 @@ cdef extern from "decoder/hg.h":
         int GoalNode()
         double NumberOfPaths()
         void Reweight(vector[weight_t]& weights) nogil
-        void Reweight(FastSparseVector& weights) nogil
+        void Reweight(FastSparseVector[weight_t]& weights) nogil
         bint PruneInsideOutside(double beam_alpha,
                                 double density,
                                 EdgeMask* preserve_mask,
-- 
cgit v1.2.3