summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xrealtime/mkinput.py17
-rwxr-xr-xrealtime/realtime.py4
-rw-r--r--realtime/rt/decoder.py2
-rw-r--r--realtime/rt/rt.py8
-rw-r--r--training/mira/kbest_cut_mira.cc15
5 files changed, 40 insertions, 6 deletions
diff --git a/realtime/mkinput.py b/realtime/mkinput.py
new file mode 100755
index 00000000..897b44fd
--- /dev/null
+++ b/realtime/mkinput.py
@@ -0,0 +1,17 @@
+#!/usr/bin/env python
+
+import itertools
+import sys
+
+def main():
+
+ if len(sys.argv[1:]) != 2:
+ sys.stderr.write('usage: {} test.src test.ref >test.input\n'.format(sys.argv[0]))
+ sys.exit(2)
+
+ for (src, ref) in itertools.izip(open(sys.argv[1]), open(sys.argv[2])):
+ sys.stdout.write('{}'.format(src))
+ sys.stdout.write('LEARN ||| {} ||| {}'.format(src.strip(), ref))
+
+if __name__ == '__main__':
+ main()
diff --git a/realtime/realtime.py b/realtime/realtime.py
index dff7e90c..554c52ca 100755
--- a/realtime/realtime.py
+++ b/realtime/realtime.py
@@ -38,8 +38,8 @@ def main():
hyp = rtd.decode(input[0])
sys.stdout.write('{}\n'.format(hyp))
sys.stdout.flush()
- elif len(input) == 2:
- rtd.learn(*input)
+ else:
+ rtd.command(input)
# Clean exit on ctrl+c
except KeyboardInterrupt:
diff --git a/realtime/rt/decoder.py b/realtime/rt/decoder.py
index 34b5d391..57739d93 100644
--- a/realtime/rt/decoder.py
+++ b/realtime/rt/decoder.py
@@ -34,6 +34,6 @@ class MIRADecoder(Decoder):
self.decoder = util.popen_io(mira_cmd)
def update(self, sentence, grammar, reference):
- input = '<seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference)
+ input = 'LEARN ||| <seg grammar="{g}">{s}</seg> ||| {r}\n'.format(s=sentence, g=grammar, r=reference)
self.decoder.stdin.write(input)
return self.decoder.stdout.readline().strip()
diff --git a/realtime/rt/rt.py b/realtime/rt/rt.py
index 1b8ac58c..0ce05a56 100644
--- a/realtime/rt/rt.py
+++ b/realtime/rt/rt.py
@@ -19,6 +19,8 @@ class RealtimeDecoder:
def __init__(self, configdir, tmpdir='/tmp', cache_size=5, norm=False):
+ self.commands = {'LEARN': self.learn}
+
cdec_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Temporary work dir
@@ -126,6 +128,12 @@ class RealtimeDecoder:
self.detokenizer.stdin.write('{}\n'.format(line))
return self.detokenizer.stdout.readline().strip()
+ def command(self, args):
+ try:
+ self.commands[args[0]](*args[1:])
+ except:
+ logging.info('Command error: {}'.format(' ||| '.join(args)))
+
def learn(self, source, target):
if '' in (source.strip(), target.strip()):
logging.info('Error empty source or target: {} ||| {}'.format(source, target))
diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc
index e4435abb..a9a4aeb6 100644
--- a/training/mira/kbest_cut_mira.cc
+++ b/training/mira/kbest_cut_mira.cc
@@ -734,10 +734,19 @@ int main(int argc, char** argv) {
ViterbiESentence(bobs.hypergraph[0], &trans);
cout << TD::GetString(trans) << endl;
continue;
- // Translate and update (normal MIRA)
+ // Special command:
+ // CMD ||| arg1 ||| arg2 ...
} else {
- ds->update(buf.substr(delim + 5));
- buf = buf.substr(0, delim);
+ string cmd = buf.substr(0, delim);
+ buf = buf.substr(delim + 5);
+ // Translate and update (normal MIRA)
+ // LEARN ||| source ||| reference
+ if (cmd == "LEARN") {
+ delim = buf.find(" ||| ");
+ ds->update(buf.substr(delim + 5));
+ buf = buf.substr(0, delim);
+ }
+ // TODO: additional commands
}
}
//TODO: allow batch updating