diff options
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r-- | python/pkg/cdec/sa/__init__.py | 1 | ||||
-rw-r--r-- | python/pkg/cdec/sa/extract.py | 33 | ||||
-rw-r--r-- | python/pkg/cdec/sa/extractor.py | 30 | ||||
-rw-r--r-- | python/pkg/cdec/sa/features.py | 117 | ||||
-rwxr-xr-x | python/pkg/cdec/sa/online_extractor.py | 337 |
5 files changed, 492 insertions, 26 deletions
diff --git a/python/pkg/cdec/sa/__init__.py b/python/pkg/cdec/sa/__init__.py index e0a344b7..14ba5ecb 100644 --- a/python/pkg/cdec/sa/__init__.py +++ b/python/pkg/cdec/sa/__init__.py @@ -1,4 +1,5 @@ from cdec.sa._sa import make_lattice, decode_lattice, decode_sentence,\ + encode_words, decode_words, isvar,\ SuffixArray, DataArray, LCP, Precomputation, Alignment, BiLex,\ HieroCachingRuleFactory, Sampler, Scorer from cdec.sa.extractor import GrammarExtractor diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py index 87b7d5d4..dc72c18c 100644 --- a/python/pkg/cdec/sa/extract.py +++ b/python/pkg/cdec/sa/extract.py @@ -10,11 +10,13 @@ import cdec.sa from cdec.sa._sa import monitor_cpu extractor, prefix = None, None +online = False + def make_extractor(config, grammars, features): - global extractor, prefix + global extractor, prefix, online signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C load_features(features) - extractor = cdec.sa.GrammarExtractor(config) + extractor = cdec.sa.GrammarExtractor(config, online) prefix = grammars def load_features(features): @@ -26,22 +28,37 @@ def load_features(features): sys.path.remove(prefix) def extract(inp): - global extractor, prefix + global extractor, prefix, online i, sentence = inp sentence = sentence[:-1] fields = re.split('\s*\|\|\|\s*', sentence) suffix = '' - if len(fields) > 1: - sentence = fields[0] - suffix = ' ||| ' + ' ||| '.join(fields[1:]) + # 3 fields for online mode, 1 for normal + if online: + if len(fields) < 3: + sys.stderr.write('Error: online mode requires references and alignments.' + ' Not adding sentence to training data: {0}\n'.format(sentence)) + sentence = fields[0] + else: + sentence, reference, alignment = fields[0:3] + if len(fields) > 3: + suffix = ' ||| ' + ' ||| '.join(fields[3:]) + else: + if len(fields) > 1: + sentence = fields[0] + suffix = ' ||| ' + ' ||| '.join(fields[1:]) grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i)) with open(grammar_file, 'w') as output: for rule in extractor.grammar(sentence): output.write(str(rule)+'\n') + # Add training instance _after_ extracting grammars + if online: + extractor.add_instance(sentence, reference, alignment) grammar_file = os.path.abspath(grammar_file) return '<seg grammar="{0}" id="{1}"> {2} </seg>{3}'.format(grammar_file, i, sentence, suffix) def main(): + global online logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description='Extract grammars from a compiled corpus.') parser.add_argument('-c', '--config', required=True, @@ -54,6 +71,8 @@ def main(): help='number of sentences / chunk') parser.add_argument('-f', '--features', nargs='*', default=[], help='additional feature definitions') + parser.add_argument('-o', '--online', action='store_true', default=False, + help='online grammar extraction') args = parser.parse_args() if not os.path.exists(args.grammars): @@ -64,6 +83,8 @@ def main(): ' should be a python module\n'.format(featdef)) sys.exit(1) + online = args.online + start_time = monitor_cpu() if args.jobs > 1: logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize) diff --git a/python/pkg/cdec/sa/extractor.py b/python/pkg/cdec/sa/extractor.py index e09f79ea..acc13cbc 100644 --- a/python/pkg/cdec/sa/extractor.py +++ b/python/pkg/cdec/sa/extractor.py @@ -1,15 +1,16 @@ from itertools import chain -import os +import os, sys import cdec.configobj from cdec.sa.features import EgivenFCoherent, SampleCountF, CountEF,\ - MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE + MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE,\ + IsSupportedOnline import cdec.sa # maximum span of a grammar rule in TEST DATA MAX_INITIAL_SIZE = 15 class GrammarExtractor: - def __init__(self, config, features=None): + def __init__(self, config, online=False, features=None): if isinstance(config, basestring): if not os.path.exists(config): raise IOError('cannot read configuration from {0}'.format(config)) @@ -57,11 +58,19 @@ class GrammarExtractor: # lexical weighting tables tt = cdec.sa.BiLex(from_binary=config['lex_file']) + # TODO: clean this up + extended_features = [] + if online: + extended_features.append(IsSupportedOnline) + # TODO: use @cdec.sa.features decorator for standard features too # + add a mask to disable features + for f in cdec.sa._SA_FEATURES: + extended_features.append(f) + scorer = cdec.sa.Scorer(EgivenFCoherent, SampleCountF, CountEF, MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE, - *cdec.sa._SA_FEATURES) + *extended_features) fsarray = cdec.sa.SuffixArray(from_binary=config['f_sa_file']) edarray = cdec.sa.DataArray(from_binary=config['e_file']) @@ -82,3 +91,16 @@ class GrammarExtractor: meta = cdec.sa.annotate(words) cnet = cdec.sa.make_lattice(words) return self.factory.input(cnet, meta) + + # Add training instance to data + def add_instance(self, sentence, reference, alignment): + f_words = cdec.sa.encode_words(sentence.split()) + e_words = cdec.sa.encode_words(reference.split()) + al = sorted(tuple(int(i) for i in pair.split('-')) for pair in alignment.split()) + self.factory.add_instance(f_words, e_words, al) + + # Debugging + def dump_online_stats(self): + self.factory.dump_online_stats() + def dump_online_rules(self): + self.factory.dump_online_rules()
\ No newline at end of file diff --git a/python/pkg/cdec/sa/features.py b/python/pkg/cdec/sa/features.py index a4ae23e8..46412cd5 100644 --- a/python/pkg/cdec/sa/features.py +++ b/python/pkg/cdec/sa/features.py @@ -1,57 +1,142 @@ from __future__ import division import math +from cdec.sa import isvar + MAXSCORE = 99 def EgivenF(ctx): # p(e|f) = c(e, f)/c(f) - return -math.log10(ctx.paircount/ctx.fcount) + if not ctx.online: + prob = ctx.paircount/ctx.fcount + else: + prob = (ctx.paircount + ctx.online.paircount) / (ctx.fcount + ctx.online.fcount) + return -math.log10(prob) def CountEF(ctx): # c(e, f) - return math.log10(1 + ctx.paircount) + if not ctx.online: + count = 1 + ctx.paircount + else: + count = 1 + ctx.paircount + ctx.online.paircount + return math.log10(count) def SampleCountF(ctx): # sample c(f) - return math.log10(1 + ctx.fsample_count) + if not ctx.online: + count = 1 + ctx.fsample_count + else: + count = 1 + ctx.fsample_count + ctx.online.fsample_count + return math.log10(count) def EgivenFCoherent(ctx): # c(e, f) / sample c(f) - prob = ctx.paircount/ctx.fsample_count + if not ctx.online: + prob = ctx.paircount/ctx.fsample_count + else: + prob = (ctx.paircount + ctx.online.paircount) / (ctx.fsample_count + ctx.online.fsample_count) return -math.log10(prob) if prob > 0 else MAXSCORE def CoherenceProb(ctx): # c(f) / sample c(f) - return -math.log10(ctx.fcount/ctx.fsample_count) + if not ctx.online: + prob = ctx.fcount/ctx.fsample_count + else: + prob = (ctx.fcount + ctx.online.fcount) / (ctx.fsample_count + ctx.online.fsample_count) + return -math.log10(prob) def MaxLexEgivenF(ttable): def MaxLexEgivenF(ctx): fwords = ctx.fphrase.words fwords.append('NULL') - def score(): + # Always use this for now + if not ctx.online or ctx.online: + maxOffScore = 0.0 + for e in ctx.ephrase.words: + maxScore = max(ttable.get_score(f, e, 0) for f in fwords) + maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + return maxOffScore + else: + # For now, straight average + maxOffScore = 0.0 + maxOnScore = 0.0 for e in ctx.ephrase.words: - maxScore = max(ttable.get_score(f, e, 0) for f in fwords) - yield -math.log10(maxScore) if maxScore > 0 else MAXSCORE - return sum(score()) + maxScore = max(ttable.get_score(f, e, 0) for f in fwords) + maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + for e in ctx.ephrase: + if not isvar(e): + maxScore = 0.0 + for f in ctx.fphrase: + if not isvar(f): + b_f = ctx.online.bilex_f.get(f, 0) + if b_f: + maxScore = max(maxScore, ctx.online.bilex_fe.get(f, {}).get(e)) + maxOnScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + return (maxOffScore + maxOnScore) / 2 return MaxLexEgivenF def MaxLexFgivenE(ttable): def MaxLexFgivenE(ctx): ewords = ctx.ephrase.words ewords.append('NULL') - def score(): + # Always use this for now + if not ctx.online or ctx.online: + maxOffScore = 0.0 for f in ctx.fphrase.words: - maxScore = max(ttable.get_score(f, e, 1) for e in ewords) - yield -math.log10(maxScore) if maxScore > 0 else MAXSCORE - return sum(score()) + maxScore = max(ttable.get_score(f, e, 1) for e in ewords) + maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + return maxOffScore + else: + # For now, straight average + maxOffScore = 0.0 + maxOnScore = 0.0 + for f in ctx.fphrase.words: + maxScore = max(ttable.get_score(f, e, 1) for e in ewords) + maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + for f in ctx.fphrase: + if not isvar(f): + maxScore = 0.0 + for e in ctx.ephrase: + if not isvar(e): + b_e = ctx.online.bilex_e.get(e, 0) + if b_e: + maxScore = max(maxScore, ctx.online.bilex_fe.get(f, {}).get(e, 0) / b_e ) + maxOnScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE + return (maxOffScore + maxOnScore) / 2 return MaxLexFgivenE def IsSingletonF(ctx): - return (ctx.fcount == 1) + if not ctx.online: + count = ctx.fcount + else: + count = ctx.fcount + ctx.online.fcount + return (count == 1) def IsSingletonFE(ctx): - return (ctx.paircount == 1) + if not ctx.online: + count = ctx.paircount + else: + count = ctx.paircount + ctx.online.paircount + return (count == 1) def IsNotSingletonF(ctx): - return (ctx.fcount > 1) + if not ctx.online: + count = ctx.fcount + else: + count = ctx.fcount + ctx.online.fcount + return (count > 1) def IsNotSingletonFE(ctx): + if not ctx.online: + count = ctx.paircount + else: + count = ctx.paircount + ctx.online.paircount return (ctx.paircount > 1) def IsFEGreaterThanZero(ctx): + if not ctx.online: + count = ctx.paircount + else: + count = ctx.paircount + ctx.online.paircount return (ctx.paircount > 0.01) + +def IsSupportedOnline(ctx): # Occurs in online data? + if ctx.online: + return (ctx.online.paircount > 0.01) + else: + return False
\ No newline at end of file diff --git a/python/pkg/cdec/sa/online_extractor.py b/python/pkg/cdec/sa/online_extractor.py new file mode 100755 index 00000000..03a46b3b --- /dev/null +++ b/python/pkg/cdec/sa/online_extractor.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python + +import collections, sys + +import cdec.configobj + +CAT = '[X]' # Default non-terminal +MAX_SIZE = 15 # Max span of a grammar rule (source) +MAX_LEN = 5 # Max number of terminals and non-terminals in a rule (source) +MAX_NT = 2 # Max number of non-terminals in a rule +MIN_GAP = 1 # Min number of terminals between non-terminals (source) + +# Spans are _inclusive_ on both ends [i, j] +# TODO: Replace all of this with bit vectors? +def span_check(vec, i, j): + k = i + while k <= j: + if vec[k]: + return False + k += 1 + return True + +def span_flip(vec, i, j): + k = i + while k <= j: + vec[k] = ~vec[k] + k += 1 + +# Next non-terminal +def next_nt(nt): + if not nt: + return 1 + return nt[-1][0] + 1 + +class NonTerminal: + def __init__(self, index): + self.index = index + def __str__(self): + return '[X,{0}]'.format(self.index) + +def fmt_rule(f_sym, e_sym, links): + a_str = ' '.join('{0}-{1}'.format(i, j) for (i, j) in links) + return '[X] ||| {0} ||| {1} ||| {2}'.format(' '.join(str(sym) for sym in f_sym), + ' '.join(str(sym) for sym in e_sym), + a_str) + +class OnlineGrammarExtractor: + + def __init__(self, config=None): + if isinstance(config, str) or isinstance(config, unicode): + if not os.path.exists(config): + raise IOError('cannot read configuration from {0}'.format(config)) + config = cdec.configobj.ConfigObj(config, unrepr=True) + elif not config: + config = collections.defaultdict(lambda: None) + self.category = CAT + self.max_size = MAX_SIZE + self.max_length = config['max_len'] or MAX_LEN + self.max_nonterminals = config['max_nt'] or MAX_NT + self.min_gap_size = MIN_GAP + # Hard coded: require at least one aligned word + # Hard coded: require tight phrases + + # Phrase counts + self.phrases_f = collections.defaultdict(lambda: 0) + self.phrases_e = collections.defaultdict(lambda: 0) + self.phrases_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + # Bilexical counts + self.bilex_f = collections.defaultdict(lambda: 0) + self.bilex_e = collections.defaultdict(lambda: 0) + self.bilex_fe = collections.defaultdict(lambda: collections.defaultdict(lambda: 0)) + + # Aggregate bilexical counts + def aggr_bilex(self, f_words, e_words): + + for e_w in e_words: + self.bilex_e[e_w] += 1 + + for f_w in f_words: + self.bilex_f[f_w] += 1 + for e_w in e_words: + self.bilex_fe[f_w][e_w] += 1 + + # Aggregate stats from a training instance: + # Extract hierarchical phrase pairs + # Update bilexical counts + def add_instance(self, f_words, e_words, alignment): + + # Bilexical counts + self.aggr_bilex(f_words, e_words) + + # Phrase pairs extracted from this instance + phrases = set() + + f_len = len(f_words) + e_len = len(e_words) + + # Pre-compute alignment info + al = [[] for i in range(f_len)] + al_span = [[f_len + 1, -1] for i in range(f_len)] + for (f, e) in alignment: + al[f].append(e) + al_span[f][0] = min(al_span[f][0], e) + al_span[f][1] = max(al_span[f][1], e) + + # Target side word coverage + # TODO: Does Cython do bit vectors? + cover = [0] * e_len + + # Extract all possible hierarchical phrases starting at a source index + # f_ i and j are current, e_ i and j are previous + def extract(f_i, f_j, e_i, e_j, wc, links, nt, nt_open): + # Phrase extraction limits + if wc + len(nt) > self.max_length or (f_j + 1) > f_len or \ + (f_j - f_i) + 1 > self.max_size: + return + # Unaligned word + if not al[f_j]: + # Open non-terminal: extend + if nt_open: + nt[-1][2] += 1 + extract(f_i, f_j + 1, e_i, e_j, wc, links, nt, True) + nt[-1][2] -= 1 + # No open non-terminal: extend with word + else: + extract(f_i, f_j + 1, e_i, e_j, wc + 1, links, nt, False) + return + # Aligned word + link_i = al_span[f_j][0] + link_j = al_span[f_j][1] + new_e_i = min(link_i, e_i) + new_e_j = max(link_j, e_j) + # Open non-terminal: close, extract, extend + if nt_open: + # Close non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if link_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + return + # No open non-terminal + # Extract, extend with word + collision = False + for link in al[f_j]: + if cover[link]: + collision = True + # Collisions block extraction and extension, but may be okay for + # continuing non-terminals + if not collision: + plus_links = [] + for link in al[f_j]: + plus_links.append((f_j, link)) + cover[link] = ~cover[link] + links.append(plus_links) + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc + 1, links, nt, False) + links.pop() + for link in al[f_j]: + cover[link] = ~cover[link] + # Try to add a word to a (closed) non-terminal, extract, extend + if nt and nt[-1][2] == f_j - 1: + # Add to non-terminal, checking for collisions + old_last_nt = nt[-1][:] + nt[-1][2] = f_j + if link_i < nt[-1][3]: + if not span_check(cover, link_i, nt[-1][3] - 1): + nt[-1] = old_last_nt + return + span_flip(cover, link_i, nt[-1][3] - 1) + nt[-1][3] = link_i + if link_j > nt[-1][4]: + if not span_check(cover, nt[-1][4] + 1, link_j): + nt[-1] = old_last_nt + return + span_flip(cover, nt[-1][4] + 1, link_j) + nt[-1][4] = link_j + # Require at least one word in phrase + if links: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt[-1] = old_last_nt + if new_e_i < nt[-1][3]: + span_flip(cover, link_i, nt[-1][3] - 1) + if link_j > nt[-1][4]: + span_flip(cover, nt[-1][4] + 1, link_j) + # Try to start a new non-terminal, extract, extend + if (not nt or f_j - nt[-1][2] > 1) and len(nt) < self.max_nonterminals: + # Check for collisions + if not span_check(cover, link_i, link_j): + return + span_flip(cover, link_i, link_j) + nt.append([next_nt(nt), f_j, f_j, link_i, link_j]) + # Require at least one word in phrase + if links: + for rule in self.form_rules(f_i, new_e_i, f_words[f_i:f_j + 1], e_words[new_e_i:new_e_j + 1], nt, links): + phrases.add(rule) + extract(f_i, f_j + 1, new_e_i, new_e_j, wc, links, nt, False) + nt.pop() + span_flip(cover, link_i, link_j) + # TODO: try adding NT to start, end, both + # check: one aligned word on boundary that is not part of a NT + + # Try to extract phrases from every f index + f_i = 0 + while f_i < f_len: + # Skip if phrases won't be tight on left side + if not al[f_i]: + f_i += 1 + continue + extract(f_i, f_i, f_len + 1, -1, 1, [], [], False) + f_i += 1 + + for rule in sorted(phrases): + print rule + + # Create a rule from source, target, non-terminals, and alignments + def form_rules(self, f_i, e_i, f_span, e_span, nt, al): + + # This could be more efficient but is unlikely to be the bottleneck + + rules = [] + + nt_inv = sorted(nt, cmp=lambda x, y: cmp(x[3], y[3])) + + f_sym = f_span[:] + off = f_i + for next_nt in nt: + nt_len = (next_nt[2] - next_nt[1]) + 1 + i = 0 + while i < nt_len: + f_sym.pop(next_nt[1] - off) + i += 1 + f_sym.insert(next_nt[1] - off, NonTerminal(next_nt[0])) + off += (nt_len - 1) + + e_sym = e_span[:] + off = e_i + for next_nt in nt_inv: + nt_len = (next_nt[4] - next_nt[3]) + 1 + i = 0 + while i < nt_len: + e_sym.pop(next_nt[3] - off) + i += 1 + e_sym.insert(next_nt[3] - off, NonTerminal(next_nt[0])) + off += (nt_len - 1) + + # Adjusting alignment links takes some doing + links = [list(link) for sub in al for link in sub] + links_len = len(links) + nt_len = len(nt) + nt_i = 0 + off = f_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][0] > nt[nt_i][1]: + off += (nt[nt_i][2] - nt[nt_i][1]) + nt_i += 1 + links[i][0] -= off + i += 1 + nt_i = 0 + off = e_i + i = 0 + while i < links_len: + while nt_i < nt_len and links[i][1] > nt_inv[nt_i][3]: + off += (nt_inv[nt_i][4] - nt_inv[nt_i][3]) + nt_i += 1 + links[i][1] -= off + i += 1 + + # Rule + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) >= self.max_nonterminals: + return rules + last_index = nt[-1][0] if nt else 0 + # Rule [X] + if not nt or not isinstance(f_sym[-1], NonTerminal): + f_sym.append(NonTerminal(last_index + 1)) + e_sym.append(NonTerminal(last_index + 1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + f_sym.pop() + e_sym.pop() + # [X] Rule + if not nt or not isinstance(f_sym[0], NonTerminal): + for sym in f_sym: + if isinstance(sym, NonTerminal): + sym.index += 1 + for sym in e_sym: + if isinstance(sym, NonTerminal): + sym.index += 1 + for link in links: + link[0] += 1 + link[1] += 1 + f_sym.insert(0, NonTerminal(1)) + e_sym.insert(0, NonTerminal(1)) + rules.append(fmt_rule(f_sym, e_sym, links)) + if len(f_sym) >= self.max_length or len(nt) + 1 >= self.max_nonterminals: + return rules + # [X] Rule [X] + if not nt or not isinstance(f_sym[-1], NonTerminal): + f_sym.append(NonTerminal(last_index + 2)) + e_sym.append(NonTerminal(last_index + 2)) + rules.append(fmt_rule(f_sym, e_sym, links)) + return rules + +def main(argv): + + extractor = OnlineGrammarExtractor() + + for line in sys.stdin: + print >> sys.stderr, line.strip() + f_words, e_words, a_str = (x.split() for x in line.split('|||')) + alignment = sorted(tuple(int(y) for y in x.split('-')) for x in a_str) + extractor.add_instance(f_words, e_words, alignment) + +if __name__ == '__main__': + main(sys.argv) |