summaryrefslogtreecommitdiff
path: root/realtime/rt
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 22:05:32 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 22:05:32 -0700
commit6d427339b45f8aa74410437f91b5a01afc824120 (patch)
treeeb8914fe5fbd3d148610836f4f58fd660027230b /realtime/rt
parenta30a2e59ff117ad6bae80ece2bf535767daf7db6 (diff)
Save/load state in realtime
Diffstat (limited to 'realtime/rt')
-rw-r--r--realtime/rt/__init__.py2
-rw-r--r--realtime/rt/decoder.py11
-rw-r--r--realtime/rt/rt.py68
3 files changed, 72 insertions, 9 deletions
diff --git a/realtime/rt/__init__.py b/realtime/rt/__init__.py
index fbde8f4d..c76acc4d 100644
--- a/realtime/rt/__init__.py
+++ b/realtime/rt/__init__.py
@@ -9,7 +9,7 @@ except ImportError as ie:
sys.path.append(pycdec)
import cdec
except:
- sys.stderr.write('Error: cannot import pycdec. Please check the cdec/python is built.\n')
+ sys.stderr.write('Error: cannot import pycdec. Please check that cdec/python is built.\n')
raise ie
# Regular init imports
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index 57739d93..aa6db64d 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -9,8 +9,8 @@ class Decoder:
def close(self):
self.decoder.stdin.close()
- def decode(self, sentence, grammar):
- input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar)
+ def decode(self, sentence, grammar=None):
+ input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) if grammar else '{}\n'.format(sentence)
self.decoder.stdin.write(input)
return self.decoder.stdout.readline().strip()
@@ -33,6 +33,13 @@ class MIRADecoder(Decoder):
logging.info('Executing: {}'.format(' '.join(mira_cmd)))
self.decoder = util.popen_io(mira_cmd)
+ def get_weights(self):
+ self.decoder.stdin.write('WEIGHTS ||| WRITE\n')
+ return self.decoder.stdout.readline().strip()
+
+ def set_weights(self, w_line):
+ self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line))
+
def update(self, sentence, grammar, reference):
input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference)
self.decoder.stdin.write(input)
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 0ce05a56..fedc1fcf 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -15,14 +15,18 @@ import aligner
import decoder
import util
+LIKELY_OOV = '("OOV")'
+
class RealtimeDecoder:
- def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False):
+ def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None):
- self.commands = {'LEARN': self.learn}
+ self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state}
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ self.inc_data = [] # instances of (source, target)
+
# Temporary work dir
self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.')
logging.info('Using temp dir {}'.format(self.tmp))
@@ -68,6 +72,17 @@ class RealtimeDecoder:
decoder_weights = os.path.join(configdir, 'weights.final')
self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+ # Load state if given
+ if state:
+ with open(state) as input:
+ self.load_state(input)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
def close(self):
logging.info('Closing processes')
self.aligner.close()
@@ -128,9 +143,13 @@ class RealtimeDecoder:
self.detokenizer.stdin.write('{}\n'.format(line))
return self.detokenizer.stdout.readline().strip()
- def command(self, args):
+ def command_line(self, line):
+ args = [f.strip() for f in line.split('|||')]
try:
- self.commands[args[0]](*args[1:])
+ if len(args) == 2 and not args[1]:
+ self.commands[args[0]]()
+ else:
+ self.commands[args[0]](*args[1:])
except:
logging.info('Command error: {}'.format(' ||| '.join(args)))
@@ -145,9 +164,12 @@ class RealtimeDecoder:
grammar_file = self.grammar(source)
mira_log = self.decoder.update(source, grammar_file, target)
logging.info('MIRA: {}'.format(mira_log))
- # Add aligned sentence pair to grammar extractor
+ # Align instance
alignment = self.aligner.align(source, target)
- logging.info('Adding instance: {} ||| {} ||| {}'.format(source, target, alignment))
+ # Store incremental data for save/load
+ self.inc_data.append((source, target, alignment))
+ # Add aligned sentence pair to grammar extractor
+ logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment))
self.extractor.add_instance(source, target, alignment)
# Clear (old) cached grammar
rm_grammar = self.grammar_dict.pop(source)
@@ -156,3 +178,37 @@ class RealtimeDecoder:
logging.info('Adding to HPYPLM: {}'.format(target))
self.ref_fifo.write('{}\n'.format(target))
self.ref_fifo.flush()
+
+ def save_state(self):
+ logging.info('Saving state with {} sentences'.format(len(self.inc_data)))
+ sys.stdout.write('{}\n'.format(self.decoder.get_weights()))
+ for (source, target, alignment) in self.inc_data:
+ sys.stdout.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
+ sys.stdout.write('EOF\n')
+
+ def load_state(self, input=sys.stdin):
+ # Non-initial load error
+ if self.inc_data:
+ logging.info('Error: Incremental data has already been added to decoder.')
+ logging.info(' State can only be loaded by a freshly started decoder.')
+ return
+ # MIRA weights
+ line = input.readline().strip()
+ self.decoder.set_weights(line)
+ logging.info('Loading state...')
+ start_time = time.time()
+ # Lines source ||| target ||| alignment
+ while True:
+ line = input.readline().strip()
+ if line == 'EOF':
+ break
+ (source, target, alignment) = line.split(' ||| ')
+ self.inc_data.append((source, target, alignment))
+ # Extractor
+ self.extractor.add_instance(source, target, alignment)
+ # HPYPLM
+ hyp = self.decoder.decode(LIKELY_OOV)
+ self.ref_fifo.write('{}\n'.format(target))
+ self.ref_fifo.flush()
+ stop_time = time.time()
+ logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time))