summaryrefslogtreecommitdiff
path: root/python/examples/rampion.py
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-04 10:21:25 +0100
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-09-04 10:21:25 +0100
commitb774a1ce6aced0e17d308d775cb32ba18ab755a8 (patch)
tree5ac4e3edcbe3d7ad3d2283eb080e862a2f30091d /python/examples/rampion.py
parent063152d73f2814be32dfa8e927fa00caf1af1855 (diff)
Multi-processing grammar extraction
+ various surface fixes
Diffstat (limited to 'python/examples/rampion.py')
-rw-r--r--python/examples/rampion.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py
index 66d89a61..30244cf7 100644
--- a/python/examples/rampion.py
+++ b/python/examples/rampion.py
@@ -15,7 +15,7 @@ cost = lambda c: 10 * (1 - c.score) # cost definition
def rampion(decoder, sources, references):
# Empty k-best lists
- cs = [cdec.score.BLEU(refs).candidate_set() for refs in references]
+ candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references]
# Weight vector -> sparse
w = decoder.weights.tosparse()
w0 = w.copy()
@@ -25,7 +25,7 @@ def rampion(decoder, sources, references):
logging.info('Iteration {0}: translating...'.format(t+1))
# Get the hypergraphs and extend the k-best lists
hgs = []
- for src, candidates in izip(sources, cs):
+ for src, candidates in izip(sources, candidate_sets):
hg = decoder.translate(src)
hgs.append(hg)
candidates.add_kbest(hg, K)
@@ -36,17 +36,16 @@ def rampion(decoder, sources, references):
for _ in range(T2):
# y_i^+, h_i^+; i=1..N
plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
- for candidates in cs]
+ for candidates in candidate_sets]
for _ in range(T3):
- for fp, candidates in izip(plus, cs):
+ for fp, candidates in izip(plus, candidate_sets):
# y^-, h^-
fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
# update weights (line 11-12)
w += eta * ((fp - fm) - C/N * (w - w0))
logging.info('Updated weight vector: {0}'.format(dict(w)))
# Update decoder weights
- for fname, fval in w:
- decoder.weights[fname] = fval
+ decoder.weights = w
def main():
logging.basicConfig(level=logging.INFO, format='%(message)s')