1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
|
import argparse
import logging
from itertools import izip
import cdec, cdec.score
def evaluate(hyp, ref):
""" Compute BLEU score for a set of hypotheses+references """
return sum(cdec.score.BLEU(r).evaluate(h) for h, r in izip(hyp, ref)).score
T1, T2, T3 = 5, 10, 20 # number of iterations (global, CCCP, SSD)
K = 500 # k-best list size
C = 1 # regularization coefficient
eta = 1e-4 # step size
cost = lambda c: 10 * (1 - c.score) # cost definition
def rampion(decoder, sources, references):
# Empty k-best lists
candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references]
# Weight vector -> sparse
w = decoder.weights.tosparse()
w0 = w.copy()
N = len(sources)
for t in range(T1):
logging.info('Iteration {0}: translating...'.format(t+1))
# Get the hypergraphs and extend the k-best lists
hgs = []
for src, candidates in izip(sources, candidate_sets):
hg = decoder.translate(src)
hgs.append(hg)
candidates.add_kbest(hg, K)
# BLEU score for the previous iteration
score = evaluate((hg.viterbi() for hg in hgs), references)
logging.info('BLEU: {:.2f}'.format(100 * score))
logging.info('Optimizing...')
for _ in range(T2):
# y_i^+, h_i^+; i=1..N
plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
for candidates in candidate_sets]
for _ in range(T3):
for fp, candidates in izip(plus, candidate_sets):
# y^-, h^-
fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
# update weights (line 11-12)
w += eta * ((fp - fm) - C/N * (w - w0))
logging.info('Updated weight vector: {0}'.format(dict(w)))
# Update decoder weights
decoder.weights = w
def main():
logging.basicConfig(level=logging.INFO, format='%(message)s')
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', help='cdec config', required=True)
parser.add_argument('-w', '--weights', help='initial weights', required=True)
parser.add_argument('-r', '--reference', help='reference file', required=True)
parser.add_argument('-s', '--source', help='source file', required=True)
args = parser.parse_args()
with open(args.config) as fp:
config = fp.read()
decoder = cdec.Decoder(config)
decoder.read_weights(args.weights)
with open(args.reference) as fp:
references = fp.readlines()
with open(args.source) as fp:
sources = fp.readlines()
assert len(references) == len(sources)
rampion(decoder, sources, references)
for fname, fval in sorted(dict(decoder.weights).iteritems()):
print('{0}\t{1}'.format(fname, fval))
if __name__ == '__main__':
main()
|