summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/examples/rampion.py77
-rw-r--r--python/examples/test.py (renamed from python/test.py)0
2 files changed, 77 insertions, 0 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py
new file mode 100644
index 00000000..66d89a61
--- /dev/null
+++ b/python/examples/rampion.py
@@ -0,0 +1,77 @@
+import argparse
+import logging
+from itertools import izip
+import cdec, cdec.score
+
+def evaluate(hyp, ref):
+ """ Compute BLEU score for a set of hypotheses+references """
+ return sum(cdec.score.BLEU(r).evaluate(h) for h, r in izip(hyp, ref)).score
+
+T1, T2, T3 = 5, 10, 20 # number of iterations (global, CCCP, SSD)
+K = 500 # k-best list size
+C = 1 # regularization coefficient
+eta = 1e-4 # step size
+cost = lambda c: 10 * (1 - c.score) # cost definition
+
+def rampion(decoder, sources, references):
+ # Empty k-best lists
+ cs = [cdec.score.BLEU(refs).candidate_set() for refs in references]
+ # Weight vector -> sparse
+ w = decoder.weights.tosparse()
+ w0 = w.copy()
+
+ N = len(sources)
+ for t in range(T1):
+ logging.info('Iteration {0}: translating...'.format(t+1))
+ # Get the hypergraphs and extend the k-best lists
+ hgs = []
+ for src, candidates in izip(sources, cs):
+ hg = decoder.translate(src)
+ hgs.append(hg)
+ candidates.add_kbest(hg, K)
+ # BLEU score for the previous iteration
+ score = evaluate((hg.viterbi() for hg in hgs), references)
+ logging.info('BLEU: {:.2f}'.format(100 * score))
+ logging.info('Optimizing...')
+ for _ in range(T2):
+ # y_i^+, h_i^+; i=1..N
+ plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
+ for candidates in cs]
+ for _ in range(T3):
+ for fp, candidates in izip(plus, cs):
+ # y^-, h^-
+ fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
+ # update weights (line 11-12)
+ w += eta * ((fp - fm) - C/N * (w - w0))
+ logging.info('Updated weight vector: {0}'.format(dict(w)))
+ # Update decoder weights
+ for fname, fval in w:
+ decoder.weights[fname] = fval
+
+def main():
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', help='cdec config', required=True)
+ parser.add_argument('-w', '--weights', help='initial weights', required=True)
+ parser.add_argument('-r', '--reference', help='reference file', required=True)
+ parser.add_argument('-s', '--source', help='source file', required=True)
+ args = parser.parse_args()
+
+ with open(args.config) as fp:
+ config = fp.read()
+
+ decoder = cdec.Decoder(config)
+ decoder.read_weights(args.weights)
+ with open(args.reference) as fp:
+ references = fp.readlines()
+ with open(args.source) as fp:
+ sources = fp.readlines()
+ assert len(references) == len(sources)
+ rampion(decoder, sources, references)
+
+ for fname, fval in sorted(dict(decoder.weights).iteritems()):
+ print('{0}\t{1}'.format(fname, fval))
+
+if __name__ == '__main__':
+ main()
diff --git a/python/test.py b/python/examples/test.py
index eb9e6a95..eb9e6a95 100644
--- a/python/test.py
+++ b/python/examples/test.py