diff options
author | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-09-04 10:21:25 +0100 |
---|---|---|
committer | Victor Chahuneau <vchahune@cs.cmu.edu> | 2012-09-04 10:21:25 +0100 |
commit | 38c38f707e58960f80a8dc216673ae0bb0796ade (patch) | |
tree | 0255d49b96f7b2cec46c7b7ac5e14b5224c64537 /python/examples | |
parent | e828fab2b485dc0e50d9b9d5c5a599db695ce252 (diff) |
Multi-processing grammar extraction
+ various surface fixes
Diffstat (limited to 'python/examples')
-rw-r--r-- | python/examples/rampion.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py index 66d89a61..30244cf7 100644 --- a/python/examples/rampion.py +++ b/python/examples/rampion.py @@ -15,7 +15,7 @@ cost = lambda c: 10 * (1 - c.score) # cost definition def rampion(decoder, sources, references): # Empty k-best lists - cs = [cdec.score.BLEU(refs).candidate_set() for refs in references] + candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references] # Weight vector -> sparse w = decoder.weights.tosparse() w0 = w.copy() @@ -25,7 +25,7 @@ def rampion(decoder, sources, references): logging.info('Iteration {0}: translating...'.format(t+1)) # Get the hypergraphs and extend the k-best lists hgs = [] - for src, candidates in izip(sources, cs): + for src, candidates in izip(sources, candidate_sets): hg = decoder.translate(src) hgs.append(hg) candidates.add_kbest(hg, K) @@ -36,17 +36,16 @@ def rampion(decoder, sources, references): for _ in range(T2): # y_i^+, h_i^+; i=1..N plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap - for candidates in cs] + for candidates in candidate_sets] for _ in range(T3): - for fp, candidates in izip(plus, cs): + for fp, candidates in izip(plus, candidate_sets): # y^-, h^- fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap # update weights (line 11-12) w += eta * ((fp - fm) - C/N * (w - w0)) logging.info('Updated weight vector: {0}'.format(dict(w))) # Update decoder weights - for fname, fval in w: - decoder.weights[fname] = fval + decoder.weights = w def main(): logging.basicConfig(level=logging.INFO, format='%(message)s') |