diff options
author | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-30 16:00:12 -0700 |
---|---|---|
committer | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-30 16:00:12 -0700 |
commit | 3cc87778e8985ff7e1bcf3a4a12d071c2ddd5639 (patch) | |
tree | c860fbdf22c1ef0b26a8b244af9453db0c8b920e | |
parent | 046e726aa598c75ca8f5a22a74d83cd2c405b741 (diff) |
New commands, save/load context
-rwxr-xr-x | realtime/mkinput.py | 9 | ||||
-rw-r--r-- | realtime/rt/decoder.py | 12 | ||||
-rw-r--r-- | realtime/rt/rt.py | 121 |
3 files changed, 91 insertions, 51 deletions
diff --git a/realtime/mkinput.py b/realtime/mkinput.py index df434c76..a1b1256d 100755 --- a/realtime/mkinput.py +++ b/realtime/mkinput.py @@ -5,13 +5,14 @@ import sys def main(): - if len(sys.argv[1:]) != 2: - sys.stderr.write('usage: {} test.src test.ref >test.input\n'.format(sys.argv[0])) + if len(sys.argv[1:]) < 2: + sys.stderr.write('usage: {} test.src test.ref [ctx_name] >test.input\n'.format(sys.argv[0])) sys.exit(2) + ctx_name = ' {}'.format(sys.argv[3]) if len(sys.argv[1:]) > 2 else '' for (src, ref) in itertools.izip(open(sys.argv[1]), open(sys.argv[2])): - sys.stdout.write('TR ||| {}'.format(src)) - sys.stdout.write('LEARN ||| {} ||| {}'.format(src.strip(), ref)) + sys.stdout.write('TR{} ||| {}'.format(ctx_name, src)) + sys.stdout.write('LEARN{} ||| {} ||| {}'.format(ctx_name, src.strip(), ref)) if __name__ == '__main__': main() diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py index da646f68..15f7db3f 100644 --- a/realtime/rt/decoder.py +++ b/realtime/rt/decoder.py @@ -55,8 +55,16 @@ class MIRADecoder(Decoder): def set_weights(self, w_line): '''Threadsafe, FIFO''' self.lock.acquire() - self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line)) - self.lock.release() + try: + # Check validity + for w_str in w_line.split(): + (k, v) = w_str.split('=') + float(v) + self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line)) + self.lock.release() + except: + raise Exception('Invalid weights line: {}'.format(w_line)) + def update(self, sentence, grammar, reference): '''Threadsafe, FIFO''' diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index 4a31070f..40305f66 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -64,6 +64,8 @@ class RealtimeTranslator: 'LEARN': (self.learn, set((2,))), 'SAVE': (self.save_state, set((0, 1))), 'LOAD': (self.load_state, set((0, 1))), + 'DROP': (self.drop_ctx, set((0,))), + 'LIST': (self.list_ctx, set((0,))), } cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) @@ -115,7 +117,6 @@ class RealtimeTranslator: self.decoders = {} - # TODO: state # Load state if given if state: with open(state) as input: @@ -125,7 +126,8 @@ class RealtimeTranslator: return self def __exit__(self, ex_type, ex_value, ex_traceback): - self.close(ex_type is KeyboardInterrupt) + # Force shutdown on exception + self.close(ex_type is not None) def close(self, force=False): '''Cleanup''' @@ -166,6 +168,11 @@ class RealtimeTranslator: lock = self.ctx_locks[ctx_name] if not force: lock.acquire() + if ctx_name not in self.ctx_names: + logging.info('No context found, no action: {}'.format(ctx_name)) + if not force: + lock.release() + return logging.info('Dropping context: {}'.format(ctx_name)) self.ctx_names.remove(ctx_name) self.ctx_data.pop(ctx_name) @@ -176,7 +183,11 @@ class RealtimeTranslator: self.ctx_locks.pop(ctx_name) if not force: lock.release() - + + def list_ctx(self, ctx_name=None): + '''Return a string of active contexts''' + return 'ctx_name ||| {}'.format(' '.join(sorted(str(ctx_name) for ctx_name in self.ctx_names))) + def grammar(self, sentence, ctx_name=None): '''Extract a sentence-level grammar on demand (or return cached) Threadsafe wrt extractor but NOT decoder. Acquire ctx_name lock @@ -251,19 +262,23 @@ class RealtimeTranslator: return detok_line def command_line(self, line, ctx_name=None): + # COMMAND [ctx_name] ||| arg1 [||| arg2 ...] args = [f.strip() for f in line.split('|||')] - (command, nargs) = self.COMMANDS[args[0]] - # ctx_name provided - if len(args[1:]) + 1 in nargs: - logging.info('Context {}: {} ||| {}'.format(args[1], args[0], ' ||| '.join(args[2:]))) - return command(*args[2:], ctx_name=args[1]) - # No ctx_name, use default or passed - elif len(args[1:]) in nargs: - logging.info('Context {}: {} ||| {}'.format(ctx_name, args[0], ' ||| '.join(args[1:]))) - return command(*args[1:], ctx_name=ctx_name) - # nargs doesn't match - else: - logging.info('Command error: {}'.format(' ||| '.join(args))) + if args[-1] == '': + args = args[:-1] + if len(args) > 0: + cmd_name = args[0].split() + # ctx_name provided + if len(cmd_name) == 2: + (cmd_name, ctx_name) = cmd_name + # ctx_name default/passed + else: + cmd_name = cmd_name[0] + (command, nargs) = self.COMMANDS.get(cmd_name, (None, None)) + if command and len(args[1:]) in nargs: + logging.info('{} ({}) ||| {}'.format(cmd_name, ctx_name, ' ||| '.join(args[1:]))) + return command(*args[1:], ctx_name=ctx_name) + logging.info('ERROR: command: {}'.format(' ||| '.join(args))) def learn(self, source, target, ctx_name=None): '''Learn from training instance (inc extracting grammar if needed) @@ -272,7 +287,7 @@ class RealtimeTranslator: lock.acquire() self.lazy_ctx(ctx_name) if '' in (source.strip(), target.strip()): - logging.info('Error empty source or target: {} ||| {}'.format(source, target)) + logging.info('ERROR: empty source or target: {} ||| {}'.format(source, target)) lock.release() return if self.norm: @@ -300,49 +315,65 @@ class RealtimeTranslator: lock.release() def save_state(self, filename=None, ctx_name=None): - self.lazy_ctx(ctx_name) - out = open(filename, 'w') if filename else sys.stdout lock = self.ctx_locks[ctx_name] lock.acquire() + self.lazy_ctx(ctx_name) ctx_data = self.ctx_data[ctx_name] - logging.info('Saving state with {} sentences'.format(len(self.ctx_data))) + out = open(filename, 'w') if filename else sys.stdout + logging.info('Saving state for context ({}) with {} sentences'.format(ctx_name, len(ctx_data))) out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights())) for (source, target, alignment) in ctx_data: out.write('{} ||| {} ||| {}\n'.format(source, target, alignment)) - lock.release() out.write('EOF\n') if filename: out.close() + lock.release() - def load_state(self, input=sys.stdin, ctx_name=None): - self.lazy_ctx(ctx_name) + def load_state(self, filename=None, ctx_name=None): lock = self.ctx_locks[ctx_name] lock.acquire() + self.lazy_ctx(ctx_name) ctx_data = self.ctx_data[ctx_name] decoder = self.decoders[ctx_name] - # Non-initial load error + input = open(filename) if filename else sys.stdin + # Non-initial load error if ctx_data: - logging.info('Error: Incremental data has already been added to decoder.') - logging.info(' State can only be loaded by a freshly started decoder.') + logging.info('ERROR: Incremental data has already been added to context ({})'.format(ctx_name)) + logging.info(' State can only be loaded to a new context.') + lock.release() return - # MIRA weights - line = input.readline().strip() - decoder.decoder.set_weights(line) - logging.info('Loading state...') - start_time = time.time() - # Lines source ||| target ||| alignment - while True: + # Many things can go wrong if bad state data is given + try: + # MIRA weights line = input.readline().strip() - if line == 'EOF': - break - (source, target, alignment) = line.split(' ||| ') - ctx_data.append((source, target, alignment)) - # Extractor - self.extractor.add_instance(source, target, alignment, ctx_name) - # HPYPLM - hyp = decoder.decoder.decode(LIKELY_OOV) - self.ref_fifo.write('{}\n'.format(target)) - self.ref_fifo.flush() - stop_time = time.time() - logging.info('Loaded state with {} sentences in {} seconds'.format(len(ctx_data), stop_time - start_time)) - lock.release() + # Throws exception if bad line + decoder.decoder.set_weights(line) + logging.info('Loading state...') + start_time = time.time() + # Lines source ||| target ||| alignment + while True: + line = input.readline() + if not line: + raise Exception('End of file before EOF line') + line = line.strip() + if line == 'EOF': + break + (source, target, alignment) = line.split(' ||| ') + ctx_data.append((source, target, alignment)) + # Extractor + self.extractor.add_instance(source, target, alignment, ctx_name) + # HPYPLM + hyp = decoder.decoder.decode(LIKELY_OOV) + decoder.ref_fifo.write('{}\n'.format(target)) + decoder.ref_fifo.flush() + stop_time = time.time() + logging.info('Loaded state for context ({}) with {} sentences in {} seconds'.format(ctx_name, len(ctx_data), stop_time - start_time)) + lock.release() + # Recover from bad load attempt by restarting context. + # Guaranteed not to cause data loss since only a new context can load state. + except: + logging.info('ERROR: could not load state, restarting context ({})'.format(ctx_name)) + # ctx_name is already owned and needs to be restarted before other blocking threads use + self.drop_ctx(ctx_name, force=True) + self.lazy_ctx(ctx_name) + lock.release() |