diff options
Diffstat (limited to 'python/cdec/sa')
| -rw-r--r-- | python/cdec/sa/extract.py | 56 | 
1 files changed, 52 insertions, 4 deletions
| diff --git a/python/cdec/sa/extract.py b/python/cdec/sa/extract.py index b6502c52..b6c11f05 100644 --- a/python/cdec/sa/extract.py +++ b/python/cdec/sa/extract.py @@ -62,13 +62,48 @@ def extract(inp):      grammar_file = os.path.abspath(grammar_file)      return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix) +def stream_extract(): +    global extractor, online, compress +    while True: +        line = sys.stdin.readline() +        if not line: +            break +        fields = re.split('\s*\|\|\|\s*', line.strip()) +        # context ||| cmd +        if len(fields) == 2: +            (context, cmd) = fields +            if cmd.lower() == 'drop': +                if online: +                    extractor.drop_ctx(context) +                    sys.stdout.write('drop {}\n'.format(context)) +                else: +                    sys.stdout.write('Error: online mode not set. Skipping line: {}\n'.format(line.strip())) +        # context ||| sentence ||| grammar_file +        elif len(fields) == 3: +            (context, sentence, grammar_file) = fields +            with (gzip.open if compress else open)(grammar_file, 'w') as output: +                for rule in extractor.grammar(sentence, context): +                    output.write(str(rule)+'\n') +            sys.stdout.write('{}\n'.format(grammar_file)) +        # context ||| sentence ||| reference ||| alignment +        elif len(fields) == 4: +            (context, sentence, reference, alignment) = fields +            if online: +                extractor.add_instance(sentence, reference, alignment, context) +                sys.stdout.write('learn {}\n'.format(context)) +            else: +                sys.stdout.write('Error: online mode not set. Skipping line: {}\n'.format(line.strip())) +        else: +            sys.stdout.write('Error: see README.md for stream mode usage.  Skipping line: {}\n'.format(line.strip())) +        sys.stdout.flush() +  def main():      global online      logging.basicConfig(level=logging.INFO)      parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')      parser.add_argument('-c', '--config', required=True,                          help='extractor configuration') -    parser.add_argument('-g', '--grammars', required=True, +    parser.add_argument('-g', '--grammars',                          help='grammar output path')      parser.add_argument('-j', '--jobs', type=int, default=1,                          help='number of parallel extractors') @@ -80,9 +115,15 @@ def main():                          help='online grammar extraction')      parser.add_argument('-z', '--compress', action='store_true',                          help='compress grammars with gzip') +    parser.add_argument('-t', '--stream', action='store_true', +                        help='stream mode (see README.md)')      args = parser.parse_args() -    if not os.path.exists(args.grammars): +    if not (args.grammars or args.stream): +        sys.stderr.write('Error: either -g/--grammars or -t/--stream required\n') +        sys.exit(1) + +    if args.grammars and not os.path.exists(args.grammars):          os.mkdir(args.grammars)      for featdef in args.features:          if not featdef.endswith('.py'): @@ -91,9 +132,13 @@ def main():              sys.exit(1)      online = args.online +    stream = args.stream      start_time = monitor_cpu()      if args.jobs > 1: +        if stream: +            sys.stderr.write('Error: stream mode incompatible with multiple jobs\n') +            sys.exit(1)          logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)          pool = mp.Pool(args.jobs, make_extractor, (args,))          try: @@ -103,8 +148,11 @@ def main():              pool.terminate()      else:          make_extractor(args) -        for output in map(extract, enumerate(sys.stdin)): -            print(output) +        if stream: +            stream_extract() +        else: +            for output in map(extract, enumerate(sys.stdin)): +                print(output)      stop_time = monitor_cpu()      logging.info("Overall extraction step took %f seconds", stop_time - start_time) | 
