From 046e726aa598c75ca8f5a22a74d83cd2c405b741 Mon Sep 17 00:00:00 2001 From: Michael Denkowski Date: Mon, 30 Sep 2013 07:45:06 -0700 Subject: Command handling --- realtime/mkinput.py | 2 +- realtime/realtime.py | 16 ++++++---------- realtime/rt/rt.py | 37 ++++++++++++++++++++++--------------- 3 files changed, 29 insertions(+), 26 deletions(-) (limited to 'realtime') diff --git a/realtime/mkinput.py b/realtime/mkinput.py index 897b44fd..df434c76 100755 --- a/realtime/mkinput.py +++ b/realtime/mkinput.py @@ -10,7 +10,7 @@ def main(): sys.exit(2) for (src, ref) in itertools.izip(open(sys.argv[1]), open(sys.argv[2])): - sys.stdout.write('{}'.format(src)) + sys.stdout.write('TR ||| {}'.format(src)) sys.stdout.write('LEARN ||| {} ||| {}'.format(src.strip(), ref)) if __name__ == '__main__': diff --git a/realtime/realtime.py b/realtime/realtime.py index 38da4413..af3a3aba 100755 --- a/realtime/realtime.py +++ b/realtime/realtime.py @@ -16,11 +16,9 @@ class Parser(argparse.ArgumentParser): sys.exit(2) def handle_line(translator, line, output, ctx_name): - if '|||' in line: - translator.command_line(line, ctx_name) - else: - hyp = translator.decode(line, ctx_name) - output.write('{}\n'.format(hyp)) + res = translator.command_line(line, ctx_name) + if res: + output.write('{}\n'.format(res)) output.flush() def test1(translator, input, output, ctx_name): @@ -83,11 +81,9 @@ def main(): if not line: break line = line.strip() - if '|||' in line: - translator.command_line(line) - else: - hyp = translator.decode(line) - sys.stdout.write('{}\n'.format(hyp)) + res = translator.command_line(line) + if res: + sys.stdout.write('{}\n'.format(res)) sys.stdout.flush() if __name__ == '__main__': diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index 5ace5d59..4a31070f 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -58,9 +58,14 @@ class RealtimeTranslator: def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None): - # TODO: save/load - self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state} - + # name -> (method, set of possible nargs) + self.COMMANDS = { + 'TR': (self.translate, set((1,))), + 'LEARN': (self.learn, set((2,))), + 'SAVE': (self.save_state, set((0, 1))), + 'LOAD': (self.load_state, set((0, 1))), + } + cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) ### Single instance for all contexts @@ -203,13 +208,12 @@ class RealtimeTranslator: self.extractor_lock.release() return grammar_file - def decode(self, sentence, ctx_name=None): + def translate(self, sentence, ctx_name=None): '''Decode a sentence (inc extracting a grammar if needed) Threadsafe, FIFO''' lock = self.ctx_locks[ctx_name] lock.acquire() self.lazy_ctx(ctx_name) - logging.info('DECODE: {}'.format(sentence)) # Empty in, empty out if sentence.strip() == '': lock.release() @@ -246,16 +250,20 @@ class RealtimeTranslator: self.detokenizer_lock.release() return detok_line - # TODO def command_line(self, line, ctx_name=None): - args = [f.strip() for f in line.split('|||')] - #try: - if len(args) == 2 and not args[1]: - self.commands[args[0]](ctx_name) - else: - self.commands[args[0]](*args[1:], ctx_name=ctx_name) - #except: - # logging.info('Command error: {}'.format(' ||| '.join(args))) + args = [f.strip() for f in line.split('|||')] + (command, nargs) = self.COMMANDS[args[0]] + # ctx_name provided + if len(args[1:]) + 1 in nargs: + logging.info('Context {}: {} ||| {}'.format(args[1], args[0], ' ||| '.join(args[2:]))) + return command(*args[2:], ctx_name=args[1]) + # No ctx_name, use default or passed + elif len(args[1:]) in nargs: + logging.info('Context {}: {} ||| {}'.format(ctx_name, args[0], ' ||| '.join(args[1:]))) + return command(*args[1:], ctx_name=ctx_name) + # nargs doesn't match + else: + logging.info('Command error: {}'.format(' ||| '.join(args))) def learn(self, source, target, ctx_name=None): '''Learn from training instance (inc extracting grammar if needed) @@ -263,7 +271,6 @@ class RealtimeTranslator: lock = self.ctx_locks[ctx_name] lock.acquire() self.lazy_ctx(ctx_name) - logging.info('LEARN: {}'.format(source)) if '' in (source.strip(), target.strip()): logging.info('Error empty source or target: {} ||| {}'.format(source, target)) lock.release() -- cgit v1.2.3