diff options
author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-08-10 19:03:38 -0400 |
---|---|---|
committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-08-10 19:03:38 -0400 |
commit | 3f2cc751d1f2655aa0ff14ca735da648899edc40 (patch) | |
tree | 2dfd7520031a751165da2815dde4881502059511 | |
parent | 0ec6eab158320dc87054057a5a6aaa3536d2fc91 (diff) |
[python] Examples directory including Rampion
-rw-r--r-- | python/examples/rampion.py | 77 | ||||
-rw-r--r-- | python/examples/test.py (renamed from python/test.py) | 0 |
2 files changed, 77 insertions, 0 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py new file mode 100644 index 00000000..66d89a61 --- /dev/null +++ b/python/examples/rampion.py @@ -0,0 +1,77 @@ +import argparse +import logging +from itertools import izip +import cdec, cdec.score + +def evaluate(hyp, ref): + """ Compute BLEU score for a set of hypotheses+references """ + return sum(cdec.score.BLEU(r).evaluate(h) for h, r in izip(hyp, ref)).score + +T1, T2, T3 = 5, 10, 20 # number of iterations (global, CCCP, SSD) +K = 500 # k-best list size +C = 1 # regularization coefficient +eta = 1e-4 # step size +cost = lambda c: 10 * (1 - c.score) # cost definition + +def rampion(decoder, sources, references): + # Empty k-best lists + cs = [cdec.score.BLEU(refs).candidate_set() for refs in references] + # Weight vector -> sparse + w = decoder.weights.tosparse() + w0 = w.copy() + + N = len(sources) + for t in range(T1): + logging.info('Iteration {0}: translating...'.format(t+1)) + # Get the hypergraphs and extend the k-best lists + hgs = [] + for src, candidates in izip(sources, cs): + hg = decoder.translate(src) + hgs.append(hg) + candidates.add_kbest(hg, K) + # BLEU score for the previous iteration + score = evaluate((hg.viterbi() for hg in hgs), references) + logging.info('BLEU: {:.2f}'.format(100 * score)) + logging.info('Optimizing...') + for _ in range(T2): + # y_i^+, h_i^+; i=1..N + plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap + for candidates in cs] + for _ in range(T3): + for fp, candidates in izip(plus, cs): + # y^-, h^- + fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap + # update weights (line 11-12) + w += eta * ((fp - fm) - C/N * (w - w0)) + logging.info('Updated weight vector: {0}'.format(dict(w))) + # Update decoder weights + for fname, fval in w: + decoder.weights[fname] = fval + +def main(): + logging.basicConfig(level=logging.INFO, format='%(message)s') + + parser = argparse.ArgumentParser() + parser.add_argument('-c', '--config', help='cdec config', required=True) + parser.add_argument('-w', '--weights', help='initial weights', required=True) + parser.add_argument('-r', '--reference', help='reference file', required=True) + parser.add_argument('-s', '--source', help='source file', required=True) + args = parser.parse_args() + + with open(args.config) as fp: + config = fp.read() + + decoder = cdec.Decoder(config) + decoder.read_weights(args.weights) + with open(args.reference) as fp: + references = fp.readlines() + with open(args.source) as fp: + sources = fp.readlines() + assert len(references) == len(sources) + rampion(decoder, sources, references) + + for fname, fval in sorted(dict(decoder.weights).iteritems()): + print('{0}\t{1}'.format(fname, fval)) + +if __name__ == '__main__': + main() diff --git a/python/test.py b/python/examples/test.py index eb9e6a95..eb9e6a95 100644 --- a/python/test.py +++ b/python/examples/test.py |