diff options
author | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-17 22:05:32 -0700 |
---|---|---|
committer | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-17 22:05:32 -0700 |
commit | 6d427339b45f8aa74410437f91b5a01afc824120 (patch) | |
tree | eb8914fe5fbd3d148610836f4f58fd660027230b /realtime | |
parent | a30a2e59ff117ad6bae80ece2bf535767daf7db6 (diff) |
Save/load state in realtime
Diffstat (limited to 'realtime')
-rwxr-xr-x | realtime/realtime.py | 48 | ||||
-rw-r--r-- | realtime/rt/__init__.py | 2 | ||||
-rw-r--r-- | realtime/rt/decoder.py | 11 | ||||
-rw-r--r-- | realtime/rt/rt.py | 68 |
4 files changed, 97 insertions, 32 deletions
diff --git a/realtime/realtime.py b/realtime/realtime.py index 554c52ca..a6ed1511 100755 --- a/realtime/realtime.py +++ b/realtime/realtime.py @@ -15,8 +15,9 @@ class Parser(argparse.ArgumentParser): def main(): - parser = Parser(description='Real-time adaptive translation with cdec.') - parser.add_argument('-c', '--config', required=True, help='Config directory (see README.md)') + parser = Parser(description='Real-time adaptive translation with cdec. (See README.md)') + parser.add_argument('-c', '--config', required=True, help='Config directory') + parser.add_argument('-s', '--state', help='Load state file (saved incremental data)') parser.add_argument('-n', '--normalize', help='Normalize text (tokenize, translate, detokenize)', action='store_true') parser.add_argument('-T', '--temp', help='Temp directory (default /tmp)', default='/tmp') parser.add_argument('-a', '--cache', help='Grammar cache size (default 5)', default='5') @@ -26,27 +27,28 @@ def main(): if args.verbose: logging.basicConfig(level=logging.INFO) - rtd = rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) - - try: - while True: - line = sys.stdin.readline() - if not line: - break - input = [f.strip() for f in line.split('|||')] - if len(input) == 1: - hyp = rtd.decode(input[0]) - sys.stdout.write('{}\n'.format(hyp)) - sys.stdout.flush() - else: - rtd.command(input) - - # Clean exit on ctrl+c - except KeyboardInterrupt: - logging.info('Caught KeyboardInterrupt, exiting') - - # Cleanup - rtd.close() + with rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as rtd: + + try: + # Load state if given + if args.state: + rtd.load_state(args.state) + # Read lines and commands + while True: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + if '|||' in line: + rtd.command_line(line) + else: + hyp = rtd.decode(line) + sys.stdout.write('{}\n'.format(hyp)) + sys.stdout.flush() + + # Clean exit on ctrl+c + except KeyboardInterrupt: + logging.info('Caught KeyboardInterrupt, exiting') if __name__ == '__main__': main() diff --git a/realtime/rt/__init__.py b/realtime/rt/__init__.py index fbde8f4d..c76acc4d 100644 --- a/realtime/rt/__init__.py +++ b/realtime/rt/__init__.py @@ -9,7 +9,7 @@ except ImportError as ie: sys.path.append(pycdec) import cdec except: - sys.stderr.write('Error: cannot import pycdec. Please check the cdec/python is built.\n') + sys.stderr.write('Error: cannot import pycdec. Please check that cdec/python is built.\n') raise ie # Regular init imports diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py index 57739d93..aa6db64d 100644 --- a/realtime/rt/decoder.py +++ b/realtime/rt/decoder.py @@ -9,8 +9,8 @@ class Decoder: def close(self): self.decoder.stdin.close() - def decode(self, sentence, grammar): - input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) + def decode(self, sentence, grammar=None): + input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) if grammar else '{}\n'.format(sentence) self.decoder.stdin.write(input) return self.decoder.stdout.readline().strip() @@ -33,6 +33,13 @@ class MIRADecoder(Decoder): logging.info('Executing: {}'.format(' '.join(mira_cmd))) self.decoder = util.popen_io(mira_cmd) + def get_weights(self): + self.decoder.stdin.write('WEIGHTS ||| WRITE\n') + return self.decoder.stdout.readline().strip() + + def set_weights(self, w_line): + self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line)) + def update(self, sentence, grammar, reference): input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference) self.decoder.stdin.write(input) diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index 0ce05a56..fedc1fcf 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -15,14 +15,18 @@ import aligner import decoder import util +LIKELY_OOV = '("OOV")' + class RealtimeDecoder: - def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False): + def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None): - self.commands = {'LEARN': self.learn} + self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state} cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + self.inc_data = [] # instances of (source, target) + # Temporary work dir self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.') logging.info('Using temp dir {}'.format(self.tmp)) @@ -68,6 +72,17 @@ class RealtimeDecoder: decoder_weights = os.path.join(configdir, 'weights.final') self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights) + # Load state if given + if state: + with open(state) as input: + self.load_state(input) + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.close() + def close(self): logging.info('Closing processes') self.aligner.close() @@ -128,9 +143,13 @@ class RealtimeDecoder: self.detokenizer.stdin.write('{}\n'.format(line)) return self.detokenizer.stdout.readline().strip() - def command(self, args): + def command_line(self, line): + args = [f.strip() for f in line.split('|||')] try: - self.commands[args[0]](*args[1:]) + if len(args) == 2 and not args[1]: + self.commands[args[0]]() + else: + self.commands[args[0]](*args[1:]) except: logging.info('Command error: {}'.format(' ||| '.join(args))) @@ -145,9 +164,12 @@ class RealtimeDecoder: grammar_file = self.grammar(source) mira_log = self.decoder.update(source, grammar_file, target) logging.info('MIRA: {}'.format(mira_log)) - # Add aligned sentence pair to grammar extractor + # Align instance alignment = self.aligner.align(source, target) - logging.info('Adding instance: {} ||| {} ||| {}'.format(source, target, alignment)) + # Store incremental data for save/load + self.inc_data.append((source, target, alignment)) + # Add aligned sentence pair to grammar extractor + logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment)) self.extractor.add_instance(source, target, alignment) # Clear (old) cached grammar rm_grammar = self.grammar_dict.pop(source) @@ -156,3 +178,37 @@ class RealtimeDecoder: logging.info('Adding to HPYPLM: {}'.format(target)) self.ref_fifo.write('{}\n'.format(target)) self.ref_fifo.flush() + + def save_state(self): + logging.info('Saving state with {} sentences'.format(len(self.inc_data))) + sys.stdout.write('{}\n'.format(self.decoder.get_weights())) + for (source, target, alignment) in self.inc_data: + sys.stdout.write('{} ||| {} ||| {}\n'.format(source, target, alignment)) + sys.stdout.write('EOF\n') + + def load_state(self, input=sys.stdin): + # Non-initial load error + if self.inc_data: + logging.info('Error: Incremental data has already been added to decoder.') + logging.info(' State can only be loaded by a freshly started decoder.') + return + # MIRA weights + line = input.readline().strip() + self.decoder.set_weights(line) + logging.info('Loading state...') + start_time = time.time() + # Lines source ||| target ||| alignment + while True: + line = input.readline().strip() + if line == 'EOF': + break + (source, target, alignment) = line.split(' ||| ') + self.inc_data.append((source, target, alignment)) + # Extractor + self.extractor.add_instance(source, target, alignment) + # HPYPLM + hyp = self.decoder.decode(LIKELY_OOV) + self.ref_fifo.write('{}\n'.format(target)) + self.ref_fifo.flush() + stop_time = time.time() + logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time)) |