summaryrefslogtreecommitdiff
path: root/realtime/rt/rt.py
diff options
context:
space:
mode:
Diffstat (limited to 'realtime/rt/rt.py')
-rw-r--r--realtime/rt/rt.py68
1 files changed, 62 insertions, 6 deletions
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 0ce05a56..fedc1fcf 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -15,14 +15,18 @@ import aligner
import decoder
import util
+LIKELY_OOV = '("OOV")'
+
class RealtimeDecoder:
- def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False):
+ def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None):
- self.commands = {'LEARN': self.learn}
+ self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state}
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ self.inc_data = [] # instances of (source, target)
+
# Temporary work dir
self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.')
logging.info('Using temp dir {}'.format(self.tmp))
@@ -68,6 +72,17 @@ class RealtimeDecoder:
decoder_weights = os.path.join(configdir, 'weights.final')
self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+ # Load state if given
+ if state:
+ with open(state) as input:
+ self.load_state(input)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
def close(self):
logging.info('Closing processes')
self.aligner.close()
@@ -128,9 +143,13 @@ class RealtimeDecoder:
self.detokenizer.stdin.write('{}\n'.format(line))
return self.detokenizer.stdout.readline().strip()
- def command(self, args):
+ def command_line(self, line):
+ args = [f.strip() for f in line.split('|||')]
try:
- self.commands[args[0]](*args[1:])
+ if len(args) == 2 and not args[1]:
+ self.commands[args[0]]()
+ else:
+ self.commands[args[0]](*args[1:])
except:
logging.info('Command error: {}'.format(' ||| '.join(args)))
@@ -145,9 +164,12 @@ class RealtimeDecoder:
grammar_file = self.grammar(source)
mira_log = self.decoder.update(source, grammar_file, target)
logging.info('MIRA: {}'.format(mira_log))
- # Add aligned sentence pair to grammar extractor
+ # Align instance
alignment = self.aligner.align(source, target)
- logging.info('Adding instance: {} ||| {} ||| {}'.format(source, target, alignment))
+ # Store incremental data for save/load
+ self.inc_data.append((source, target, alignment))
+ # Add aligned sentence pair to grammar extractor
+ logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment))
self.extractor.add_instance(source, target, alignment)
# Clear (old) cached grammar
rm_grammar = self.grammar_dict.pop(source)
@@ -156,3 +178,37 @@ class RealtimeDecoder:
logging.info('Adding to HPYPLM: {}'.format(target))
self.ref_fifo.write('{}\n'.format(target))
self.ref_fifo.flush()
+
+ def save_state(self):
+ logging.info('Saving state with {} sentences'.format(len(self.inc_data)))
+ sys.stdout.write('{}\n'.format(self.decoder.get_weights()))
+ for (source, target, alignment) in self.inc_data:
+ sys.stdout.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
+ sys.stdout.write('EOF\n')
+
+ def load_state(self, input=sys.stdin):
+ # Non-initial load error
+ if self.inc_data:
+ logging.info('Error: Incremental data has already been added to decoder.')
+ logging.info(' State can only be loaded by a freshly started decoder.')
+ return
+ # MIRA weights
+ line = input.readline().strip()
+ self.decoder.set_weights(line)
+ logging.info('Loading state...')
+ start_time = time.time()
+ # Lines source ||| target ||| alignment
+ while True:
+ line = input.readline().strip()
+ if line == 'EOF':
+ break
+ (source, target, alignment) = line.split(' ||| ')
+ self.inc_data.append((source, target, alignment))
+ # Extractor
+ self.extractor.add_instance(source, target, alignment)
+ # HPYPLM
+ hyp = self.decoder.decode(LIKELY_OOV)
+ self.ref_fifo.write('{}\n'.format(target))
+ self.ref_fifo.flush()
+ stop_time = time.time()
+ logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time))