summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--realtime/rt/decoder.py2
-rw-r--r--realtime/rt/rt.py54
2 files changed, 35 insertions, 21 deletions
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index 5082911d..1bdd3f1f 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -38,7 +38,7 @@ class CdecDecoder(Decoder):
class MIRADecoder(Decoder):
- def __init__(self, config, weights, metric='bleu'):
+ def __init__(self, config, weights, metric='ibm_bleu'):
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
mira = os.path.join(cdec_root, 'training', 'mira', 'kbest_cut_mira')
# optimizer=2 step=0.001 best=500, k=500, uniq, stream, metric
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index c0aec410..27eeb3ca 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -20,40 +20,46 @@ import util
# Dummy input token that is unlikely to appear in normalized data (but no fatal errors if it does)
LIKELY_OOV = '(OOV)'
+# For parsing rt.ini
+TRUE = ('true', 'True', 'TRUE')
+
logger = logging.getLogger('rt')
class RealtimeDecoder:
'''Do not use directly unless you know what you're doing. Use RealtimeTranslator.'''
- def __init__(self, configdir, tmpdir):
-
+ def __init__(self, configdir, tmpdir, hpyplm=False, metric='ibm_bleu'):
+
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
self.tmp = tmpdir
os.mkdir(self.tmp)
# HPYPLM reference stream
- ref_fifo_file = os.path.join(self.tmp, 'ref.fifo')
- os.mkfifo(ref_fifo_file)
- self.ref_fifo = open(ref_fifo_file, 'w+')
- # Start with empty line (do not learn prior to first input)
- self.ref_fifo.write('\n')
- self.ref_fifo.flush()
+ self.hpyplm = hpyplm
+ if self.hpyplm:
+ ref_fifo_file = os.path.join(self.tmp, 'ref.fifo')
+ os.mkfifo(ref_fifo_file)
+ self.ref_fifo = open(ref_fifo_file, 'w+')
+ # Start with empty line (do not learn prior to first input)
+ self.ref_fifo.write('\n')
+ self.ref_fifo.flush()
# Decoder
decoder_config = [[f.strip() for f in line.split('=')] for line in open(os.path.join(configdir, 'cdec.ini'))]
- util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file)
+ util.cdec_ini_for_realtime(decoder_config, os.path.abspath(configdir), ref_fifo_file if self.hpyplm else None)
decoder_config_file = os.path.join(self.tmp, 'cdec.ini')
with open(decoder_config_file, 'w') as output:
for (k, v) in decoder_config:
output.write('{}={}\n'.format(k, v))
decoder_weights = os.path.join(configdir, 'weights.final')
- self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights)
+ self.decoder = decoder.MIRADecoder(decoder_config_file, decoder_weights, metric=metric)
def close(self, force=False):
logger.info('Closing decoder and removing {}'.format(self.tmp))
self.decoder.close(force)
- self.ref_fifo.close()
+ if self.hpyplm:
+ self.ref_fifo.close()
shutil.rmtree(self.tmp)
class RealtimeTranslator:
@@ -73,6 +79,11 @@ class RealtimeTranslator:
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+ # rt.ini options
+ ini = dict(line.strip().split('=') for line in open(os.path.join(configdir, 'rt.ini')))
+ self.hpyplm = (ini.get('hpyplm', 'false') in TRUE)
+ self.metric = ini.get('metric', 'ibm_bleu')
+
### Single instance for all contexts
self.config = configdir
@@ -160,7 +171,7 @@ class RealtimeTranslator:
self.grammar_files[ctx_name] = collections.deque()
self.grammar_dict[ctx_name] = {}
tmpdir = os.path.join(self.tmp, 'decoder.{}'.format(ctx_name))
- self.decoders[ctx_name] = RealtimeDecoder(self.config, tmpdir)
+ self.decoders[ctx_name] = RealtimeDecoder(self.config, tmpdir, hpyplm=self.hpyplm, metric=self.metric)
def drop_ctx(self, ctx_name=None, force=False):
'''Delete a context (inc stopping the decoder)
@@ -239,8 +250,9 @@ class RealtimeTranslator:
stop_time = time.time()
logger.info('({}) Translation time: {} seconds'.format(ctx_name, stop_time - start_time))
# Empty reference: HPYPLM does not learn prior to next translation
- decoder.ref_fifo.write('\n')
- decoder.ref_fifo.flush()
+ if self.hpyplm:
+ decoder.ref_fifo.write('\n')
+ decoder.ref_fifo.flush()
if self.norm:
logger.info('({}) Normalized translation: {}'.format(ctx_name, hyp))
hyp = self.detokenize(hyp)
@@ -301,9 +313,10 @@ class RealtimeTranslator:
mira_log = decoder.decoder.update(source, grammar_file, target)
logger.info('({}) MIRA HBF: {}'.format(ctx_name, mira_log))
# Add to HPYPLM by writing to fifo (read on next translation)
- logger.info('({}) Adding to HPYPLM: {}'.format(ctx_name, target))
- decoder.ref_fifo.write('{}\n'.format(target))
- decoder.ref_fifo.flush()
+ if self.hpyplm:
+ logger.info('({}) Adding to HPYPLM: {}'.format(ctx_name, target))
+ decoder.ref_fifo.write('{}\n'.format(target))
+ decoder.ref_fifo.flush()
# Store incremental data for save/load
self.ctx_data[ctx_name].append((source, target, alignment))
# Add aligned sentence pair to grammar extractor
@@ -381,9 +394,10 @@ class RealtimeTranslator:
# Extractor
self.extractor.add_instance(source, target, alignment, ctx_name)
# HPYPLM
- hyp = decoder.decoder.decode(LIKELY_OOV)
- decoder.ref_fifo.write('{}\n'.format(target))
- decoder.ref_fifo.flush()
+ if self.hpyplm:
+ hyp = decoder.decoder.decode(LIKELY_OOV)
+ decoder.ref_fifo.write('{}\n'.format(target))
+ decoder.ref_fifo.flush()
stop_time = time.time()
logger.info('({}) Loaded state with {} sentences in {} seconds'.format(ctx_name, len(ctx_data), stop_time - start_time))
lock.release()