summaryrefslogtreecommitdiff
path: root/python/pkg/cdec/sa
diff options
context:
space:
mode:
Diffstat (limited to 'python/pkg/cdec/sa')
-rw-r--r--python/pkg/cdec/sa/__init__.py1
-rw-r--r--python/pkg/cdec/sa/compile.py13
-rw-r--r--python/pkg/cdec/sa/extract.py57
-rw-r--r--python/pkg/cdec/sa/extractor.py36
-rw-r--r--python/pkg/cdec/sa/features.py117
5 files changed, 185 insertions, 39 deletions
diff --git a/python/pkg/cdec/sa/__init__.py b/python/pkg/cdec/sa/__init__.py
index e0a344b7..14ba5ecb 100644
--- a/python/pkg/cdec/sa/__init__.py
+++ b/python/pkg/cdec/sa/__init__.py
@@ -1,4 +1,5 @@
from cdec.sa._sa import make_lattice, decode_lattice, decode_sentence,\
+ encode_words, decode_words, isvar,\
SuffixArray, DataArray, LCP, Precomputation, Alignment, BiLex,\
HieroCachingRuleFactory, Sampler, Scorer
from cdec.sa.extractor import GrammarExtractor
diff --git a/python/pkg/cdec/sa/compile.py b/python/pkg/cdec/sa/compile.py
index 393c72a4..ce249c0f 100644
--- a/python/pkg/cdec/sa/compile.py
+++ b/python/pkg/cdec/sa/compile.py
@@ -4,9 +4,10 @@ import os
import logging
import cdec.configobj
import cdec.sa
+import sys
MAX_PHRASE_LENGTH = 4
-def precompute(f_sa, max_len, max_nt, max_size, min_gap, rank1, rank2):
+def precompute(f_sa, max_len, max_nt, max_size, min_gap, rank1, rank2, tight_phrases):
lcp = cdec.sa.LCP(f_sa)
stats = sorted(lcp.compute_stats(MAX_PHRASE_LENGTH), reverse=True)
precomp = cdec.sa.Precomputation(from_stats=stats,
@@ -20,6 +21,8 @@ def precompute(f_sa, max_len, max_nt, max_size, min_gap, rank1, rank2):
return precomp
def main():
+ sys.setrecursionlimit(sys.getrecursionlimit() * 100)
+
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('cdec.sa.compile')
parser = argparse.ArgumentParser(description='Compile a corpus into a suffix array.')
@@ -35,6 +38,8 @@ def main():
help='Number of pre-computed frequent patterns')
parser.add_argument('--rank2', '-r2', type=int, default=10,
help='Number of pre-computed super-frequent patterns)')
+ parser.add_argument('--loose', action='store_true',
+ help='Enable loose phrase extraction (default: tight)')
parser.add_argument('-c', '--config', default='/dev/stdout',
help='Output configuration')
parser.add_argument('-f', '--source',
@@ -53,8 +58,10 @@ def main():
parser.error('a parallel corpus is required\n'
'\tuse -f (source) with -e (target) or -b (bitext)')
- param_names = ("max_len", "max_nt", "max_size", "min_gap", "rank1", "rank2")
- params = (args.maxlen, args.maxnt, args.maxsize, args.mingap, args.rank1, args.rank2)
+ param_names = ('max_len', 'max_nt', 'max_size', 'min_gap',
+ 'rank1', 'rank2', 'tight_phrases')
+ params = (args.maxlen, args.maxnt, args.maxsize, args.mingap,
+ args.rank1, args.rank2, not args.loose)
if not os.path.exists(args.output):
os.mkdir(args.output)
diff --git a/python/pkg/cdec/sa/extract.py b/python/pkg/cdec/sa/extract.py
index 10a81556..782bed8b 100644
--- a/python/pkg/cdec/sa/extract.py
+++ b/python/pkg/cdec/sa/extract.py
@@ -1,19 +1,25 @@
#!/usr/bin/env python
import sys
import os
+import re
+import gzip
import argparse
import logging
-import multiprocessing as mp
import signal
+import multiprocessing as mp
import cdec.sa
extractor, prefix = None, None
-def make_extractor(config, grammars, features):
- global extractor, prefix
+online, compress = False, False
+
+def make_extractor(args):
+ global extractor, prefix, online, compress
signal.signal(signal.SIGINT, signal.SIG_IGN) # Let parent process catch Ctrl+C
- load_features(features)
- extractor = cdec.sa.GrammarExtractor(config)
- prefix = grammars
+ load_features(args.features)
+ extractor = cdec.sa.GrammarExtractor(args.config, online)
+ prefix = args.grammars
+ online = args.online
+ compress = args.compress
def load_features(features):
for featdef in features:
@@ -24,15 +30,36 @@ def load_features(features):
sys.path.remove(prefix)
def extract(inp):
- global extractor, prefix
+ global extractor, prefix, online, compress
i, sentence = inp
sentence = sentence[:-1]
- grammar_file = os.path.join(prefix, 'grammar.{0}'.format(i))
- with open(grammar_file, 'w') as output:
+ fields = re.split('\s*\|\|\|\s*', sentence)
+ suffix = ''
+ # 3 fields for online mode, 1 for normal
+ if online:
+ if len(fields) < 3:
+ sys.stderr.write('Error: online mode requires references and alignments.'
+ ' Not adding sentence to training data: {}\n'.format(sentence))
+ sentence = fields[0]
+ else:
+ sentence, reference, alignment = fields[0:3]
+ if len(fields) > 3:
+ suffix = ' ||| ' + ' ||| '.join(fields[3:])
+ else:
+ if len(fields) > 1:
+ sentence = fields[0]
+ suffix = ' ||| ' + ' ||| '.join(fields[1:])
+
+ grammar_file = os.path.join(prefix, 'grammar.'+str(i))
+ if compress: grammar_file += '.gz'
+ with (gzip.open if compress else open)(grammar_file, 'w') as output:
for rule in extractor.grammar(sentence):
output.write(str(rule)+'\n')
+ # Add training instance _after_ extracting grammars
+ if online:
+ extractor.add_instance(sentence, reference, alignment)
grammar_file = os.path.abspath(grammar_file)
- return '<seg grammar="{0}" id="{1}">{2}</seg>'.format(grammar_file, i, sentence)
+ return '<seg grammar="{}" id="{}">{}</seg>{}'.format(grammar_file, i, sentence, suffix)
def main():
logging.basicConfig(level=logging.INFO)
@@ -47,26 +74,30 @@ def main():
help='number of sentences / chunk')
parser.add_argument('-f', '--features', nargs='*', default=[],
help='additional feature definitions')
+ parser.add_argument('-o', '--online', action='store_true',
+ help='online grammar extraction')
+ parser.add_argument('-z', '--compress', action='store_true',
+ help='compress grammars with gzip')
args = parser.parse_args()
if not os.path.exists(args.grammars):
os.mkdir(args.grammars)
for featdef in args.features:
if not featdef.endswith('.py'):
- sys.stderr.write('Error: feature definition file <{0}>'
+ sys.stderr.write('Error: feature definition file <{}>'
' should be a python module\n'.format(featdef))
sys.exit(1)
if args.jobs > 1:
logging.info('Starting %d workers; chunk size: %d', args.jobs, args.chunksize)
- pool = mp.Pool(args.jobs, make_extractor, (args.config, args.grammars, args.features))
+ pool = mp.Pool(args.jobs, make_extractor, (args,))
try:
for output in pool.imap(extract, enumerate(sys.stdin), args.chunksize):
print(output)
except KeyboardInterrupt:
pool.terminate()
else:
- make_extractor(args.config, args.grammars, args.features)
+ make_extractor(args)
for output in map(extract, enumerate(sys.stdin)):
print(output)
diff --git a/python/pkg/cdec/sa/extractor.py b/python/pkg/cdec/sa/extractor.py
index a5ce8a68..acc13cbc 100644
--- a/python/pkg/cdec/sa/extractor.py
+++ b/python/pkg/cdec/sa/extractor.py
@@ -1,16 +1,17 @@
from itertools import chain
-import os
+import os, sys
import cdec.configobj
from cdec.sa.features import EgivenFCoherent, SampleCountF, CountEF,\
- MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE
+ MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE,\
+ IsSupportedOnline
import cdec.sa
# maximum span of a grammar rule in TEST DATA
MAX_INITIAL_SIZE = 15
class GrammarExtractor:
- def __init__(self, config, features=None):
- if isinstance(config, str) or isinstance(config, unicode):
+ def __init__(self, config, online=False, features=None):
+ if isinstance(config, basestring):
if not os.path.exists(config):
raise IOError('cannot read configuration from {0}'.format(config))
config = cdec.configobj.ConfigObj(config, unrepr=True)
@@ -50,18 +51,26 @@ class GrammarExtractor:
train_max_initial_size=config['max_size'],
# minimum span of an RHS nonterminal in a rule extracted from TRAINING DATA
train_min_gap_size=config['min_gap'],
- # True if phrases should be tight, False otherwise (better but slower)
- tight_phrases=True,
+ # False if phrases should be loose (better but slower), True otherwise
+ tight_phrases=config.get('tight_phrases', True),
)
# lexical weighting tables
tt = cdec.sa.BiLex(from_binary=config['lex_file'])
+ # TODO: clean this up
+ extended_features = []
+ if online:
+ extended_features.append(IsSupportedOnline)
+
# TODO: use @cdec.sa.features decorator for standard features too
# + add a mask to disable features
+ for f in cdec.sa._SA_FEATURES:
+ extended_features.append(f)
+
scorer = cdec.sa.Scorer(EgivenFCoherent, SampleCountF, CountEF,
MaxLexFgivenE(tt), MaxLexEgivenF(tt), IsSingletonF, IsSingletonFE,
- *cdec.sa._SA_FEATURES)
+ *extended_features)
fsarray = cdec.sa.SuffixArray(from_binary=config['f_sa_file'])
edarray = cdec.sa.DataArray(from_binary=config['e_file'])
@@ -82,3 +91,16 @@ class GrammarExtractor:
meta = cdec.sa.annotate(words)
cnet = cdec.sa.make_lattice(words)
return self.factory.input(cnet, meta)
+
+ # Add training instance to data
+ def add_instance(self, sentence, reference, alignment):
+ f_words = cdec.sa.encode_words(sentence.split())
+ e_words = cdec.sa.encode_words(reference.split())
+ al = sorted(tuple(int(i) for i in pair.split('-')) for pair in alignment.split())
+ self.factory.add_instance(f_words, e_words, al)
+
+ # Debugging
+ def dump_online_stats(self):
+ self.factory.dump_online_stats()
+ def dump_online_rules(self):
+ self.factory.dump_online_rules() \ No newline at end of file
diff --git a/python/pkg/cdec/sa/features.py b/python/pkg/cdec/sa/features.py
index a4ae23e8..46412cd5 100644
--- a/python/pkg/cdec/sa/features.py
+++ b/python/pkg/cdec/sa/features.py
@@ -1,57 +1,142 @@
from __future__ import division
import math
+from cdec.sa import isvar
+
MAXSCORE = 99
def EgivenF(ctx): # p(e|f) = c(e, f)/c(f)
- return -math.log10(ctx.paircount/ctx.fcount)
+ if not ctx.online:
+ prob = ctx.paircount/ctx.fcount
+ else:
+ prob = (ctx.paircount + ctx.online.paircount) / (ctx.fcount + ctx.online.fcount)
+ return -math.log10(prob)
def CountEF(ctx): # c(e, f)
- return math.log10(1 + ctx.paircount)
+ if not ctx.online:
+ count = 1 + ctx.paircount
+ else:
+ count = 1 + ctx.paircount + ctx.online.paircount
+ return math.log10(count)
def SampleCountF(ctx): # sample c(f)
- return math.log10(1 + ctx.fsample_count)
+ if not ctx.online:
+ count = 1 + ctx.fsample_count
+ else:
+ count = 1 + ctx.fsample_count + ctx.online.fsample_count
+ return math.log10(count)
def EgivenFCoherent(ctx): # c(e, f) / sample c(f)
- prob = ctx.paircount/ctx.fsample_count
+ if not ctx.online:
+ prob = ctx.paircount/ctx.fsample_count
+ else:
+ prob = (ctx.paircount + ctx.online.paircount) / (ctx.fsample_count + ctx.online.fsample_count)
return -math.log10(prob) if prob > 0 else MAXSCORE
def CoherenceProb(ctx): # c(f) / sample c(f)
- return -math.log10(ctx.fcount/ctx.fsample_count)
+ if not ctx.online:
+ prob = ctx.fcount/ctx.fsample_count
+ else:
+ prob = (ctx.fcount + ctx.online.fcount) / (ctx.fsample_count + ctx.online.fsample_count)
+ return -math.log10(prob)
def MaxLexEgivenF(ttable):
def MaxLexEgivenF(ctx):
fwords = ctx.fphrase.words
fwords.append('NULL')
- def score():
+ # Always use this for now
+ if not ctx.online or ctx.online:
+ maxOffScore = 0.0
+ for e in ctx.ephrase.words:
+ maxScore = max(ttable.get_score(f, e, 0) for f in fwords)
+ maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ return maxOffScore
+ else:
+ # For now, straight average
+ maxOffScore = 0.0
+ maxOnScore = 0.0
for e in ctx.ephrase.words:
- maxScore = max(ttable.get_score(f, e, 0) for f in fwords)
- yield -math.log10(maxScore) if maxScore > 0 else MAXSCORE
- return sum(score())
+ maxScore = max(ttable.get_score(f, e, 0) for f in fwords)
+ maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ for e in ctx.ephrase:
+ if not isvar(e):
+ maxScore = 0.0
+ for f in ctx.fphrase:
+ if not isvar(f):
+ b_f = ctx.online.bilex_f.get(f, 0)
+ if b_f:
+ maxScore = max(maxScore, ctx.online.bilex_fe.get(f, {}).get(e))
+ maxOnScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ return (maxOffScore + maxOnScore) / 2
return MaxLexEgivenF
def MaxLexFgivenE(ttable):
def MaxLexFgivenE(ctx):
ewords = ctx.ephrase.words
ewords.append('NULL')
- def score():
+ # Always use this for now
+ if not ctx.online or ctx.online:
+ maxOffScore = 0.0
for f in ctx.fphrase.words:
- maxScore = max(ttable.get_score(f, e, 1) for e in ewords)
- yield -math.log10(maxScore) if maxScore > 0 else MAXSCORE
- return sum(score())
+ maxScore = max(ttable.get_score(f, e, 1) for e in ewords)
+ maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ return maxOffScore
+ else:
+ # For now, straight average
+ maxOffScore = 0.0
+ maxOnScore = 0.0
+ for f in ctx.fphrase.words:
+ maxScore = max(ttable.get_score(f, e, 1) for e in ewords)
+ maxOffScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ for f in ctx.fphrase:
+ if not isvar(f):
+ maxScore = 0.0
+ for e in ctx.ephrase:
+ if not isvar(e):
+ b_e = ctx.online.bilex_e.get(e, 0)
+ if b_e:
+ maxScore = max(maxScore, ctx.online.bilex_fe.get(f, {}).get(e, 0) / b_e )
+ maxOnScore += -math.log10(maxScore) if maxScore > 0 else MAXSCORE
+ return (maxOffScore + maxOnScore) / 2
return MaxLexFgivenE
def IsSingletonF(ctx):
- return (ctx.fcount == 1)
+ if not ctx.online:
+ count = ctx.fcount
+ else:
+ count = ctx.fcount + ctx.online.fcount
+ return (count == 1)
def IsSingletonFE(ctx):
- return (ctx.paircount == 1)
+ if not ctx.online:
+ count = ctx.paircount
+ else:
+ count = ctx.paircount + ctx.online.paircount
+ return (count == 1)
def IsNotSingletonF(ctx):
- return (ctx.fcount > 1)
+ if not ctx.online:
+ count = ctx.fcount
+ else:
+ count = ctx.fcount + ctx.online.fcount
+ return (count > 1)
def IsNotSingletonFE(ctx):
+ if not ctx.online:
+ count = ctx.paircount
+ else:
+ count = ctx.paircount + ctx.online.paircount
return (ctx.paircount > 1)
def IsFEGreaterThanZero(ctx):
+ if not ctx.online:
+ count = ctx.paircount
+ else:
+ count = ctx.paircount + ctx.online.paircount
return (ctx.paircount > 0.01)
+
+def IsSupportedOnline(ctx): # Occurs in online data?
+ if ctx.online:
+ return (ctx.online.paircount > 0.01)
+ else:
+ return False \ No newline at end of file