diff options
| author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-09-04 10:21:25 +0100 | 
|---|---|---|
| committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-09-04 10:21:25 +0100 | 
| commit | b774a1ce6aced0e17d308d775cb32ba18ab755a8 (patch) | |
| tree | 5ac4e3edcbe3d7ad3d2283eb080e862a2f30091d | |
| parent | 063152d73f2814be32dfa8e927fa00caf1af1855 (diff) | |
Multi-processing grammar extraction
+ various surface fixes
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | python/examples/rampion.py | 11 | ||||
| -rw-r--r-- | python/pkg/cdec/sa/extract.py | 45 | ||||
| -rw-r--r-- | python/src/hypergraph.pxd | 2 | 
4 files changed, 41 insertions, 18 deletions
| @@ -117,6 +117,7 @@ phrasinator/gibbs_train_plm_notables  previous.sh  pro-train/mr_pro_map  pro-train/mr_pro_reduce +python/setup.py  rampion/rampion_cccp  rst_parser/mst_train  rst_parser/random_tree diff --git a/python/examples/rampion.py b/python/examples/rampion.py index 66d89a61..30244cf7 100644 --- a/python/examples/rampion.py +++ b/python/examples/rampion.py @@ -15,7 +15,7 @@ cost = lambda c: 10 * (1 - c.score) # cost definition  def rampion(decoder, sources, references):      # Empty k-best lists -    cs = [cdec.score.BLEU(refs).candidate_set() for refs in references] +    candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references]      # Weight vector -> sparse      w = decoder.weights.tosparse()      w0 = w.copy() @@ -25,7 +25,7 @@ def rampion(decoder, sources, references):          logging.info('Iteration {0}: translating...'.format(t+1))          # Get the hypergraphs and extend the k-best lists          hgs = [] -        for src, candidates in izip(sources, cs): +        for src, candidates in izip(sources, candidate_sets):              hg = decoder.translate(src)              hgs.append(hg)              candidates.add_kbest(hg, K) @@ -36,17 +36,16 @@ def rampion(decoder, sources, references):          for _ in range(T2):              # y_i^+, h_i^+; i=1..N              plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap -                    for candidates in cs] +                    for candidates in candidate_sets]              for _ in range(T3): -                for fp, candidates in izip(plus, cs): +                for fp, candidates in izip(plus, candidate_sets):                      # y^-, h^-                      fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap                      # update weights (line 11-12)                      w += eta * ((fp - fm) - C/N * (w - w0))          logging.info('Updated weight vector: {0}'.format(dict(w)))          # Update decoder weights -        for fname, fval in w: -            decoder.weights[fname] = fval +        decoder.weights = w  def main():      logging.basicConfig(level=logging.INFO, format='%(message)s') diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py index 875bf42e..39eac824 100644 --- a/python/pkg/cdec/sa/extract.py +++ b/python/pkg/cdec/sa/extract.py @@ -3,29 +3,52 @@ import sys  import os  import argparse  import logging +import multiprocessing as mp +import signal  import cdec.sa +extractor, prefix = None, None +def make_extractor(config, grammars): +    global extractor, prefix +    signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C +    extractor = cdec.sa.GrammarExtractor(config) +    prefix = grammars + +def extract(inp): +    global extractor, prefix +    i, sentence = inp +    sentence = sentence[:-1] +    grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i)) +    with open(grammar_file, 'w') as output: +        for rule in extractor.grammar(sentence): +            output.write(str(rule)+'\n') +    grammar_file = os.path.abspath(grammar_file) +    return '<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence) + +  def main():      logging.basicConfig(level=logging.INFO)      parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')      parser.add_argument('-c', '--config', required=True, -                        help='Extractor configuration') +                        help='extractor configuration')      parser.add_argument('-g', '--grammars', required=True, -                        help='Grammar output path') +                        help='grammar output path') +    parser.add_argument('-j', '--jobs', type=int, default=1, +                        help='number of parallel extractors') +    parser.add_argument('-s', '--chunksize', type=int, default=10, +                        help='number of sentences / chunk')      args = parser.parse_args()      if not os.path.exists(args.grammars):          os.mkdir(args.grammars) -    extractor = cdec.sa.GrammarExtractor(args.config) -    for i, sentence in enumerate(sys.stdin): -        sentence = sentence[:-1] -        grammar_file = os.path.join(args.grammars, 'grammar.{0}'.format(i)) -        with open(grammar_file, 'w') as output: -            for rule in extractor.grammar(sentence): -                output.write(str(rule)+'\n') -        grammar_file = os.path.abspath(grammar_file) -        print('<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence)) +    logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize) +    pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars)) +    try: +        for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize): +            print(output) +    except KeyboardInterrupt: +        pool.terminate()  if __name__ == '__main__':      main() diff --git a/python/src/hypergraph.pxd b/python/src/hypergraph.pxd index acab7244..dd3d39cc 100644 --- a/python/src/hypergraph.pxd +++ b/python/src/hypergraph.pxd @@ -38,7 +38,7 @@ cdef extern from "decoder/hg.h":          int GoalNode()          double NumberOfPaths()          void Reweight(vector[weight_t]& weights) nogil -        void Reweight(FastSparseVector& weights) nogil +        void Reweight(FastSparseVector[weight_t]& weights) nogil          bint PruneInsideOutside(double beam_alpha,                                  double density,                                  EdgeMask* preserve_mask, | 
