summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-04-24 17:18:10 +0100
committerPaul Baltescu <pauldb89@gmail.com>2013-04-24 17:18:10 +0100
commitba206aaac1d95e76126443c9e7ccc5941e879849 (patch)
tree13a918da3f3983fd8e4cb74e7cdc3f5e1fc01cd1 /python/pkg/cdec/sa
parentc2aede0f19b7a5e43581768b8c4fbfae8b92c68c (diff)
parentdb960a8bba81df3217660ec5a96d73e0d6baa01b (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/extract.py41
-rwxr-xr-xpython/pkg/cdec/sa/online_extractor.py337
2 files changed, 24 insertions, 354 deletions
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index dc72c18c..b6502c52 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -1,23 +1,26 @@
#!/usr/bin/env python
import sys
import os
+import re
+import gzip
import argparse
import logging
-import re
-import multiprocessing as mp
import signal
+import multiprocessing as mp
import cdec.sa
from cdec.sa._sa import monitor_cpu
extractor, prefix = None, None
-online = False
+online, compress = False, False
-def make_extractor(config, grammars, features):
- global extractor, prefix, online
+def make_extractor(args):
+ global extractor, prefix, online, compress
signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
- load_features(features)
- extractor = cdec.sa.GrammarExtractor(config, online)
- prefix = grammars
+ load_features(args.features)
+ extractor = cdec.sa.GrammarExtractor(args.config, online)
+ prefix = args.grammars
+ online = args.online
+ compress = args.compress
def load_features(features):
for featdef in features:
@@ -28,7 +31,7 @@ def load_features(features):
sys.path.remove(prefix)
def extract(inp):
- global extractor, prefix, online
+ global extractor, prefix, online, compress
i, sentence = inp
sentence = sentence[:-1]
fields = re.split('\s*\|\|\|\s*', sentence)
@@ -37,7 +40,7 @@ def extract(inp):
if online:
if len(fields) < 3:
sys.stderr.write('Error: online mode requires references and alignments.'
- ' Not adding sentence to training data: {0}\n'.format(sentence))
+ ' Not adding sentence to training data: {}\n'.format(sentence))
sentence = fields[0]
else:
sentence, reference, alignment = fields[0:3]
@@ -47,15 +50,17 @@ def extract(inp):
if len(fields) > 1:
sentence = fields[0]
suffix = ' ||| ' + ' ||| '.join(fields[1:])
- grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
- with open(grammar_file, 'w') as output:
+
+ grammar_file = os.path.join(prefix, 'grammar.'+str(i))
+ if compress: grammar_file += '.gz'
+ with (gzip.open if compress else open)(grammar_file, 'w') as output:
for rule in extractor.grammar(sentence):
output.write(str(rule)+'\n')
# Add training instance _after_ extracting grammars
if online:
extractor.add_instance(sentence, reference, alignment)
grammar_file = os.path.abspath(grammar_file)
- return '<seg grammar="{0}" id="{1}"> {2} </seg>{3}'.format(grammar_file, i, sentence, suffix)
+ return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix)
def main():
global online
@@ -71,15 +76,17 @@ def main():
help='number of sentences / chunk')
parser.add_argument('-f', '--features', nargs='*', default=[],
help='additional feature definitions')
- parser.add_argument('-o', '--online', action='store_true', default=False,
+ parser.add_argument('-o', '--online', action='store_true',
help='online grammar extraction')
+ parser.add_argument('-z', '--compress', action='store_true',
+ help='compress grammars with gzip')
args = parser.parse_args()
if not os.path.exists(args.grammars):
os.mkdir(args.grammars)
for featdef in args.features:
if not featdef.endswith('.py'):
- sys.stderr.write('Error: feature definition file <{0}>'
+ sys.stderr.write('Error: feature definition file <{}>'
' should be a python module\n'.format(featdef))
sys.exit(1)
@@ -88,14 +95,14 @@ def main():
start_time = monitor_cpu()
if args.jobs > 1:
logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
- pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars, args.features))
+ pool = mp.Pool(args.jobs, make_extractor, (args,))
try:
for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
print(output)
except KeyboardInterrupt:
pool.terminate()
else:
- make_extractor(args.config, args.grammars, args.features)
+ make_extractor(args)
for output in map(extract, enumerate(sys.stdin)):
print(output)
diff --git a/python/pkg/cdec/sa/online_extractor.py b/python/pkg/cdec/sa/online_extractor.py
deleted file mode 100755
index 03a46b3b..00000000
--- a/python/pkg/cdec/sa/online_extractor.py
+++ /dev/null
@@ -1,337 +0,0 @@
-#!/usr/bin/env python
-
-import collections, sys
-
-import cdec.configobj
-
-CAT = '[X]' # Default non-terminal
-MAX_SIZE = 15 # Max span of a grammar rule (source)
-MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source)
-MAX_NT = 2 # Max number of non-terminals in a rule
-MIN_GAP = 1 # Min number of terminals between non-terminals (source)
-
-# Spans are _inclusive_ on both ends [i, j]
-# TODO: Replace all of this with bit vectors?
-def span_check(vec, i, j):
- k = i
- while k <= j:
- if vec[k]:
- return False
- k += 1
- return True
-
-def span_flip(vec, i, j):
- k = i
- while k <= j:
- vec[k] = ~vec[k]
- k += 1
-
-# Next non-terminal
-def next_nt(nt):
- if not nt:
- return 1
- return nt[-1][0] + 1
-
-class NonTerminal:
- def __init__(self, index):
- self.index = index
- def __str__(self):
- return '[X,{0}]'.format(self.index)
-
-def fmt_rule(f_sym, e_sym, links):
- a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links)
- return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(str(sym) for sym in f_sym),
- ' '.join(str(sym) for sym in e_sym),
- a_str)
-
-class OnlineGrammarExtractor:
-
- def __init__(self, config=None):
- if isinstance(config, str) or isinstance(config, unicode):
- if not os.path.exists(config):
- raise IOError('cannot read configuration from {0}'.format(config))
- config = cdec.configobj.ConfigObj(config, unrepr=True)
- elif not config:
- config = collections.defaultdict(lambda: None)
- self.category = CAT
- self.max_size = MAX_SIZE
- self.max_length = config['max_len'] or MAX_LEN
- self.max_nonterminals = config['max_nt'] or MAX_NT
- self.min_gap_size = MIN_GAP
- # Hard coded: require at least one aligned word
- # Hard coded: require tight phrases
-
- # Phrase counts
- self.phrases_f = collections.defaultdict(lambda: 0)
- self.phrases_e = collections.defaultdict(lambda: 0)
- self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
- # Bilexical counts
- self.bilex_f = collections.defaultdict(lambda: 0)
- self.bilex_e = collections.defaultdict(lambda: 0)
- self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
-
- # Aggregate bilexical counts
- def aggr_bilex(self, f_words, e_words):
-
- for e_w in e_words:
- self.bilex_e[e_w] += 1
-
- for f_w in f_words:
- self.bilex_f[f_w] += 1
- for e_w in e_words:
- self.bilex_fe[f_w][e_w] += 1
-
- # Aggregate stats from a training instance:
- # Extract hierarchical phrase pairs
- # Update bilexical counts
- def add_instance(self, f_words, e_words, alignment):
-
- # Bilexical counts
- self.aggr_bilex(f_words, e_words)
-
- # Phrase pairs extracted from this instance
- phrases = set()
-
- f_len = len(f_words)
- e_len = len(e_words)
-
- # Pre-compute alignment info
- al = [[] for i in range(f_len)]
- al_span = [[f_len + 1, -1] for i in range(f_len)]
- for (f, e) in alignment:
- al[f].append(e)
- al_span[f][0] = min(al_span[f][0], e)
- al_span[f][1] = max(al_span[f][1], e)
-
- # Target side word coverage
- # TODO: Does Cython do bit vectors?
- cover = [0] * e_len
-
- # Extract all possible hierarchical phrases starting at a source index
- # f_ i and j are current, e_ i and j are previous
- def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open):
- # Phrase extraction limits
- if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \
- (f_j - f_i) + 1 > self.max_size:
- return
- # Unaligned word
- if not al[f_j]:
- # Open non-terminal: extend
- if nt_open:
- nt[-1][2] += 1
- extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True)
- nt[-1][2] -= 1
- # No open non-terminal: extend with word
- else:
- extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False)
- return
- # Aligned word
- link_i = al_span[f_j][0]
- link_j = al_span[f_j][1]
- new_e_i = min(link_i, e_i)
- new_e_j = max(link_j, e_j)
- # Open non-terminal: close, extract, extend
- if nt_open:
- # Close non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = f_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
- return
- span_flip(cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
- return
- span_flip(cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- phrases.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
- nt[-1] = old_last_nt
- if link_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
- return
- # No open non-terminal
- # Extract, extend with word
- collision = False
- for link in al[f_j]:
- if cover[link]:
- collision = True
- # Collisions block extraction and extension, but may be okay for
- # continuing non-terminals
- if not collision:
- plus_links = []
- for link in al[f_j]:
- plus_links.append((f_j, link))
- cover[link] = ~cover[link]
- links.append(plus_links)
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- phrases.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False)
- links.pop()
- for link in al[f_j]:
- cover[link] = ~cover[link]
- # Try to add a word to a (closed) non-terminal, extract, extend
- if nt and nt[-1][2] == f_j - 1:
- # Add to non-terminal, checking for collisions
- old_last_nt = nt[-1][:]
- nt[-1][2] = f_j
- if link_i < nt[-1][3]:
- if not span_check(cover, link_i, nt[-1][3] - 1):
- nt[-1] = old_last_nt
- return
- span_flip(cover, link_i, nt[-1][3] - 1)
- nt[-1][3] = link_i
- if link_j > nt[-1][4]:
- if not span_check(cover, nt[-1][4] + 1, link_j):
- nt[-1] = old_last_nt
- return
- span_flip(cover, nt[-1][4] + 1, link_j)
- nt[-1][4] = link_j
- # Require at least one word in phrase
- if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- phrases.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
- nt[-1] = old_last_nt
- if new_e_i < nt[-1][3]:
- span_flip(cover, link_i, nt[-1][3] - 1)
- if link_j > nt[-1][4]:
- span_flip(cover, nt[-1][4] + 1, link_j)
- # Try to start a new non-terminal, extract, extend
- if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals:
- # Check for collisions
- if not span_check(cover, link_i, link_j):
- return
- span_flip(cover, link_i, link_j)
- nt.append([next_nt(nt), f_j, f_j, link_i, link_j])
- # Require at least one word in phrase
- if links:
- for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links):
- phrases.add(rule)
- extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False)
- nt.pop()
- span_flip(cover, link_i, link_j)
- # TODO: try adding NT to start, end, both
- # check: one aligned word on boundary that is not part of a NT
-
- # Try to extract phrases from every f index
- f_i = 0
- while f_i < f_len:
- # Skip if phrases won't be tight on left side
- if not al[f_i]:
- f_i += 1
- continue
- extract(f_i, f_i, f_len + 1, -1, 1, [], [], False)
- f_i += 1
-
- for rule in sorted(phrases):
- print rule
-
- # Create a rule from source, target, non-terminals, and alignments
- def form_rules(self, f_i, e_i, f_span, e_span, nt, al):
-
- # This could be more efficient but is unlikely to be the bottleneck
-
- rules = []
-
- nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3]))
-
- f_sym = f_span[:]
- off = f_i
- for next_nt in nt:
- nt_len = (next_nt[2] - next_nt[1]) + 1
- i = 0
- while i < nt_len:
- f_sym.pop(next_nt[1] - off)
- i += 1
- f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0]))
- off += (nt_len - 1)
-
- e_sym = e_span[:]
- off = e_i
- for next_nt in nt_inv:
- nt_len = (next_nt[4] - next_nt[3]) + 1
- i = 0
- while i < nt_len:
- e_sym.pop(next_nt[3] - off)
- i += 1
- e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0]))
- off += (nt_len - 1)
-
- # Adjusting alignment links takes some doing
- links = [list(link) for sub in al for link in sub]
- links_len = len(links)
- nt_len = len(nt)
- nt_i = 0
- off = f_i
- i = 0
- while i < links_len:
- while nt_i < nt_len and links[i][0] > nt[nt_i][1]:
- off += (nt[nt_i][2] - nt[nt_i][1])
- nt_i += 1
- links[i][0] -= off
- i += 1
- nt_i = 0
- off = e_i
- i = 0
- while i < links_len:
- while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]:
- off += (nt_inv[nt_i][4] - nt_inv[nt_i][3])
- nt_i += 1
- links[i][1] -= off
- i += 1
-
- # Rule
- rules.append(fmt_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals:
- return rules
- last_index = nt[-1][0] if nt else 0
- # Rule [X]
- if not nt or not isinstance(f_sym[-1], NonTerminal):
- f_sym.append(NonTerminal(last_index + 1))
- e_sym.append(NonTerminal(last_index + 1))
- rules.append(fmt_rule(f_sym, e_sym, links))
- f_sym.pop()
- e_sym.pop()
- # [X] Rule
- if not nt or not isinstance(f_sym[0], NonTerminal):
- for sym in f_sym:
- if isinstance(sym, NonTerminal):
- sym.index += 1
- for sym in e_sym:
- if isinstance(sym, NonTerminal):
- sym.index += 1
- for link in links:
- link[0] += 1
- link[1] += 1
- f_sym.insert(0, NonTerminal(1))
- e_sym.insert(0, NonTerminal(1))
- rules.append(fmt_rule(f_sym, e_sym, links))
- if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals:
- return rules
- # [X] Rule [X]
- if not nt or not isinstance(f_sym[-1], NonTerminal):
- f_sym.append(NonTerminal(last_index + 2))
- e_sym.append(NonTerminal(last_index + 2))
- rules.append(fmt_rule(f_sym, e_sym, links))
- return rules
-
-def main(argv):
-
- extractor = OnlineGrammarExtractor()
-
- for line in sys.stdin:
- print >> sys.stderr, line.strip()
- f_words, e_words, a_str = (x.split() for x in line.split('|||'))
- alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str)
- extractor.add_instance(f_words, e_words, alignment)
-
-if __name__ == '__main__':
- main(sys.argv)