summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/examples/rampion.py11
-rw-r--r--python/pkg/cdec/sa/extract.py45
-rw-r--r--python/src/hypergraph.pxd2
3 files changed, 40 insertions, 18 deletions
diff --git a/python/examples/rampion.py b/python/examples/rampion.py
index 66d89a61..30244cf7 100644
--- a/python/examples/rampion.py
+++ b/python/examples/rampion.py
@@ -15,7 +15,7 @@ cost = lambda c: 10 * (1 - c.score) # cost definition
def rampion(decoder, sources, references):
# Empty k-best lists
- cs = [cdec.score.BLEU(refs).candidate_set() for refs in references]
+ candidate_sets = [cdec.score.BLEU(refs).candidate_set() for refs in references]
# Weight vector -> sparse
w = decoder.weights.tosparse()
w0 = w.copy()
@@ -25,7 +25,7 @@ def rampion(decoder, sources, references):
logging.info('Iteration {0}: translating...'.format(t+1))
# Get the hypergraphs and extend the k-best lists
hgs = []
- for src, candidates in izip(sources, cs):
+ for src, candidates in izip(sources, candidate_sets):
hg = decoder.translate(src)
hgs.append(hg)
candidates.add_kbest(hg, K)
@@ -36,17 +36,16 @@ def rampion(decoder, sources, references):
for _ in range(T2):
# y_i^+, h_i^+; i=1..N
plus = [max(candidates, key=lambda c: w.dot(c.fmap) - cost(c)).fmap
- for candidates in cs]
+ for candidates in candidate_sets]
for _ in range(T3):
- for fp, candidates in izip(plus, cs):
+ for fp, candidates in izip(plus, candidate_sets):
# y^-, h^-
fm = max(candidates, key=lambda c: w.dot(c.fmap) + cost(c)).fmap
# update weights (line 11-12)
w += eta * ((fp - fm) - C/N * (w - w0))
logging.info('Updated weight vector: {0}'.format(dict(w)))
# Update decoder weights
- for fname, fval in w:
- decoder.weights[fname] = fval
+ decoder.weights = w
def main():
logging.basicConfig(level=logging.INFO, format='%(message)s')
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 875bf42e..39eac824 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -3,29 +3,52 @@ import sys
import os
import argparse
import logging
+import multiprocessing as mp
+import signal
import cdec.sa
+extractor, prefix = None, None
+def make_extractor(config, grammars):
+ global extractor, prefix
+ signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
+ extractor = cdec.sa.GrammarExtractor(config)
+ prefix = grammars
+
+def extract(inp):
+ global extractor, prefix
+ i, sentence = inp
+ sentence = sentence[:-1]
+ grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
+ with open(grammar_file, 'w') as output:
+ for rule in extractor.grammar(sentence):
+ output.write(str(rule)+'\n')
+ grammar_file = os.path.abspath(grammar_file)
+ return '<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence)
+
+
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.')
parser.add_argument('-c', '--config', required=True,
- help='Extractor configuration')
+ help='extractor configuration')
parser.add_argument('-g', '--grammars', required=True,
- help='Grammar output path')
+ help='grammar output path')
+ parser.add_argument('-j', '--jobs', type=int, default=1,
+ help='number of parallel extractors')
+ parser.add_argument('-s', '--chunksize', type=int, default=10,
+ help='number of sentences / chunk')
args = parser.parse_args()
if not os.path.exists(args.grammars):
os.mkdir(args.grammars)
- extractor = cdec.sa.GrammarExtractor(args.config)
- for i, sentence in enumerate(sys.stdin):
- sentence = sentence[:-1]
- grammar_file = os.path.join(args.grammars, 'grammar.{0}'.format(i))
- with open(grammar_file, 'w') as output:
- for rule in extractor.grammar(sentence):
- output.write(str(rule)+'\n')
- grammar_file = os.path.abspath(grammar_file)
- print('<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence))
+ logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
+ pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars))
+ try:
+ for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
+ print(output)
+ except KeyboardInterrupt:
+ pool.terminate()
if __name__ == '__main__':
main()
diff --git a/python/src/hypergraph.pxd b/python/src/hypergraph.pxd
index acab7244..dd3d39cc 100644
--- a/python/src/hypergraph.pxd
+++ b/python/src/hypergraph.pxd
@@ -38,7 +38,7 @@ cdef extern from "decoder/hg.h":
int GoalNode()
double NumberOfPaths()
void Reweight(vector[weight_t]& weights) nogil
- void Reweight(FastSparseVector& weights) nogil
+ void Reweight(FastSparseVector[weight_t]& weights) nogil
bint PruneInsideOutside(double beam_alpha,
double density,
EdgeMask* preserve_mask,