summaryrefslogtreecommitdiff
path: root/python/examples
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-08-10 19:03:38 -0400
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-08-10 19:03:38 -0400
commitc3d0668c17f45247e1fec6ffe31b807fbbba6674 (patch)
treee8d673346ddbacfab81de2204be1fea99a0aecbb /python/examples
parentb6474b5cdbf870725371b32670c9dc28671e394c (diff)
[python] Examples directory including Rampion
Diffstat (limited to 'python/examples')
-rw-r--r--python/examples/rampion.py77
-rw-r--r--python/examples/test.py70
2 files changed, 147 insertions, 0 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py
new file mode 100644
index 00000000..66d89a61
--- /dev/null
+++ b/python/examples/rampion.py
@@ -0,0 +1,77 @@
+import argparse
+import logging
+from itertools import izip
+import cdec, cdec.score
+
+def evaluate(hyp, ref):
+ """ Compute BLEU score for a set of hypotheses+references """
+ return sum(cdec.score.BLEU(r).evaluate(h) for h, r in izip(hyp, ref)).score
+
+T1, T2, T3 = 5, 10, 20 # number of iterations (global, CCCP, SSD)
+K = 500 # k-best list size
+C = 1 # regularization coefficient
+eta = 1e-4 # step size
+cost = lambda c: 10 * (1 - c.score) # cost definition
+
+def rampion(decoder, sources, references):
+ # Empty k-best lists
+ cs = [cdec.score.BLEU(refs).candidate_set() for refs in references]
+ # Weight vector -> sparse
+ w = decoder.weights.tosparse()
+ w0 = w.copy()
+
+ N = len(sources)
+ for t in range(T1):
+ logging.info('Iteration {0}: translating...'.format(t+1))
+ # Get the hypergraphs and extend the k-best lists
+ hgs = []
+ for src, candidates in izip(sources, cs):
+ hg = decoder.translate(src)
+ hgs.append(hg)
+ candidates.add_kbest(hg, K)
+ # BLEU score for the previous iteration
+ score = evaluate((hg.viterbi() for hg in hgs), references)
+ logging.info('BLEU: {:.2f}'.format(100 * score))
+ logging.info('Optimizing...')
+ for _ in range(T2):
+ # y_i^+, h_i^+; i=1..N
+ plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
+ for candidates in cs]
+ for _ in range(T3):
+ for fp, candidates in izip(plus, cs):
+ # y^-, h^-
+ fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
+ # update weights (line 11-12)
+ w += eta * ((fp - fm) - C/N * (w - w0))
+ logging.info('Updated weight vector: {0}'.format(dict(w)))
+ # Update decoder weights
+ for fname, fval in w:
+ decoder.weights[fname] = fval
+
+def main():
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-c', '--config', help='cdec config', required=True)
+ parser.add_argument('-w', '--weights', help='initial weights', required=True)
+ parser.add_argument('-r', '--reference', help='reference file', required=True)
+ parser.add_argument('-s', '--source', help='source file', required=True)
+ args = parser.parse_args()
+
+ with open(args.config) as fp:
+ config = fp.read()
+
+ decoder = cdec.Decoder(config)
+ decoder.read_weights(args.weights)
+ with open(args.reference) as fp:
+ references = fp.readlines()
+ with open(args.source) as fp:
+ sources = fp.readlines()
+ assert len(references) == len(sources)
+ rampion(decoder, sources, references)
+
+ for fname, fval in sorted(dict(decoder.weights).iteritems()):
+ print('{0}\t{1}'.format(fname, fval))
+
+if __name__ == '__main__':
+ main()
diff --git a/python/examples/test.py b/python/examples/test.py
new file mode 100644
index 00000000..eb9e6a95
--- /dev/null
+++ b/python/examples/test.py
@@ -0,0 +1,70 @@
+#coding: utf8
+import cdec
+import gzip
+
+weights = '../tests/system_tests/australia/weights'
+grammar_file = '../tests/system_tests/australia/australia.scfg.gz'
+
+# Load decoder width configuration
+decoder = cdec.Decoder(formalism='scfg')
+# Read weights
+decoder.read_weights(weights)
+
+print dict(decoder.weights)
+
+# Read grammar
+with gzip.open(grammar_file) as f:
+ grammar = f.read()
+
+# Input sentence
+sentence = u'澳洲 是 与 北韩 有 邦交 的 少数 国家 之一 。'
+print ' Input:', sentence.encode('utf8')
+
+# Decode
+forest = decoder.translate(sentence, grammar=grammar)
+
+# Get viterbi translation
+print 'Output[0]:', forest.viterbi().encode('utf8')
+f_tree, e_tree = forest.viterbi_trees()
+print ' FTree[0]:', f_tree.encode('utf8')
+print ' ETree[0]:', e_tree.encode('utf8')
+print 'LgProb[0]:', forest.viterbi_features().dot(decoder.weights)
+
+# Get k-best translations
+kbest = zip(forest.kbest(5), forest.kbest_trees(5), forest.kbest_features(5))
+for i, (sentence, (f_tree, e_tree), features) in enumerate(kbest, 1):
+ print 'Output[%d]:' % i, sentence.encode('utf8')
+ print ' FTree[%d]:' % i, f_tree.encode('utf8')
+ print ' ETree[%d]:' % i, e_tree.encode('utf8')
+ print ' FVect[%d]:' % i, dict(features)
+
+# Sample translations from the forest
+for sentence in forest.sample(5):
+ print 'Sample:', sentence.encode('utf8')
+
+# Get feature vector for 1best
+fsrc = forest.viterbi_features()
+
+# Feature expectations
+print 'Feature expectations:', dict(forest.inside_outside())
+
+# Reference lattice
+lattice = ((('australia',0,1),),(('is',0,1),),(('one',0,1),),(('of',0,1),),(('the',0,4),('a',0,4),('a',0,1),('the',0,1),),(('small',0,1),('tiny',0,1),('miniscule',0,1),('handful',0,2),),(('number',0,1),('group',0,1),),(('of',0,2),),(('few',0,1),),(('countries',0,1),),(('that',0,1),),(('has',0,1),('have',0,1),),(('diplomatic',0,1),),(('relations',0,1),),(('with',0,1),),(('north',0,1),),(('korea',0,1),),(('.',0,1),),)
+
+lat = cdec.Lattice(lattice)
+assert (lattice == tuple(lat))
+
+# Intersect forest and lattice
+assert forest.intersect(lat)
+
+# Get best synchronous parse
+f_tree, e_tree = forest.viterbi_trees()
+print 'FTree:', f_tree.encode('utf8')
+print 'ETree:', e_tree.encode('utf8')
+
+# Compare 1best and reference feature vectors
+fref = forest.viterbi_features()
+print dict(fsrc - fref)
+
+# Prune hypergraph
+forest.prune(density=100)