diff options
author | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-25 16:20:51 -0700 |
---|---|---|
committer | Michael Denkowski <mdenkows@cs.cmu.edu> | 2013-09-25 16:20:51 -0700 |
commit | 5866bdb0541bf136d897cc8ecc72c5ed4b6a93ee (patch) | |
tree | bd48d4a27da8d5a7a24d13ae987a948374600da9 | |
parent | 5684942eadd5b4c3fd54f4871d13975793a1f067 (diff) |
Super multi-user thread safety update
-rwxr-xr-x | realtime/realtime.py | 43 | ||||
-rw-r--r-- | realtime/rt/aligner.py | 11 | ||||
-rw-r--r-- | realtime/rt/decoder.py | 32 | ||||
-rw-r--r-- | realtime/rt/rt.py | 250 |
4 files changed, 241 insertions, 95 deletions
diff --git a/realtime/realtime.py b/realtime/realtime.py index 3c384fa2..282d3311 100755 --- a/realtime/realtime.py +++ b/realtime/realtime.py @@ -2,7 +2,9 @@ import argparse import logging +import signal import sys +import threading import rt @@ -22,34 +24,37 @@ def main(): parser.add_argument('-T', '--temp', help='Temp directory (default /tmp)', default='/tmp') parser.add_argument('-a', '--cache', help='Grammar cache size (default 5)', default='5') parser.add_argument('-v', '--verbose', help='Info to stderr', action='store_true') + parser.add_argument('-D', '--debug-test', help='Test thread safety (debug use only)', action='store_true') args = parser.parse_args() if args.verbose: logging.basicConfig(level=logging.INFO) - with rt.RealtimeDecoder(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as rtd: + with rt.RealtimeTranslator(args.config, tmpdir=args.temp, cache_size=int(args.cache), norm=args.normalize) as translator: - try: # Load state if given if args.state: with open(args.state) as input: rtd.load_state(input) - # Read lines and commands - while True: - line = sys.stdin.readline() - if not line: - break - line = line.strip() - if '|||' in line: - rtd.command_line(line) - else: - hyp = rtd.decode(line) - sys.stdout.write('{}\n'.format(hyp)) - sys.stdout.flush() - - # Clean exit on ctrl+c - except KeyboardInterrupt: - logging.info('Caught KeyboardInterrupt, exiting') - + if not args.debug_test: + run(translator) + else: + # TODO: write test + run(translator) + +def run(translator, input=sys.stdin, output=sys.stdout, ctx_name=None): + # Read lines and commands + while True: + line = input.readline() + if not line: + break + line = line.strip() + if '|||' in line: + translator.command_line(line, ctx_name) + else: + hyp = translator.decode(line, ctx_name) + output.write('{}\n'.format(hyp)) + output.flush() + if __name__ == '__main__': main() diff --git a/realtime/rt/aligner.py b/realtime/rt/aligner.py index 80835412..a14121db 100644 --- a/realtime/rt/aligner.py +++ b/realtime/rt/aligner.py @@ -2,6 +2,7 @@ import logging import os import sys import subprocess +import threading import util @@ -29,10 +30,16 @@ class ForceAligner: logging.info('Executing: {}'.format(' '.join(tools_cmd))) self.tools = util.popen_io(tools_cmd) + # Used to guarantee thread safety + self.semaphore = threading.Semaphore() + def align(self, source, target): + '''Threadsafe''' return self.align_formatted('{} ||| {}'.format(source, target)) def align_formatted(self, line): + '''Threadsafe''' + self.semaphore.acquire() self.fwd_align.stdin.write('{}\n'.format(line)) self.rev_align.stdin.write('{}\n'.format(line)) # f words ||| e words ||| links ||| score @@ -40,7 +47,9 @@ class ForceAligner: rev_line = self.rev_align.stdout.readline().split('|||')[2].strip() self.tools.stdin.write('{}\n'.format(fwd_line)) self.tools.stdin.write('{}\n'.format(rev_line)) - return self.tools.stdout.readline().strip() + al_line = self.tools.stdout.readline().strip() + self.semaphore.release() + return al_line def close(self): self.fwd_align.stdin.close() diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py index aa6db64d..72b5b959 100644 --- a/realtime/rt/decoder.py +++ b/realtime/rt/decoder.py @@ -1,27 +1,37 @@ import logging import os import subprocess +import threading import util class Decoder: - def close(self): + def close(self, force=False): + if not force: + self.semaphore.acquire() self.decoder.stdin.close() + if not force: + self.semaphore.release() def decode(self, sentence, grammar=None): + '''Threadsafe''' input = '<seg grammar="{g}">{s}</seg>\n'.format(s=sentence, g=grammar) if grammar else '{}\n'.format(sentence) + self.semaphore.acquire() self.decoder.stdin.write(input) - return self.decoder.stdout.readline().strip() + hyp = self.decoder.stdout.readline().strip() + self.semaphore.release() + return hyp class CdecDecoder(Decoder): - + def __init__(self, config, weights): cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) decoder = os.path.join(cdec_root, 'decoder', 'cdec') decoder_cmd = [decoder, '-c', config, '-w', weights] logging.info('Executing: {}'.format(' '.join(decoder_cmd))) self.decoder = util.popen_io(decoder_cmd) + self.semaphore = threading.Semaphore() class MIRADecoder(Decoder): @@ -32,15 +42,27 @@ class MIRADecoder(Decoder): mira_cmd = [mira, '-c', config, '-w', weights, '-o', '2', '-C', '0.001', '-b', '500', '-k', '500', '-u', '-t'] logging.info('Executing: {}'.format(' '.join(mira_cmd))) self.decoder = util.popen_io(mira_cmd) + self.semaphore = threading.Semaphore() def get_weights(self): + '''Threadsafe''' + self.semaphore.acquire() self.decoder.stdin.write('WEIGHTS ||| WRITE\n') - return self.decoder.stdout.readline().strip() + weights = self.decoder.stdout.readline().strip() + self.semaphore.release() + return weights def set_weights(self, w_line): + '''Threadsafe''' + self.semaphore.acquire() self.decoder.stdin.write('WEIGHTS ||| {}\n'.format(w_line)) + self.semaphore.release() def update(self, sentence, grammar, reference): + '''Threadsafe''' input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference) + self.semaphore.acquire() self.decoder.stdin.write(input) - return self.decoder.stdout.readline().strip() + log = self.decoder.stdout.readline().strip() + self.semaphore.release() + return log diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index 033ed790..6f1fb70f 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -8,6 +8,7 @@ import shutil import sys import subprocess import tempfile +import threading import time import cdec @@ -15,18 +16,56 @@ import aligner import decoder import util -LIKELY_OOV = '("OOV")' +# Dummy input token that is unlikely to appear in normalized data (but no fatal errors if it does) +LIKELY_OOV = '(OOV)' class RealtimeDecoder: + '''Do not use directly unless you know what you're doing. Use RealtimeTranslator.''' + + def __init__(self, configdir, tmpdir): + + cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + self.tmp = tmpdir + os.mkdir(self.tmp) + + # HPYPLM reference stream + ref_fifo_file = os.path.join(self.tmp, 'ref.fifo') + os.mkfifo(ref_fifo_file) + self.ref_fifo = open(ref_fifo_file, 'w+') + # Start with empty line (do not learn prior to first input) + self.ref_fifo.write('\n') + self.ref_fifo.flush() + + # Decoder + decoder_config = [[f.strip() for f in line.split('=')] for line in open(os.path.join(configdir, 'cdec.ini'))] + util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file) + decoder_config_file = os.path.join(self.tmp, 'cdec.ini') + with open(decoder_config_file, 'w') as output: + for (k, v) in decoder_config: + output.write('{}={}\n'.format(k, v)) + decoder_weights = os.path.join(configdir, 'weights.final') + self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights) + + def close(self, force=False): + logging.info('Closing decoder and removing {}'.format(self.tmp)) + self.decoder.close(force) + self.ref_fifo.close() + shutil.rmtree(self.tmp) + +class RealtimeTranslator: + '''Main entry point into API: serves translations to any number of concurrent users''' def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False, state=None): + # TODO: save/load self.commands = {'LEARN': self.learn, 'SAVE': self.save_state, 'LOAD': self.load_state} cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - self.inc_data = [] # instances of (source, target) + ### Single instance for all contexts + self.config = configdir # Temporary work dir self.tmp = tempfile.mkdtemp(dir=tmpdir, prefix='realtime.') logging.info('Using temp dir {}'.format(self.tmp)) @@ -35,7 +74,9 @@ class RealtimeDecoder: self.norm = norm if self.norm: self.tokenizer = util.popen_io([os.path.join(cdec_root, 'corpus', 'tokenize-anything.sh'), '-u']) + self.tokenizer_sem = threading.Semaphore() self.detokenizer = util.popen_io([os.path.join(cdec_root, 'corpus', 'untok.pl')]) + self.detokenizer_sem = threading.Semaphore() # Word aligner fwd_params = os.path.join(configdir, 'a.fwd_params') @@ -50,28 +91,24 @@ class RealtimeDecoder: util.sa_ini_for_realtime(sa_config, os.path.abspath(configdir)) sa_config.write() self.extractor = cdec.sa.GrammarExtractor(sa_config.filename, online=True) - self.grammar_files = collections.deque() - self.grammar_dict = {} self.cache_size = cache_size - # HPYPLM reference stream - ref_fifo_file = os.path.join(self.tmp, 'ref.fifo') - os.mkfifo(ref_fifo_file) - self.ref_fifo = open(ref_fifo_file, 'w+') - # Start with empty line (do not learn prior to first input) - self.ref_fifo.write('\n') - self.ref_fifo.flush() + ### One instance per context - # Decoder - decoder_config = [[f.strip() for f in line.split('=')] for line in open(os.path.join(configdir, 'cdec.ini'))] - util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file) - decoder_config_file = os.path.join(self.tmp, 'cdec.ini') - with open(decoder_config_file, 'w') as output: - for (k, v) in decoder_config: - output.write('{}={}\n'.format(k, v)) - decoder_weights = os.path.join(configdir, 'weights.final') - self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights) + self.ctx_names = set() + # All context-dependent operations are atomic + self.ctx_sems = collections.defaultdict(threading.Semaphore) + # ctx -> list of (source, target, alignment) + self.ctx_data = {} + + # ctx -> deque of file + self.grammar_files = {} + # ctx -> dict of {sentence: file} + self.grammar_dict = {} + self.decoders = {} + + # TODO: state # Load state if given if state: with open(state) as input: @@ -80,125 +117,197 @@ class RealtimeDecoder: def __enter__(self): return self - def __exit__(self, type, value, traceback): - self.close() + def __exit__(self, ex_type, ex_value, ex_traceback): + self.close(ex_type is KeyboardInterrupt) - def close(self): + def close(self, force=False): + '''Cleanup''' + if force: + logging.info('Forced shutdown: stopping immediately') + for ctx_name in list(self.ctx_names): + self.drop_ctx(ctx_name, force) logging.info('Closing processes') self.aligner.close() - self.decoder.close() - self.ref_fifo.close() if self.norm: self.tokenizer.stdin.close() self.detokenizer.stdin.close() logging.info('Deleting {}'.format(self.tmp)) shutil.rmtree(self.tmp) - def grammar(self, sentence): - grammar_file = self.grammar_dict.get(sentence, None) + def lazy_ctx(self, ctx_name): + '''Initialize a context (inc starting a new decoder) if needed''' + self.ctx_sems[ctx_name].acquire() + if ctx_name in self.ctx_names: + self.ctx_sems[ctx_name].release() + return + logging.info('New context: {}'.format(ctx_name)) + self.ctx_names.add(ctx_name) + self.ctx_data[ctx_name] = [] + self.grammar_files[ctx_name] = collections.deque() + self.grammar_dict[ctx_name] = {} + tmpdir = os.path.join(self.tmp, 'decoder.{}'.format(ctx_name)) + self.decoders[ctx_name] = RealtimeDecoder(self.config, tmpdir) + self.ctx_sems[ctx_name].release() + + def drop_ctx(self, ctx_name, force=False): + '''Delete a context (inc stopping the decoder)''' + if not force: + sem = self.ctx_sems[ctx_name] + sem.acquire() + logging.info('Dropping context: {}'.format(ctx_name)) + self.ctx_names.remove(ctx_name) + self.ctx_data.pop(ctx_name) + self.extractor.drop_ctx(ctx_name) + self.grammar_files.pop(ctx_name) + self.grammar_dict.pop(ctx_name) + self.decoders.pop(ctx_name).close(force) + self.ctx_sems.pop(ctx_name) + if not force: + sem.release() + + def grammar(self, sentence, ctx_name=None): + '''Extract a sentence-level grammar on demand (or return cached)''' + self.lazy_ctx(ctx_name) + sem = self.ctx_sems[ctx_name] + sem.acquire() + grammar_dict = self.grammar_dict[ctx_name] + grammar_file = grammar_dict.get(sentence, None) # Cache hit if grammar_file: - logging.info('Grammar cache hit') + logging.info('Grammar cache hit: {}'.format(grammar_file)) + sem.release() return grammar_file # Extract and cache - (fid, grammar_file) = tempfile.mkstemp(dir=self.tmp, prefix='grammar.') + (fid, grammar_file) = tempfile.mkstemp(dir=self.decoders[ctx_name].tmp, prefix='grammar.') os.close(fid) with open(grammar_file, 'w') as output: - for rule in self.extractor.grammar(sentence): + for rule in self.extractor.grammar(sentence, ctx_name): output.write('{}\n'.format(str(rule))) - if len(self.grammar_files) == self.cache_size: - rm_sent = self.grammar_files.popleft() + grammar_files = self.grammar_files[ctx_name] + if len(grammar_files) == self.cache_size: + rm_sent = grammar_files.popleft() # If not already removed by learn method - if rm_sent in self.grammar_dict: - rm_grammar = self.grammar_dict.pop(rm_sent) + if rm_sent in grammar_dict: + rm_grammar = grammar_dict.pop(rm_sent) os.remove(rm_grammar) - self.grammar_files.append(sentence) - self.grammar_dict[sentence] = grammar_file + grammar_files.append(sentence) + grammar_dict[sentence] = grammar_file + sem.release() return grammar_file - def decode(self, sentence): + def decode(self, sentence, ctx_name=None): + '''Decode a sentence (inc extracting a grammar if needed)''' + self.lazy_ctx(ctx_name) # Empty in, empty out if sentence.strip() == '': return '' if self.norm: sentence = self.tokenize(sentence) logging.info('Normalized input: {}'.format(sentence)) - grammar_file = self.grammar(sentence) + # grammar method is threadsafe + grammar_file = self.grammar(sentence, ctx_name) + decoder = self.decoders[ctx_name] + sem = self.ctx_sems[ctx_name] + sem.acquire() start_time = time.time() - hyp = self.decoder.decode(sentence, grammar_file) + hyp = decoder.decoder.decode(sentence, grammar_file) stop_time = time.time() logging.info('Translation time: {} seconds'.format(stop_time - start_time)) # Empty reference: HPYPLM does not learn prior to next translation - self.ref_fifo.write('\n') - self.ref_fifo.flush() + decoder.ref_fifo.write('\n') + decoder.ref_fifo.flush() + sem.release() if self.norm: logging.info('Normalized translation: {}'.format(hyp)) hyp = self.detokenize(hyp) return hyp def tokenize(self, line): + self.tokenizer_sem.acquire() self.tokenizer.stdin.write('{}\n'.format(line)) - return self.tokenizer.stdout.readline().strip() + tok_line = self.tokenizer.stdout.readline().strip() + self.tokenizer_sem.release() + return tok_line def detokenize(self, line): + self.detokenizer_sem.acquire() self.detokenizer.stdin.write('{}\n'.format(line)) - return self.detokenizer.stdout.readline().strip() + detok_line = self.detokenizer.stdout.readline().strip() + self.detokenizer_sem.release() + return detok_line - def command_line(self, line): + # TODO + def command_line(self, line, ctx_name=None): args = [f.strip() for f in line.split('|||')] try: if len(args) == 2 and not args[1]: - self.commands[args[0]]() + self.commands[args[0]](ctx_name) else: - self.commands[args[0]](*args[1:]) + self.commands[args[0]](*args[1:], ctx_name=ctx_name) except: logging.info('Command error: {}'.format(' ||| '.join(args))) - def learn(self, source, target): + def learn(self, source, target, ctx_name=None): + self.lazy_ctx(ctx_name) if '' in (source.strip(), target.strip()): logging.info('Error empty source or target: {} ||| {}'.format(source, target)) return if self.norm: source = self.tokenize(source) target = self.tokenize(target) + # Align instance (threadsafe) + alignment = self.aligner.align(source, target) + # grammar method is threadsafe + grammar_file = self.grammar(source, ctx_name) + sem = self.ctx_sems[ctx_name] + sem.acquire() # MIRA update before adding data to grammar extractor - grammar_file = self.grammar(source) - mira_log = self.decoder.update(source, grammar_file, target) + decoder = self.decoders[ctx_name] + mira_log = decoder.decoder.update(source, grammar_file, target) logging.info('MIRA: {}'.format(mira_log)) - # Align instance - alignment = self.aligner.align(source, target) + # Add to HPYPLM by writing to fifo (read on next translation) + logging.info('Adding to HPYPLM: {}'.format(target)) + decoder.ref_fifo.write('{}\n'.format(target)) + decoder.ref_fifo.flush() # Store incremental data for save/load - self.inc_data.append((source, target, alignment)) + self.ctx_data[ctx_name].append((source, target, alignment)) # Add aligned sentence pair to grammar extractor logging.info('Adding to bitext: {} ||| {} ||| {}'.format(source, target, alignment)) - self.extractor.add_instance(source, target, alignment) + self.extractor.add_instance(source, target, alignment, ctx_name) # Clear (old) cached grammar - rm_grammar = self.grammar_dict.pop(source) + rm_grammar = self.grammar_dict[ctx_name].pop(source) os.remove(rm_grammar) - # Add to HPYPLM by writing to fifo (read on next translation) - logging.info('Adding to HPYPLM: {}'.format(target)) - self.ref_fifo.write('{}\n'.format(target)) - self.ref_fifo.flush() + sem.release() - def save_state(self, filename=None): + def save_state(self, filename=None, ctx_name=None): + self.lazy_ctx(ctx_name) out = open(filename, 'w') if filename else sys.stdout - logging.info('Saving state with {} sentences'.format(len(self.inc_data))) - out.write('{}\n'.format(self.decoder.get_weights())) - for (source, target, alignment) in self.inc_data: + sem = self.ctx_sems[ctx_name] + sem.acquire() + ctx_data = self.ctx_data[ctx_name] + logging.info('Saving state with {} sentences'.format(len(self.ctx_data))) + out.write('{}\n'.format(self.decoders[ctx_name].decoder.get_weights())) + for (source, target, alignment) in ctx_data: out.write('{} ||| {} ||| {}\n'.format(source, target, alignment)) + sem.release() out.write('EOF\n') if filename: out.close() - def load_state(self, input=sys.stdin): - # Non-initial load error - if self.inc_data: + def load_state(self, input=sys.stdin, ctx_name=None): + self.lazy_ctx(ctx_name) + sem = self.ctx_sems[ctx_name] + sem.acquire() + ctx_data = self.ctx_data[ctx_name] + decoder = self.decoders[ctx_name] + # Non-initial load error + if ctx_data: logging.info('Error: Incremental data has already been added to decoder.') logging.info(' State can only be loaded by a freshly started decoder.') return # MIRA weights line = input.readline().strip() - self.decoder.set_weights(line) + decoder.decoder.set_weights(line) logging.info('Loading state...') start_time = time.time() # Lines source ||| target ||| alignment @@ -207,12 +316,13 @@ class RealtimeDecoder: if line == 'EOF': break (source, target, alignment) = line.split(' ||| ') - self.inc_data.append((source, target, alignment)) + ctx_data.append((source, target, alignment)) # Extractor - self.extractor.add_instance(source, target, alignment) + self.extractor.add_instance(source, target, alignment, ctx_name) # HPYPLM - hyp = self.decoder.decode(LIKELY_OOV) + hyp = decoder.decoder.decode(LIKELY_OOV) self.ref_fifo.write('{}\n'.format(target)) self.ref_fifo.flush() stop_time = time.time() - logging.info('Loaded state with {} sentences in {} seconds'.format(len(self.inc_data), stop_time - start_time)) + logging.info('Loaded state with {} sentences in {} seconds'.format(len(ctx_data), stop_time - start_time)) + sem.release() |