summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-25 16:20:51 -0700
committerMichael Denkowski <mdenkows@cs.cmu.edu>2013-09-25 16:20:51 -0700
commitbd2fe67ac2e2f7c22bf279aeef5439820329e6dc (patch)
tree5cbf8f8d45b84f383504941e02a844df46d19985
parent17497f2e77e63e6aa549eedc279cac46cfd25e2b (diff)
Super multi-user thread safety update
-rwxr-xr-xrealtime/realtime.py43
-rw-r--r--realtime/rt/aligner.py11
-rw-r--r--realtime/rt/decoder.py32
-rw-r--r--realtime/rt/rt.py250
4 files changed, 241 insertions, 95 deletions
diff --git a/realtime/realtime.py b/realtime/realtime.py
index 3c384fa2..282d3311 100755
--- a/realtime/realtime.py
+++ b/realtime/realtime.py
@@ -2,7 +2,9 @@
import argparse
import logging
+import signal
import sys
+import threading
import rt
@@ -22,34 +24,37 @@ def main():
parser.add_argument('-T', '--temp', help='Temp directory (default /tmp)', default='/tmp')
parser.add_argument('-a', '--cache', help='Grammar cache size (default 5)', default='5')
parser.add_argument('-v', '--verbose', help='Info to stderr', action='store_true')
+ parser.add_argument('-D', '--debug-test', help='Test thread safety (debug use only)', action='store_true')
args = parser.parse_args()
if args.verbose:
logging.basicConfig(level=logging.INFO)
- with rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as rtd:
+ with rt.RealtimeTranslator(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as translator:
- try:
# Load state if given
if args.state:
with open(args.state) as input:
rtd.load_state(input)
- # Read lines and commands
- while True:
- line = sys.stdin.readline()
- if not line:
- break
- line = line.strip()
- if '|||' in line:
- rtd.command_line(line)
- else:
- hyp = rtd.decode(line)
- sys.stdout.write('{}\n'.format(hyp))
- sys.stdout.flush()
-
- # Clean exit on ctrl+c
- except KeyboardInterrupt:
- logging.info('Caught KeyboardInterrupt, exiting')
-
+ if not args.debug_test:
+ run(translator)
+ else:
+ # TODO: write test
+ run(translator)
+
+def run(translator, input=sys.stdin, output=sys.stdout, ctx_name=None):
+ # Read lines and commands
+ while True:
+ line = input.readline()
+ if not line:
+ break
+ line = line.strip()
+ if '|||' in line:
+ translator.command_line(line, ctx_name)
+ else:
+ hyp = translator.decode(line, ctx_name)
+ output.write('{}\n'.format(hyp))
+ output.flush()
+
if __name__ == '__main__':
main()
diff --git a/realtime/rt/aligner.py b/realtime/rt/aligner.py
index 80835412..a14121db 100644
--- a/realtime/rt/aligner.py
+++ b/realtime/rt/aligner.py
@@ -2,6 +2,7 @@ import logging
import os
import sys
import subprocess
+import threading
import util
@@ -29,10 +30,16 @@ class ForceAligner:
logging.info('Executing: {}'.format(' '.join(tools_cmd)))
self.tools = util.popen_io(tools_cmd)
+ # Used to guarantee thread safety
+ self.semaphore = threading.Semaphore()
+
def align(self, source, target):
+ '''Threadsafe'''
return self.align_formatted('{} ||| {}'.format(source, target))
def align_formatted(self, line):
+ '''Threadsafe'''
+ self.semaphore.acquire()
self.fwd_align.stdin.write('{}\n'.format(line))
self.rev_align.stdin.write('{}\n'.format(line))
# f words ||| e words ||| links ||| score
@@ -40,7 +47,9 @@ class ForceAligner:
rev_line = self.rev_align.stdout.readline().split('|||')[2].strip()
self.tools.stdin.write('{}\n'.format(fwd_line))
self.tools.stdin.write('{}\n'.format(rev_line))
- return self.tools.stdout.readline().strip()
+ al_line = self.tools.stdout.readline().strip()
+ self.semaphore.release()
+ return al_line
def close(self):
self.fwd_align.stdin.close()
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index aa6db64d..72b5b959 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -1,27 +1,37 @@
import logging
import os
import subprocess
+import threading
import util
class Decoder:
- def close(self):
+ def close(self, force=False):
+ if not force:
+ self.semaphore.acquire()
self.decoder.stdin.close()
+ if not force:
+ self.semaphore.release()
def decode(self, sentence, grammar=None):
+ '''Threadsafe'''
input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) if grammar else '{}\n'.format(sentence)
+ self.semaphore.acquire()
self.decoder.stdin.write(input)
- return self.decoder.stdout.readline().strip()
+ hyp = self.decoder.stdout.readline().strip()
+ self.semaphore.release()
+ return hyp
class CdecDecoder(Decoder):
-
+
def __init__(self, config, weights):
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
decoder = os.path.join(cdec_root, 'decoder', 'cdec')
decoder_cmd = [decoder, '-c', config, '-w', weights]
logging.info('Executing: {}'.format(' '.join(decoder_cmd)))
self.decoder = util.popen_io(decoder_cmd)
+ self.semaphore = threading.Semaphore()
class MIRADecoder(Decoder):
@@ -32,15 +42,27 @@ class MIRADecoder(Decoder):
mira_cmd = [mira, '-c', config, '-w', weights, '-o', '2', '-C', '0.001', '-b', '500', '-k', '500', '-u', '-t']
logging.info('Executing: {}'.format(' '.join(mira_cmd)))
self.decoder = util.popen_io(mira_cmd)
+ self.semaphore = threading.Semaphore()
def get_weights(self):
+ '''Threadsafe'''
+ self.semaphore.acquire()
self.decoder.stdin.write('WEIGHTS ||| WRITE\n')
- return self.decoder.stdout.readline().strip()
+ weights = self.decoder.stdout.readline().strip()
+ self.semaphore.release()
+ return weights
def set_weights(self, w_line):
+ '''Threadsafe'''
+ self.semaphore.acquire()
self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line))
+ self.semaphore.release()
def update(self, sentence, grammar, reference):
+ '''Threadsafe'''
input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference)
+ self.semaphore.acquire()
self.decoder.stdin.write(input)
- return self.decoder.stdout.readline().strip()
+ log = self.decoder.stdout.readline().strip()
+ self.semaphore.release()
+ return log
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 033ed790..6f1fb70f 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -8,6 +8,7 @@ import shutil
import sys
import subprocess
import tempfile
+import threading
import time
import cdec
@@ -15,18 +16,56 @@ import aligner
import decoder
import util
-LIKELY_OOV = '("OOV")'
+# Dummy input token that is unlikely to appear in normalized data (but no fatal errors if it does)
+LIKELY_OOV = '(OOV)'
class RealtimeDecoder:
+ '''Do not use directly unless you know what you're doing. Use RealtimeTranslator.'''
+
+ def __init__(self, configdir, tmpdir):
+
+ cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+ self.tmp = tmpdir
+ os.mkdir(self.tmp)
+
+ # HPYPLM reference stream
+ ref_fifo_file = os.path.join(self.tmp, 'ref.fifo')
+ os.mkfifo(ref_fifo_file)
+ self.ref_fifo = open(ref_fifo_file, 'w+')
+ # Start with empty line (do not learn prior to first input)
+ self.ref_fifo.write('\n')
+ self.ref_fifo.flush()
+
+ # Decoder
+ decoder_config = [[f.strip() for f in line.split('=')] for line in open(os.path.join(configdir, 'cdec.ini'))]
+ util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file)
+ decoder_config_file = os.path.join(self.tmp, 'cdec.ini')
+ with open(decoder_config_file, 'w') as output:
+ for (k, v) in decoder_config:
+ output.write('{}={}\n'.format(k, v))
+ decoder_weights = os.path.join(configdir, 'weights.final')
+ self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+
+ def close(self, force=False):
+ logging.info('Closing decoder and removing {}'.format(self.tmp))
+ self.decoder.close(force)
+ self.ref_fifo.close()
+ shutil.rmtree(self.tmp)
+
+class RealtimeTranslator:
+ '''Main entry point into API: serves translations to any number of concurrent users'''
def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None):
+ # TODO: save/load
self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state}
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- self.inc_data = [] # instances of (source, target)
+ ### Single instance for all contexts
+ self.config = configdir
# Temporary work dir
self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.')
logging.info('Using temp dir {}'.format(self.tmp))
@@ -35,7 +74,9 @@ class RealtimeDecoder:
self.norm = norm
if self.norm:
self.tokenizer = util.popen_io([os.path.join(cdec_root, 'corpus', 'tokenize-anything.sh'), '-u'])
+ self.tokenizer_sem = threading.Semaphore()
self.detokenizer = util.popen_io([os.path.join(cdec_root, 'corpus', 'untok.pl')])
+ self.detokenizer_sem = threading.Semaphore()
# Word aligner
fwd_params = os.path.join(configdir, 'a.fwd_params')
@@ -50,28 +91,24 @@ class RealtimeDecoder:
util.sa_ini_for_realtime(sa_config, os.path.abspath(configdir))
sa_config.write()
self.extractor = cdec.sa.GrammarExtractor(sa_config.filename, online=True)
- self.grammar_files = collections.deque()
- self.grammar_dict = {}
self.cache_size = cache_size
- # HPYPLM reference stream
- ref_fifo_file = os.path.join(self.tmp, 'ref.fifo')
- os.mkfifo(ref_fifo_file)
- self.ref_fifo = open(ref_fifo_file, 'w+')
- # Start with empty line (do not learn prior to first input)
- self.ref_fifo.write('\n')
- self.ref_fifo.flush()
+ ### One instance per context
- # Decoder
- decoder_config = [[f.strip() for f in line.split('=')] for line in open(os.path.join(configdir, 'cdec.ini'))]
- util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file)
- decoder_config_file = os.path.join(self.tmp, 'cdec.ini')
- with open(decoder_config_file, 'w') as output:
- for (k, v) in decoder_config:
- output.write('{}={}\n'.format(k, v))
- decoder_weights = os.path.join(configdir, 'weights.final')
- self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+ self.ctx_names = set()
+ # All context-dependent operations are atomic
+ self.ctx_sems = collections.defaultdict(threading.Semaphore)
+ # ctx -> list of (source, target, alignment)
+ self.ctx_data = {}
+
+ # ctx -> deque of file
+ self.grammar_files = {}
+ # ctx -> dict of {sentence: file}
+ self.grammar_dict = {}
+ self.decoders = {}
+
+ # TODO: state
# Load state if given
if state:
with open(state) as input:
@@ -80,125 +117,197 @@ class RealtimeDecoder:
def __enter__(self):
return self
- def __exit__(self, type, value, traceback):
- self.close()
+ def __exit__(self, ex_type, ex_value, ex_traceback):
+ self.close(ex_type is KeyboardInterrupt)
- def close(self):
+ def close(self, force=False):
+ '''Cleanup'''
+ if force:
+ logging.info('Forced shutdown: stopping immediately')
+ for ctx_name in list(self.ctx_names):
+ self.drop_ctx(ctx_name, force)
logging.info('Closing processes')
self.aligner.close()
- self.decoder.close()
- self.ref_fifo.close()
if self.norm:
self.tokenizer.stdin.close()
self.detokenizer.stdin.close()
logging.info('Deleting {}'.format(self.tmp))
shutil.rmtree(self.tmp)
- def grammar(self, sentence):
- grammar_file = self.grammar_dict.get(sentence, None)
+ def lazy_ctx(self, ctx_name):
+ '''Initialize a context (inc starting a new decoder) if needed'''
+ self.ctx_sems[ctx_name].acquire()
+ if ctx_name in self.ctx_names:
+ self.ctx_sems[ctx_name].release()
+ return
+ logging.info('New context: {}'.format(ctx_name))
+ self.ctx_names.add(ctx_name)
+ self.ctx_data[ctx_name] = []
+ self.grammar_files[ctx_name] = collections.deque()
+ self.grammar_dict[ctx_name] = {}
+ tmpdir = os.path.join(self.tmp, 'decoder.{}'.format(ctx_name))
+ self.decoders[ctx_name] = RealtimeDecoder(self.config, tmpdir)
+ self.ctx_sems[ctx_name].release()
+
+ def drop_ctx(self, ctx_name, force=False):
+ '''Delete a context (inc stopping the decoder)'''
+ if not force:
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
+ logging.info('Dropping context: {}'.format(ctx_name))
+ self.ctx_names.remove(ctx_name)
+ self.ctx_data.pop(ctx_name)
+ self.extractor.drop_ctx(ctx_name)
+ self.grammar_files.pop(ctx_name)
+ self.grammar_dict.pop(ctx_name)
+ self.decoders.pop(ctx_name).close(force)
+ self.ctx_sems.pop(ctx_name)
+ if not force:
+ sem.release()
+
+ def grammar(self, sentence, ctx_name=None):
+ '''Extract a sentence-level grammar on demand (or return cached)'''
+ self.lazy_ctx(ctx_name)
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
+ grammar_dict = self.grammar_dict[ctx_name]
+ grammar_file = grammar_dict.get(sentence, None)
# Cache hit
if grammar_file:
- logging.info('Grammar cache hit')
+ logging.info('Grammar cache hit: {}'.format(grammar_file))
+ sem.release()
return grammar_file
# Extract and cache
- (fid, grammar_file) = tempfile.mkstemp(dir=self.tmp, prefix='grammar.')
+ (fid, grammar_file) = tempfile.mkstemp(dir=self.decoders[ctx_name].tmp, prefix='grammar.')
os.close(fid)
with open(grammar_file, 'w') as output:
- for rule in self.extractor.grammar(sentence):
+ for rule in self.extractor.grammar(sentence, ctx_name):
output.write('{}\n'.format(str(rule)))
- if len(self.grammar_files) == self.cache_size:
- rm_sent = self.grammar_files.popleft()
+ grammar_files = self.grammar_files[ctx_name]
+ if len(grammar_files) == self.cache_size:
+ rm_sent = grammar_files.popleft()
# If not already removed by learn method
- if rm_sent in self.grammar_dict:
- rm_grammar = self.grammar_dict.pop(rm_sent)
+ if rm_sent in grammar_dict:
+ rm_grammar = grammar_dict.pop(rm_sent)
os.remove(rm_grammar)
- self.grammar_files.append(sentence)
- self.grammar_dict[sentence] = grammar_file
+ grammar_files.append(sentence)
+ grammar_dict[sentence] = grammar_file
+ sem.release()
return grammar_file
- def decode(self, sentence):
+ def decode(self, sentence, ctx_name=None):
+ '''Decode a sentence (inc extracting a grammar if needed)'''
+ self.lazy_ctx(ctx_name)
# Empty in, empty out
if sentence.strip() == '':
return ''
if self.norm:
sentence = self.tokenize(sentence)
logging.info('Normalized input: {}'.format(sentence))
- grammar_file = self.grammar(sentence)
+ # grammar method is threadsafe
+ grammar_file = self.grammar(sentence, ctx_name)
+ decoder = self.decoders[ctx_name]
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
start_time = time.time()
- hyp = self.decoder.decode(sentence, grammar_file)
+ hyp = decoder.decoder.decode(sentence, grammar_file)
stop_time = time.time()
logging.info('Translation time: {} seconds'.format(stop_time - start_time))
# Empty reference: HPYPLM does not learn prior to next translation
- self.ref_fifo.write('\n')
- self.ref_fifo.flush()
+ decoder.ref_fifo.write('\n')
+ decoder.ref_fifo.flush()
+ sem.release()
if self.norm:
logging.info('Normalized translation: {}'.format(hyp))
hyp = self.detokenize(hyp)
return hyp
def tokenize(self, line):
+ self.tokenizer_sem.acquire()
self.tokenizer.stdin.write('{}\n'.format(line))
- return self.tokenizer.stdout.readline().strip()
+ tok_line = self.tokenizer.stdout.readline().strip()
+ self.tokenizer_sem.release()
+ return tok_line
def detokenize(self, line):
+ self.detokenizer_sem.acquire()
self.detokenizer.stdin.write('{}\n'.format(line))
- return self.detokenizer.stdout.readline().strip()
+ detok_line = self.detokenizer.stdout.readline().strip()
+ self.detokenizer_sem.release()
+ return detok_line
- def command_line(self, line):
+ # TODO
+ def command_line(self, line, ctx_name=None):
args = [f.strip() for f in line.split('|||')]
try:
if len(args) == 2 and not args[1]:
- self.commands[args[0]]()
+ self.commands[args[0]](ctx_name)
else:
- self.commands[args[0]](*args[1:])
+ self.commands[args[0]](*args[1:], ctx_name=ctx_name)
except:
logging.info('Command error: {}'.format(' ||| '.join(args)))
- def learn(self, source, target):
+ def learn(self, source, target, ctx_name=None):
+ self.lazy_ctx(ctx_name)
if '' in (source.strip(), target.strip()):
logging.info('Error empty source or target: {} ||| {}'.format(source, target))
return
if self.norm:
source = self.tokenize(source)
target = self.tokenize(target)
+ # Align instance (threadsafe)
+ alignment = self.aligner.align(source, target)
+ # grammar method is threadsafe
+ grammar_file = self.grammar(source, ctx_name)
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
# MIRA update before adding data to grammar extractor
- grammar_file = self.grammar(source)
- mira_log = self.decoder.update(source, grammar_file, target)
+ decoder = self.decoders[ctx_name]
+ mira_log = decoder.decoder.update(source, grammar_file, target)
logging.info('MIRA: {}'.format(mira_log))
- # Align instance
- alignment = self.aligner.align(source, target)
+ # Add to HPYPLM by writing to fifo (read on next translation)
+ logging.info('Adding to HPYPLM: {}'.format(target))
+ decoder.ref_fifo.write('{}\n'.format(target))
+ decoder.ref_fifo.flush()
# Store incremental data for save/load
- self.inc_data.append((source, target, alignment))
+ self.ctx_data[ctx_name].append((source, target, alignment))
# Add aligned sentence pair to grammar extractor
logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment))
- self.extractor.add_instance(source, target, alignment)
+ self.extractor.add_instance(source, target, alignment, ctx_name)
# Clear (old) cached grammar
- rm_grammar = self.grammar_dict.pop(source)
+ rm_grammar = self.grammar_dict[ctx_name].pop(source)
os.remove(rm_grammar)
- # Add to HPYPLM by writing to fifo (read on next translation)
- logging.info('Adding to HPYPLM: {}'.format(target))
- self.ref_fifo.write('{}\n'.format(target))
- self.ref_fifo.flush()
+ sem.release()
- def save_state(self, filename=None):
+ def save_state(self, filename=None, ctx_name=None):
+ self.lazy_ctx(ctx_name)
out = open(filename, 'w') if filename else sys.stdout
- logging.info('Saving state with {} sentences'.format(len(self.inc_data)))
- out.write('{}\n'.format(self.decoder.get_weights()))
- for (source, target, alignment) in self.inc_data:
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
+ ctx_data = self.ctx_data[ctx_name]
+ logging.info('Saving state with {} sentences'.format(len(self.ctx_data)))
+ out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights()))
+ for (source, target, alignment) in ctx_data:
out.write('{} ||| {} ||| {}\n'.format(source, target, alignment))
+ sem.release()
out.write('EOF\n')
if filename:
out.close()
- def load_state(self, input=sys.stdin):
- # Non-initial load error
- if self.inc_data:
+ def load_state(self, input=sys.stdin, ctx_name=None):
+ self.lazy_ctx(ctx_name)
+ sem = self.ctx_sems[ctx_name]
+ sem.acquire()
+ ctx_data = self.ctx_data[ctx_name]
+ decoder = self.decoders[ctx_name]
+ # Non-initial load error
+ if ctx_data:
logging.info('Error: Incremental data has already been added to decoder.')
logging.info(' State can only be loaded by a freshly started decoder.')
return
# MIRA weights
line = input.readline().strip()
- self.decoder.set_weights(line)
+ decoder.decoder.set_weights(line)
logging.info('Loading state...')
start_time = time.time()
# Lines source ||| target ||| alignment
@@ -207,12 +316,13 @@ class RealtimeDecoder:
if line == 'EOF':
break
(source, target, alignment) = line.split(' ||| ')
- self.inc_data.append((source, target, alignment))
+ ctx_data.append((source, target, alignment))
# Extractor
- self.extractor.add_instance(source, target, alignment)
+ self.extractor.add_instance(source, target, alignment, ctx_name)
# HPYPLM
- hyp = self.decoder.decode(LIKELY_OOV)
+ hyp = decoder.decoder.decode(LIKELY_OOV)
self.ref_fifo.write('{}\n'.format(target))
self.ref_fifo.flush()
stop_time = time.time()
- logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time))
+ logging.info('Loaded state with {} sentences in {} seconds'.format(len(ctx_data), stop_time - start_time))
+ sem.release()