diff options
author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-08-10 19:03:38 -0400 |
---|---|---|
committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-08-10 19:03:38 -0400 |
commit | c3d0668c17f45247e1fec6ffe31b807fbbba6674 (patch) | |
tree | e8d673346ddbacfab81de2204be1fea99a0aecbb /python/examples/rampion.py | |
parent | b6474b5cdbf870725371b32670c9dc28671e394c (diff) |
[python] Examples directory including Rampion
Diffstat (limited to 'python/examples/rampion.py')
-rw-r--r-- | python/examples/rampion.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py new file mode 100644 index 00000000..66d89a61 --- /dev/null +++ b/python/examples/rampion.py @@ -0,0 +1,77 @@ +import argparse +import logging +from itertools import izip +import cdec, cdec.score + +def evaluate(hyp, ref): + """ Compute BLEU score for a set of hypotheses+references """ + return sum(cdec.score.BLEU(r).evaluate(h) for h, r in izip(hyp, ref)).score + +T1, T2, T3 = 5, 10, 20 # number of iterations (global, CCCP, SSD) +K = 500 # k-best list size +C = 1 # regularization coefficient +eta = 1e-4 # step size +cost = lambda c: 10 * (1 - c.score) # cost definition + +def rampion(decoder, sources, references): + # Empty k-best lists + cs = [cdec.score.BLEU(refs).candidate_set() for refs in references] + # Weight vector -> sparse + w = decoder.weights.tosparse() + w0 = w.copy() + + N = len(sources) + for t in range(T1): + logging.info('Iteration {0}: translating...'.format(t+1)) + # Get the hypergraphs and extend the k-best lists + hgs = [] + for src, candidates in izip(sources, cs): + hg = decoder.translate(src) + hgs.append(hg) + candidates.add_kbest(hg, K) + # BLEU score for the previous iteration + score = evaluate((hg.viterbi() for hg in hgs), references) + logging.info('BLEU: {:.2f}'.format(100 * score)) + logging.info('Optimizing...') + for _ in range(T2): + # y_i^+, h_i^+; i=1..N + plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap + for candidates in cs] + for _ in range(T3): + for fp, candidates in izip(plus, cs): + # y^-, h^- + fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap + # update weights (line 11-12) + w += eta * ((fp - fm) - C/N * (w - w0)) + logging.info('Updated weight vector: {0}'.format(dict(w))) + # Update decoder weights + for fname, fval in w: + decoder.weights[fname] = fval + +def main(): + logging.basicConfig(level=logging.INFO, format='%(message)s') + + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', help='cdec config', required=True) + parser.add_argument('-w', '--weights', help='initial weights', required=True) + parser.add_argument('-r', '--reference', help='reference file', required=True) + parser.add_argument('-s', '--source', help='source file', required=True) + args = parser.parse_args() + + with open(args.config) as fp: + config = fp.read() + + decoder = cdec.Decoder(config) + decoder.read_weights(args.weights) + with open(args.reference) as fp: + references = fp.readlines() + with open(args.source) as fp: + sources = fp.readlines() + assert len(references) == len(sources) + rampion(decoder, sources, references) + + for fname, fval in sorted(dict(decoder.weights).iteritems()): + print('{0}\t{1}'.format(fname, fval)) + +if __name__ == '__main__': + main() |