diff options
Diffstat (limited to 'src/nl_reweighter.py')
-rw-r--r-- | src/nl_reweighter.py | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/src/nl_reweighter.py b/src/nl_reweighter.py new file mode 100644 index 0000000..fcc8f85 --- /dev/null +++ b/src/nl_reweighter.py @@ -0,0 +1,227 @@ +import gzip +import re +from nlp_tools.hypergraph import Hypergraph +import itertools +import logging +from collections import defaultdict +import os + +class Rule: + + MOSES_SYMBOL = '[X]' + + def __init__(self, rule_id, symbol, src, tgt, coindexing): + self.rule_id = rule_id + self.symbol = symbol + self.src = src + self.tgt = tgt + self.coindexing = coindexing + self.degree = len(self.coindexing) + + @classmethod + def from_moses(cls, rule_id, rule_table_line): + nl, mrl, scores, alignments, counts = re.split(r'\ ?\|\|\|\ ?', + rule_table_line.strip()) + nl = nl.split()[:-1] + nl = [cls.MOSES_SYMBOL if t == '[X][X]' else t for t in nl] + mrl = mrl.split()[:-1] + mrl = [cls.MOSES_SYMBOL if t == '[X][X]' else t for t in mrl] + coindexing = [] + for pair in alignments.split(): + i_s, i_t = pair.split('-') + coindexing.append((int(i_s), int(i_t))) + return Rule(rule_id, cls.MOSES_SYMBOL, nl, mrl, coindexing) + + @classmethod + def glue(cls, rule_id): + return Rule(rule_id, cls.MOSES_SYMBOL, [cls.MOSES_SYMBOL, cls.MOSES_SYMBOL], + [cls.MOSES_SYMBOL, cls.MOSES_SYMBOL], [(0,0), (1,1)]) + + def __eq__(self, other): + return other.__class__ == self.__class__ and self.rule_id == other.rule_id + + def __hash__(self): + return self.rule_id + + def __repr__(self): + return 'Rule<(%d) %s -> %s : %s>' % (self.rule_id, self.symbol, self.src, + self.tgt) + +class NLReweighter: + + def __init__(self, config): + self.config = config + + def run(self): + rules = self.load_rule_table() + glue = Rule.glue(len(rules)) + all_counts = defaultdict(lambda: 0) + successful_counts = defaultdict(lambda: 0) + + with open('%s/unlabeled.nl' % self.config.experiment_dir) as ul_f: + for line in ul_f: + toks = line.strip().split() + chart = self.parse(toks, rules, glue) + if not chart: + continue + self.collect_all_counts(all_counts, chart) + self.collect_successful_counts(successful_counts, chart, toks) + + if not self.config.ul_only: + with open('%s/train.nl' % self.config.experiment_dir) as t_f: + for line in t_f: + toks = line.strip().split() + chart = self.parse(toks, rules, glue) + # TODO is this an OOV issue? + if not chart: + continue + self.collect_all_counts(all_counts, chart) + self.collect_successful_counts(successful_counts, chart, toks) + + #self.write_updated_model(all_counts) + self.write_updated_model(successful_counts) + + def load_rule_table(self): + rule_table_path = '%s/model/rule-table.gz' % self.config.experiment_dir + rules = {} + with gzip.open(rule_table_path) as rule_table_f: + for line in rule_table_f.readlines(): + rule = Rule.from_moses(len(rules), line) + rules[rule.rule_id] = rule + return rules + + def write_updated_model(self, counts): + old_rule_table_path = '%s/model/rule-table.gz' % self.config.experiment_dir + new_rule_table_path = '%s/model/rule-table-new.gz' % self.config.experiment_dir + counter = 0 + with gzip.open(old_rule_table_path) as old_rule_table_f: + with gzip.open(new_rule_table_path, 'w') as new_rule_table_f: + for line in old_rule_table_f: + nl, mrl, scores, alignments, rule_counts = re.split(r'\ ?\|\|\|\ ?', + line.strip()) + scores = '%s %f' % (scores, counts[counter]) + newline = ' ||| '.join([nl, mrl, scores, alignments, rule_counts]) + newline = re.sub(r'\s+', ' ', newline) + print >>new_rule_table_f, newline + counter += 1 + + old_config_path = '%s/model/moses.ini' % self.config.experiment_dir + new_config_path = '%s/model/moses-new.ini' % self.config.experiment_dir + with open(old_config_path) as old_config_f: + with open(new_config_path, 'w') as new_config_f: + for line in old_config_f: + if line[-14:-1] == 'rule-table.gz': + line = line[:6] + '6' + line[7:] + #line[6] = '6' + print >>new_config_f, line, + if line == '[weight-t]\n': + print >>new_config_f, '0.20' + + os.rename(new_rule_table_path, old_rule_table_path) + os.rename(new_config_path, old_config_path) + + def parse(self, sent, grammar, glue): + chart = dict() + + for span in range(1, len(sent)+1): + for start in range(len(sent)+1-span): + chart[start,span] = list() + for rule in grammar.values(): + matches = self.match(sent, rule, start, span, chart) + chart[start,span] += matches + + for i in range(1, len(sent)): + if chart[0,i] and chart[i,len(sent)-i]: + psets = [(c1, c2) for c1 in chart[0,i] for c2 in chart[i,len(sent)-i]] + chart[0,len(sent)].append(Hypergraph(glue, psets)) + + if not chart[0,len(sent)]: + #logging.debug('failed to parse') + return None + else: + #logging.debug('parse OK!') + return chart + + def match(self, sent, rule, start, span, chart): + + if rule.degree == 0: + if span != len(rule.src): + return [] + if sent[start:start+span] != rule.src: + return [] + return [Hypergraph(rule, [])] + + elif rule.degree == 1: + nt_start = start + rule.coindexing[0][0] + nt_span = span - len(rule.src) + 1 + if nt_span <= 0: + return [] + if sent[start:nt_start] != rule.src[0:rule.coindexing[0][0]]: + return [] + if sent[nt_start+nt_span:start+span] != rule.src[rule.coindexing[0][0]+1:]: + return [] + + pointer_sets = [i for i in chart[nt_start, nt_span] if i.label.symbol == + rule.src[rule.coindexing[0][0]]] + ## if not chart[nt_start, nt_span]: + ## return [] + if not pointer_sets: + return [] + return [Hypergraph(rule, [(i,) for i in pointer_sets])] + + elif rule.degree == 2: + matches = [] + before_dist = rule.coindexing[0][0] + between_dist = rule.coindexing[1][0] - rule.coindexing[0][0] - 1 + before_2_dist = rule.coindexing[1][0] + nt_total_span = span - len(rule.src) + 2 + if nt_total_span <= 0: + return [] + nt1_start = start + before_dist + for nt1_span in range(1,nt_total_span): + nt2_start = nt1_start + nt1_span + between_dist + nt2_span = nt_total_span - nt1_span + + if sent[start:nt1_start] != rule.src[0:before_dist]: + continue + if sent[nt1_start+nt1_span:nt2_start] != rule.src[before_dist+1:before_2_dist]: + continue + if sent[nt2_start+nt2_span:start+span] != rule.src[before_2_dist+1:]: + continue + + pointer_sets_1 = [i for i in chart[nt1_start,nt1_span] if i.label.symbol == + rule.src[rule.coindexing[0][0]]] + pointer_sets_2 = [i for i in chart[nt2_start,nt2_span] if i.label.symbol == + rule.src[rule.coindexing[1][0]]] + + if not (pointer_sets_1 and pointer_sets_2): + continue + + matches.append(Hypergraph(rule, list(itertools.product(pointer_sets_1, + pointer_sets_2)))) + #matches.append(rule.rule_id) + + return matches + + assert False + + def collect_all_counts(self, counts, chart): + for cell in chart.values(): + for node in cell: + counts[node.label.rule_id] += 1 + + def collect_successful_counts(self, counts, chart, sent): + used = set() + for cell in chart[0, len(sent)]: + self.mark_used(used, cell) + for cell in chart.values(): + for node in cell: + if node in used: + counts[node.label.rule_id] += 1 + + def mark_used(self, used, cell): + for edge in cell.edges: + for ccell in edge: + if ccell not in used: + self.mark_used(used, ccell) + used.add(cell) |