summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xrealtime/mkinput.py9
-rw-r--r--realtime/rt/decoder.py12
-rw-r--r--realtime/rt/rt.py121
3 files changed, 91 insertions, 51 deletions
diff --git a/realtime/mkinput.py b/realtime/mkinput.py
index df434c76..a1b1256d 100755
--- a/realtime/mkinput.py
+++ b/realtime/mkinput.py
@@ -5,13 +5,14 @@ import sys
def main():
- if len(sys.argv[1:]) != 2:
- sys.stderr.write('usage: {} test.src test.ref >test.input\n'.format(sys.argv[0]))
+ if len(sys.argv[1:]) < 2:
+ sys.stderr.write('usage: {} test.src test.ref [ctx_name] >test.input\n'.format(sys.argv[0]))
sys.exit(2)
+ ctx_name = ' {}'.format(sys.argv[3]) if len(sys.argv[1:]) > 2 else ''
for (src, ref) in itertools.izip(open(sys.argv[1]), open(sys.argv[2])):
- sys.stdout.write('TR ||| {}'.format(src))
- sys.stdout.write('LEARN ||| {} ||| {}'.format(src.strip(), ref))
+ sys.stdout.write('TR{} ||| {}'.format(ctx_name, src))
+ sys.stdout.write('LEARN{} ||| {} ||| {}'.format(ctx_name, src.strip(), ref))
if __name__ == '__main__':
main()
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index da646f68..15f7db3f 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -55,8 +55,16 @@ class MIRADecoder(Decoder):
def set_weights(self, w_line):
'''Threadsafe, FIFO'''
self.lock.acquire()
- self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line))
- self.lock.release()
+ try:
+ # Check validity
+ for w_str in w_line.split():
+ (k, v) = w_str.split('=')
+ float(v)
+ self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line))
+ self.lock.release()
+ except:
+ raise Exception('Invalid weights line: {}'.format(w_line))
+
def update(self, sentence, grammar, reference):
'''Threadsafe, FIFO'''
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 4a31070f..40305f66 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -64,6 +64,8 @@ class RealtimeTranslator:
'LEARN': (self.learn, set((2,))),
'SAVE': (self.save_state, set((0, 1))),
'LOAD': (self.load_state, set((0, 1))),
+ 'DROP': (self.drop_ctx, set((0,))),
+ 'LIST': (self.list_ctx, set((0,))),
}
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -115,7 +117,6 @@ class RealtimeTranslator:
self.decoders = {}
- # TODO: state
# Load state if given
if state:
with open(state) as input:
@@ -125,7 +126,8 @@ class RealtimeTranslator:
return self
def __exit__(self, ex_type, ex_value, ex_traceback):
- self.close(ex_type is KeyboardInterrupt)
+ # Force shutdown on exception
+ self.close(ex_type is not None)
def close(self, force=False):
'''Cleanup'''
@@ -166,6 +168,11 @@ class RealtimeTranslator:
lock = self.ctx_locks[ctx_name]
if not force:
lock.acquire()
+ if ctx_name not in self.ctx_names:
+ logging.info('No context found, no action: {}'.format(ctx_name))
+ if not force:
+ lock.release()
+ return
logging.info('Dropping context: {}'.format(ctx_name))
self.ctx_names.remove(ctx_name)
self.ctx_data.pop(ctx_name)
@@ -176,7 +183,11 @@ class RealtimeTranslator:
self.ctx_locks.pop(ctx_name)
if not force:
lock.release()
-
+
+ def list_ctx(self, ctx_name=None):
+ '''Return a string of active contexts'''
+ return 'ctx_name ||| {}'.format(' '.join(sorted(str(ctx_name) for ctx_name in self.ctx_names)))
+
def grammar(self, sentence, ctx_name=None):
'''Extract a sentence-level grammar on demand (or return cached)
Threadsafe wrt extractor but NOT decoder. Acquire ctx_name lock
@@ -251,19 +262,23 @@ class RealtimeTranslator:
return detok_line
def command_line(self, line, ctx_name=None):
+ # COMMAND [ctx_name] ||| arg1 [||| arg2 ...]
args = [f.strip() for f in line.split('|||')]
- (command, nargs) = self.COMMANDS[args[0]]
- # ctx_name provided
- if len(args[1:]) + 1 in nargs:
- logging.info('Context {}: {} ||| {}'.format(args[1], args[0], ' ||| '.join(args[2:])))
- return command(*args[2:], ctx_name=args[1])
- # No ctx_name, use default or passed
- elif len(args[1:]) in nargs:
- logging.info('Context {}: {} ||| {}'.format(ctx_name, args[0], ' ||| '.join(args[1:])))
- return command(*args[1:], ctx_name=ctx_name)
- # nargs doesn't match
- else:
- logging.info('Command error: {}'.format(' ||| '.join(args)))
+ if args[-1] == '':
+ args = args[:-1]
+ if len(args) > 0:
+ cmd_name = args[0].split()
+ # ctx_name provided
+ if len(cmd_name) == 2:
+ (cmd_name, ctx_name) = cmd_name
+ # ctx_name default/passed
+ else:
+ cmd_name = cmd_name[0]
+ (command, nargs) = self.COMMANDS.get(cmd_name, (None, None))
+ if command and len(args[1:]) in nargs:
+ logging.info('{} ({}) ||| {}'.format(cmd_name, ctx_name, ' ||| '.join(args[1:])))
+ return command(*args[1:], ctx_name=ctx_name)
+ logging.info('ERROR: command: {}'.format(' ||| '.join(args)))
def learn(self, source, target, ctx_name=None):
'''Learn from training instance (inc extracting grammar if needed)
@@ -272,7 +287,7 @@ class RealtimeTranslator:
lock.acquire()
self.lazy_ctx(ctx_name)
if '' in (source.strip(), target.strip()):
- logging.info('Error empty source or target: {} ||| {}'.format(source, target))
+ logging.info('ERROR: empty source or target: {} ||| {}'.format(source, target))
lock.release()
return
if self.norm:
@@ -300,49 +315,65 @@ class RealtimeTranslator:
lock.release()
def save_state(self, filename=None, ctx_name=None):
- self.lazy_ctx(ctx_name)
- out = open(filename, 'w') if filename else sys.stdout
lock = self.ctx_locks[ctx_name]
lock.acquire()
+ self.lazy_ctx(ctx_name)
ctx_data = self.ctx_data[ctx_name]
- logging.info('Saving state with {} sentences'.format(len(self.ctx_data)))
+ out = open(filename, 'w') if filename else sys.stdout
+ logging.info('Saving state for context ({}) with {} sentences'.format(ctx_name, len(ctx_data)))
out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights()))
for (source, target, alignment) in ctx_data:
out.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
- lock.release()
out.write('EOF\n')
if filename:
out.close()
+ lock.release()
- def load_state(self, input=sys.stdin, ctx_name=None):
- self.lazy_ctx(ctx_name)
+ def load_state(self, filename=None, ctx_name=None):
lock = self.ctx_locks[ctx_name]
lock.acquire()
+ self.lazy_ctx(ctx_name)
ctx_data = self.ctx_data[ctx_name]
decoder = self.decoders[ctx_name]
- # Non-initial load error
+ input = open(filename) if filename else sys.stdin
+ # Non-initial load error
if ctx_data:
- logging.info('Error: Incremental data has already been added to decoder.')
- logging.info(' State can only be loaded by a freshly started decoder.')
+ logging.info('ERROR: Incremental data has already been added to context ({})'.format(ctx_name))
+ logging.info(' State can only be loaded to a new context.')
+ lock.release()
return
- # MIRA weights
- line = input.readline().strip()
- decoder.decoder.set_weights(line)
- logging.info('Loading state...')
- start_time = time.time()
- # Lines source ||| target ||| alignment
- while True:
+ # Many things can go wrong if bad state data is given
+ try:
+ # MIRA weights
line = input.readline().strip()
- if line == 'EOF':
- break
- (source, target, alignment) = line.split(' ||| ')
- ctx_data.append((source, target, alignment))
- # Extractor
- self.extractor.add_instance(source, target, alignment, ctx_name)
- # HPYPLM
- hyp = decoder.decoder.decode(LIKELY_OOV)
- self.ref_fifo.write('{}\n'.format(target))
- self.ref_fifo.flush()
- stop_time = time.time()
- logging.info('Loaded state with {} sentences in {} seconds'.format(len(ctx_data), stop_time - start_time))
- lock.release()
+ # Throws exception if bad line
+ decoder.decoder.set_weights(line)
+ logging.info('Loading state...')
+ start_time = time.time()
+ # Lines source ||| target ||| alignment
+ while True:
+ line = input.readline()
+ if not line:
+ raise Exception('End of file before EOF line')
+ line = line.strip()
+ if line == 'EOF':
+ break
+ (source, target, alignment) = line.split(' ||| ')
+ ctx_data.append((source, target, alignment))
+ # Extractor
+ self.extractor.add_instance(source, target, alignment, ctx_name)
+ # HPYPLM
+ hyp = decoder.decoder.decode(LIKELY_OOV)
+ decoder.ref_fifo.write('{}\n'.format(target))
+ decoder.ref_fifo.flush()
+ stop_time = time.time()
+ logging.info('Loaded state for context ({}) with {} sentences in {} seconds'.format(ctx_name, len(ctx_data), stop_time - start_time))
+ lock.release()
+ # Recover from bad load attempt by restarting context.
+ # Guaranteed not to cause data loss since only a new context can load state.
+ except:
+ logging.info('ERROR: could not load state, restarting context ({})'.format(ctx_name))
+ # ctx_name is already owned and needs to be restarted before other blocking threads use
+ self.drop_ctx(ctx_name, force=True)
+ self.lazy_ctx(ctx_name)
+ lock.release()