diff options
-rw-r--r-- | realtime/rt/rt.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index f66d3a4d..7cc5bc10 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -5,8 +5,9 @@ import collections import logging import os import shutil -import sys +import StringIO import subprocess +import sys import tempfile import threading import time @@ -311,28 +312,46 @@ class RealtimeTranslator: os.remove(rm_grammar) lock.release() - def save_state(self, filename=None, ctx_name=None): + def save_state(self, file_or_stringio=None, ctx_name=None): + '''Write state (several lines terminated by EOF line) to file, buffer, or stdout''' lock = self.ctx_locks[ctx_name] lock.acquire() self.lazy_ctx(ctx_name) ctx_data = self.ctx_data[ctx_name] - out = open(filename, 'w') if filename else sys.stdout + # Filename, StringIO or None (stdout) + if file_or_stringio: + if isinstance(file_or_stringio, StringIO.StringIO): + out = file_or_stringio + else: + out = open(file_or_stringio, 'w') + else: + out = sys.stdout logger.info('({}) Saving state with {} sentences'.format(ctx_name, len(ctx_data))) out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights())) for (source, target, alignment) in ctx_data: out.write('{} ||| {} ||| {}\n'.format(source, target, alignment)) out.write('EOF\n') - if filename: + # Close if file + if file_or_stringio and not isinstance(file_or_stringio, StringIO.StringIO): out.close() lock.release() - def load_state(self, filename=None, ctx_name=None): + def load_state(self, file_or_stringio=None, ctx_name=None): + '''Load state (several lines terminated by EOF line) from file, buffer, or stdin. + Restarts context on any error.''' lock = self.ctx_locks[ctx_name] lock.acquire() self.lazy_ctx(ctx_name) ctx_data = self.ctx_data[ctx_name] decoder = self.decoders[ctx_name] - input = open(filename) if filename else sys.stdin + # Filename, StringIO, or None (stdin) + if file_or_stringio: + if isinstance(file_or_stringio, StringIO.StringIO): + input = file_or_stringio.getvalue() + else: + input = open(file_or_stringio) + else: + input = sys.stdin # Non-initial load error if ctx_data: logger.info('({}) ERROR: Incremental data has already been added to context'.format(ctx_name)) |