From 08be69abb923b74f7dc27712d6bef7f6e4a05377 Mon Sep 17 00:00:00 2001 From: Michael Denkowski Date: Sun, 15 Sep 2013 20:32:59 -0700 Subject: Move to using named commands --- realtime/mkinput.py | 17 +++++++++++++++++ realtime/realtime.py | 4 ++-- realtime/rt/decoder.py | 2 +- realtime/rt/rt.py | 8 ++++++++ training/mira/kbest_cut_mira.cc | 15 ++++++++++++--- 5 files changed, 40 insertions(+), 6 deletions(-) create mode 100755 realtime/mkinput.py diff --git a/realtime/mkinput.py b/realtime/mkinput.py new file mode 100755 index 00000000..897b44fd --- /dev/null +++ b/realtime/mkinput.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python + +import itertools +import sys + +def main(): + + if len(sys.argv[1:]) != 2: + sys.stderr.write('usage: {} test.src test.ref >test.input\n'.format(sys.argv[0])) + sys.exit(2) + + for (src, ref) in itertools.izip(open(sys.argv[1]), open(sys.argv[2])): + sys.stdout.write('{}'.format(src)) + sys.stdout.write('LEARN ||| {} ||| {}'.format(src.strip(), ref)) + +if __name__ == '__main__': + main() diff --git a/realtime/realtime.py b/realtime/realtime.py index dff7e90c..554c52ca 100755 --- a/realtime/realtime.py +++ b/realtime/realtime.py @@ -38,8 +38,8 @@ def main(): hyp = rtd.decode(input[0]) sys.stdout.write('{}\n'.format(hyp)) sys.stdout.flush() - elif len(input) == 2: - rtd.learn(*input) + else: + rtd.command(input) # Clean exit on ctrl+c except KeyboardInterrupt: diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py index 34b5d391..57739d93 100644 --- a/realtime/rt/decoder.py +++ b/realtime/rt/decoder.py @@ -34,6 +34,6 @@ class MIRADecoder(Decoder): self.decoder = util.popen_io(mira_cmd) def update(self, sentence, grammar, reference): - input = '{s} ||| {r}\n'.format(s=sentence, g=grammar, r=reference) + input = 'LEARN ||| {s} ||| {r}\n'.format(s=sentence, g=grammar, r=reference) self.decoder.stdin.write(input) return self.decoder.stdout.readline().strip() diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py index 1b8ac58c..0ce05a56 100644 --- a/realtime/rt/rt.py +++ b/realtime/rt/rt.py @@ -19,6 +19,8 @@ class RealtimeDecoder: def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False): + self.commands = {'LEARN': self.learn} + cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # Temporary work dir @@ -126,6 +128,12 @@ class RealtimeDecoder: self.detokenizer.stdin.write('{}\n'.format(line)) return self.detokenizer.stdout.readline().strip() + def command(self, args): + try: + self.commands[args[0]](*args[1:]) + except: + logging.info('Command error: {}'.format(' ||| '.join(args))) + def learn(self, source, target): if '' in (source.strip(), target.strip()): logging.info('Error empty source or target: {} ||| {}'.format(source, target)) diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index e4435abb..a9a4aeb6 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -734,10 +734,19 @@ int main(int argc, char** argv) { ViterbiESentence(bobs.hypergraph[0], &trans); cout << TD::GetString(trans) << endl; continue; - // Translate and update (normal MIRA) + // Special command: + // CMD ||| arg1 ||| arg2 ... } else { - ds->update(buf.substr(delim + 5)); - buf = buf.substr(0, delim); + string cmd = buf.substr(0, delim); + buf = buf.substr(delim + 5); + // Translate and update (normal MIRA) + // LEARN ||| source ||| reference + if (cmd == "LEARN") { + delim = buf.find(" ||| "); + ds->update(buf.substr(delim + 5)); + buf = buf.substr(0, delim); + } + // TODO: additional commands } } //TODO: allow batch updating -- cgit v1.2.3