summaryrefslogtreecommitdiff
path: root/realtime/rt
diff options
context:
space:
mode:
Diffstat (limited to 'realtime/rt')
-rw-r--r--realtime/rt/rt.py31
1 files changed, 25 insertions, 6 deletions
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index f66d3a4d..7cc5bc10 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -5,8 +5,9 @@ import collections
import logging
import os
import shutil
-import sys
+import StringIO
import subprocess
+import sys
import tempfile
import threading
import time
@@ -311,28 +312,46 @@ class RealtimeTranslator:
os.remove(rm_grammar)
lock.release()
- def save_state(self, filename=None, ctx_name=None):
+ def save_state(self, file_or_stringio=None, ctx_name=None):
+ '''Write state (several lines terminated by EOF line) to file, buffer, or stdout'''
lock = self.ctx_locks[ctx_name]
lock.acquire()
self.lazy_ctx(ctx_name)
ctx_data = self.ctx_data[ctx_name]
- out = open(filename, 'w') if filename else sys.stdout
+ # Filename, StringIO or None (stdout)
+ if file_or_stringio:
+ if isinstance(file_or_stringio, StringIO.StringIO):
+ out = file_or_stringio
+ else:
+ out = open(file_or_stringio, 'w')
+ else:
+ out = sys.stdout
logger.info('({}) Saving state with {} sentences'.format(ctx_name, len(ctx_data)))
out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights()))
for (source, target, alignment) in ctx_data:
out.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
out.write('EOF\n')
- if filename:
+ # Close if file
+ if file_or_stringio and not isinstance(file_or_stringio, StringIO.StringIO):
out.close()
lock.release()
- def load_state(self, filename=None, ctx_name=None):
+ def load_state(self, file_or_stringio=None, ctx_name=None):
+ '''Load state (several lines terminated by EOF line) from file, buffer, or stdin.
+ Restarts context on any error.'''
lock = self.ctx_locks[ctx_name]
lock.acquire()
self.lazy_ctx(ctx_name)
ctx_data = self.ctx_data[ctx_name]
decoder = self.decoders[ctx_name]
- input = open(filename) if filename else sys.stdin
+ # Filename, StringIO, or None (stdin)
+ if file_or_stringio:
+ if isinstance(file_or_stringio, StringIO.StringIO):
+ input = file_or_stringio.getvalue()
+ else:
+ input = open(file_or_stringio)
+ else:
+ input = sys.stdin
# Non-initial load error
if ctx_data:
logger.info('({}) ERROR: Incremental data has already been added to context'.format(ctx_name))