summaryrefslogtreecommitdiff
path: root/realtime
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 22:05:32 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-17 22:05:32 -0700
commit6d427339b45f8aa74410437f91b5a01afc824120 (patch)
treeeb8914fe5fbd3d148610836f4f58fd660027230b /realtime
parenta30a2e59ff117ad6bae80ece2bf535767daf7db6 (diff)
Save/load state in realtime
Diffstat (limited to 'realtime')
-rwxr-xr-xrealtime/realtime.py48
-rw-r--r--realtime/rt/__init__.py2
-rw-r--r--realtime/rt/decoder.py11
-rw-r--r--realtime/rt/rt.py68
4 files changed, 97 insertions, 32 deletions
diff --git a/realtime/realtime.py b/realtime/realtime.py
index 554c52ca..a6ed1511 100755
--- a/realtime/realtime.py
+++ b/realtime/realtime.py
@@ -15,8 +15,9 @@ class Parser(argparse.ArgumentParser):
def main():
- parser = Parser(description='Real-time adaptive translation with cdec.')
- parser.add_argument('-c', '--config', required=True, help='Config directory (see README.md)')
+ parser = Parser(description='Real-time adaptive translation with cdec. (See README.md)')
+ parser.add_argument('-c', '--config', required=True, help='Config directory')
+ parser.add_argument('-s', '--state', help='Load state file (saved incremental data)')
parser.add_argument('-n', '--normalize', help='Normalize text (tokenize, translate, detokenize)', action='store_true')
parser.add_argument('-T', '--temp', help='Temp directory (default /tmp)', default='/tmp')
parser.add_argument('-a', '--cache', help='Grammar cache size (default 5)', default='5')
@@ -26,27 +27,28 @@ def main():
if args.verbose:
logging.basicConfig(level=logging.INFO)
- rtd = rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize)
-
- try:
- while True:
- line = sys.stdin.readline()
- if not line:
- break
- input = [f.strip() for f in line.split('|||')]
- if len(input) == 1:
- hyp = rtd.decode(input[0])
- sys.stdout.write('{}\n'.format(hyp))
- sys.stdout.flush()
- else:
- rtd.command(input)
-
- # Clean exit on ctrl+c
- except KeyboardInterrupt:
- logging.info('Caught KeyboardInterrupt, exiting')
-
- # Cleanup
- rtd.close()
+ with rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as rtd:
+
+ try:
+ # Load state if given
+ if args.state:
+ rtd.load_state(args.state)
+ # Read lines and commands
+ while True:
+ line = sys.stdin.readline()
+ if not line:
+ break
+ line = line.strip()
+ if '|||' in line:
+ rtd.command_line(line)
+ else:
+ hyp = rtd.decode(line)
+ sys.stdout.write('{}\n'.format(hyp))
+ sys.stdout.flush()
+
+ # Clean exit on ctrl+c
+ except KeyboardInterrupt:
+ logging.info('Caught KeyboardInterrupt, exiting')
if __name__ == '__main__':
main()
diff --git a/realtime/rt/__init__.py b/realtime/rt/__init__.py
index fbde8f4d..c76acc4d 100644
--- a/realtime/rt/__init__.py
+++ b/realtime/rt/__init__.py
@@ -9,7 +9,7 @@ except ImportError as ie:
sys.path.append(pycdec)
import cdec
except:
- sys.stderr.write('Error: cannot import pycdec. Please check the cdec/python is built.\n')
+ sys.stderr.write('Error: cannot import pycdec. Please check that cdec/python is built.\n')
raise ie
# Regular init imports
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index 57739d93..aa6db64d 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -9,8 +9,8 @@ class Decoder:
def close(self):
self.decoder.stdin.close()
- def decode(self, sentence, grammar):
- input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar)
+ def decode(self, sentence, grammar=None):
+ input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) if grammar else '{}\n'.format(sentence)
self.decoder.stdin.write(input)
return self.decoder.stdout.readline().strip()
@@ -33,6 +33,13 @@ class MIRADecoder(Decoder):
logging.info('Executing: {}'.format(' '.join(mira_cmd)))
self.decoder = util.popen_io(mira_cmd)
+ def get_weights(self):
+ self.decoder.stdin.write('WEIGHTS ||| WRITE\n')
+ return self.decoder.stdout.readline().strip()
+
+ def set_weights(self, w_line):
+ self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line))
+
def update(self, sentence, grammar, reference):
input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference)
self.decoder.stdin.write(input)
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 0ce05a56..fedc1fcf 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -15,14 +15,18 @@ import aligner
import decoder
import util
+LIKELY_OOV = '("OOV")'
+
class RealtimeDecoder:
- def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False):
+ def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None):
- self.commands = {'LEARN': self.learn}
+ self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state}
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ self.inc_data = [] # instances of (source, target)
+
# Temporary work dir
self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.')
logging.info('Using temp dir {}'.format(self.tmp))
@@ -68,6 +72,17 @@ class RealtimeDecoder:
decoder_weights = os.path.join(configdir, 'weights.final')
self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+ # Load state if given
+ if state:
+ with open(state) as input:
+ self.load_state(input)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
def close(self):
logging.info('Closing processes')
self.aligner.close()
@@ -128,9 +143,13 @@ class RealtimeDecoder:
self.detokenizer.stdin.write('{}\n'.format(line))
return self.detokenizer.stdout.readline().strip()
- def command(self, args):
+ def command_line(self, line):
+ args = [f.strip() for f in line.split('|||')]
try:
- self.commands[args[0]](*args[1:])
+ if len(args) == 2 and not args[1]:
+ self.commands[args[0]]()
+ else:
+ self.commands[args[0]](*args[1:])
except:
logging.info('Command error: {}'.format(' ||| '.join(args)))
@@ -145,9 +164,12 @@ class RealtimeDecoder:
grammar_file = self.grammar(source)
mira_log = self.decoder.update(source, grammar_file, target)
logging.info('MIRA: {}'.format(mira_log))
- # Add aligned sentence pair to grammar extractor
+ # Align instance
alignment = self.aligner.align(source, target)
- logging.info('Adding instance: {} ||| {} ||| {}'.format(source, target, alignment))
+ # Store incremental data for save/load
+ self.inc_data.append((source, target, alignment))
+ # Add aligned sentence pair to grammar extractor
+ logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment))
self.extractor.add_instance(source, target, alignment)
# Clear (old) cached grammar
rm_grammar = self.grammar_dict.pop(source)
@@ -156,3 +178,37 @@ class RealtimeDecoder:
logging.info('Adding to HPYPLM: {}'.format(target))
self.ref_fifo.write('{}\n'.format(target))
self.ref_fifo.flush()
+
+ def save_state(self):
+ logging.info('Saving state with {} sentences'.format(len(self.inc_data)))
+ sys.stdout.write('{}\n'.format(self.decoder.get_weights()))
+ for (source, target, alignment) in self.inc_data:
+ sys.stdout.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
+ sys.stdout.write('EOF\n')
+
+ def load_state(self, input=sys.stdin):
+ # Non-initial load error
+ if self.inc_data:
+ logging.info('Error: Incremental data has already been added to decoder.')
+ logging.info(' State can only be loaded by a freshly started decoder.')
+ return
+ # MIRA weights
+ line = input.readline().strip()
+ self.decoder.set_weights(line)
+ logging.info('Loading state...')
+ start_time = time.time()
+ # Lines source ||| target ||| alignment
+ while True:
+ line = input.readline().strip()
+ if line == 'EOF':
+ break
+ (source, target, alignment) = line.split(' ||| ')
+ self.inc_data.append((source, target, alignment))
+ # Extractor
+ self.extractor.add_instance(source, target, alignment)
+ # HPYPLM
+ hyp = self.decoder.decode(LIKELY_OOV)
+ self.ref_fifo.write('{}\n'.format(target))
+ self.ref_fifo.flush()
+ stop_time = time.time()
+ logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time))