summaryrefslogtreecommitdiff
path: root/src/nl_reweighter.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/nl_reweighter.py')
-rw-r--r--src/nl_reweighter.py227
1 files changed, 227 insertions, 0 deletions
diff --git a/src/nl_reweighter.py b/src/nl_reweighter.py
new file mode 100644
index 0000000..fcc8f85
--- /dev/null
+++ b/src/nl_reweighter.py
@@ -0,0 +1,227 @@
+import gzip
+import re
+from nlp_tools.hypergraph import Hypergraph
+import itertools
+import logging
+from collections import defaultdict
+import os
+
+class Rule:
+
+ MOSES_SYMBOL = '[X]'
+
+ def __init__(self, rule_id, symbol, src, tgt, coindexing):
+ self.rule_id = rule_id
+ self.symbol = symbol
+ self.src = src
+ self.tgt = tgt
+ self.coindexing = coindexing
+ self.degree = len(self.coindexing)
+
+ @classmethod
+ def from_moses(cls, rule_id, rule_table_line):
+ nl, mrl, scores, alignments, counts = re.split(r'\ ?\|\|\|\ ?',
+ rule_table_line.strip())
+ nl = nl.split()[:-1]
+ nl = [cls.MOSES_SYMBOL if t == '[X][X]' else t for t in nl]
+ mrl = mrl.split()[:-1]
+ mrl = [cls.MOSES_SYMBOL if t == '[X][X]' else t for t in mrl]
+ coindexing = []
+ for pair in alignments.split():
+ i_s, i_t = pair.split('-')
+ coindexing.append((int(i_s), int(i_t)))
+ return Rule(rule_id, cls.MOSES_SYMBOL, nl, mrl, coindexing)
+
+ @classmethod
+ def glue(cls, rule_id):
+ return Rule(rule_id, cls.MOSES_SYMBOL, [cls.MOSES_SYMBOL, cls.MOSES_SYMBOL],
+ [cls.MOSES_SYMBOL, cls.MOSES_SYMBOL], [(0,0), (1,1)])
+
+ def __eq__(self, other):
+ return other.__class__ == self.__class__ and self.rule_id == other.rule_id
+
+ def __hash__(self):
+ return self.rule_id
+
+ def __repr__(self):
+ return 'Rule<(%d) %s -> %s : %s>' % (self.rule_id, self.symbol, self.src,
+ self.tgt)
+
+class NLReweighter:
+
+ def __init__(self, config):
+ self.config = config
+
+ def run(self):
+ rules = self.load_rule_table()
+ glue = Rule.glue(len(rules))
+ all_counts = defaultdict(lambda: 0)
+ successful_counts = defaultdict(lambda: 0)
+
+ with open('%s/unlabeled.nl' % self.config.experiment_dir) as ul_f:
+ for line in ul_f:
+ toks = line.strip().split()
+ chart = self.parse(toks, rules, glue)
+ if not chart:
+ continue
+ self.collect_all_counts(all_counts, chart)
+ self.collect_successful_counts(successful_counts, chart, toks)
+
+ if not self.config.ul_only:
+ with open('%s/train.nl' % self.config.experiment_dir) as t_f:
+ for line in t_f:
+ toks = line.strip().split()
+ chart = self.parse(toks, rules, glue)
+ # TODO is this an OOV issue?
+ if not chart:
+ continue
+ self.collect_all_counts(all_counts, chart)
+ self.collect_successful_counts(successful_counts, chart, toks)
+
+ #self.write_updated_model(all_counts)
+ self.write_updated_model(successful_counts)
+
+ def load_rule_table(self):
+ rule_table_path = '%s/model/rule-table.gz' % self.config.experiment_dir
+ rules = {}
+ with gzip.open(rule_table_path) as rule_table_f:
+ for line in rule_table_f.readlines():
+ rule = Rule.from_moses(len(rules), line)
+ rules[rule.rule_id] = rule
+ return rules
+
+ def write_updated_model(self, counts):
+ old_rule_table_path = '%s/model/rule-table.gz' % self.config.experiment_dir
+ new_rule_table_path = '%s/model/rule-table-new.gz' % self.config.experiment_dir
+ counter = 0
+ with gzip.open(old_rule_table_path) as old_rule_table_f:
+ with gzip.open(new_rule_table_path, 'w') as new_rule_table_f:
+ for line in old_rule_table_f:
+ nl, mrl, scores, alignments, rule_counts = re.split(r'\ ?\|\|\|\ ?',
+ line.strip())
+ scores = '%s %f' % (scores, counts[counter])
+ newline = ' ||| '.join([nl, mrl, scores, alignments, rule_counts])
+ newline = re.sub(r'\s+', ' ', newline)
+ print >>new_rule_table_f, newline
+ counter += 1
+
+ old_config_path = '%s/model/moses.ini' % self.config.experiment_dir
+ new_config_path = '%s/model/moses-new.ini' % self.config.experiment_dir
+ with open(old_config_path) as old_config_f:
+ with open(new_config_path, 'w') as new_config_f:
+ for line in old_config_f:
+ if line[-14:-1] == 'rule-table.gz':
+ line = line[:6] + '6' + line[7:]
+ #line[6] = '6'
+ print >>new_config_f, line,
+ if line == '[weight-t]\n':
+ print >>new_config_f, '0.20'
+
+ os.rename(new_rule_table_path, old_rule_table_path)
+ os.rename(new_config_path, old_config_path)
+
+ def parse(self, sent, grammar, glue):
+ chart = dict()
+
+ for span in range(1, len(sent)+1):
+ for start in range(len(sent)+1-span):
+ chart[start,span] = list()
+ for rule in grammar.values():
+ matches = self.match(sent, rule, start, span, chart)
+ chart[start,span] += matches
+
+ for i in range(1, len(sent)):
+ if chart[0,i] and chart[i,len(sent)-i]:
+ psets = [(c1, c2) for c1 in chart[0,i] for c2 in chart[i,len(sent)-i]]
+ chart[0,len(sent)].append(Hypergraph(glue, psets))
+
+ if not chart[0,len(sent)]:
+ #logging.debug('failed to parse')
+ return None
+ else:
+ #logging.debug('parse OK!')
+ return chart
+
+ def match(self, sent, rule, start, span, chart):
+
+ if rule.degree == 0:
+ if span != len(rule.src):
+ return []
+ if sent[start:start+span] != rule.src:
+ return []
+ return [Hypergraph(rule, [])]
+
+ elif rule.degree == 1:
+ nt_start = start + rule.coindexing[0][0]
+ nt_span = span - len(rule.src) + 1
+ if nt_span <= 0:
+ return []
+ if sent[start:nt_start] != rule.src[0:rule.coindexing[0][0]]:
+ return []
+ if sent[nt_start+nt_span:start+span] != rule.src[rule.coindexing[0][0]+1:]:
+ return []
+
+ pointer_sets = [i for i in chart[nt_start, nt_span] if i.label.symbol ==
+ rule.src[rule.coindexing[0][0]]]
+ ## if not chart[nt_start, nt_span]:
+ ## return []
+ if not pointer_sets:
+ return []
+ return [Hypergraph(rule, [(i,) for i in pointer_sets])]
+
+ elif rule.degree == 2:
+ matches = []
+ before_dist = rule.coindexing[0][0]
+ between_dist = rule.coindexing[1][0] - rule.coindexing[0][0] - 1
+ before_2_dist = rule.coindexing[1][0]
+ nt_total_span = span - len(rule.src) + 2
+ if nt_total_span <= 0:
+ return []
+ nt1_start = start + before_dist
+ for nt1_span in range(1,nt_total_span):
+ nt2_start = nt1_start + nt1_span + between_dist
+ nt2_span = nt_total_span - nt1_span
+
+ if sent[start:nt1_start] != rule.src[0:before_dist]:
+ continue
+ if sent[nt1_start+nt1_span:nt2_start] != rule.src[before_dist+1:before_2_dist]:
+ continue
+ if sent[nt2_start+nt2_span:start+span] != rule.src[before_2_dist+1:]:
+ continue
+
+ pointer_sets_1 = [i for i in chart[nt1_start,nt1_span] if i.label.symbol ==
+ rule.src[rule.coindexing[0][0]]]
+ pointer_sets_2 = [i for i in chart[nt2_start,nt2_span] if i.label.symbol ==
+ rule.src[rule.coindexing[1][0]]]
+
+ if not (pointer_sets_1 and pointer_sets_2):
+ continue
+
+ matches.append(Hypergraph(rule, list(itertools.product(pointer_sets_1,
+ pointer_sets_2))))
+ #matches.append(rule.rule_id)
+
+ return matches
+
+ assert False
+
+ def collect_all_counts(self, counts, chart):
+ for cell in chart.values():
+ for node in cell:
+ counts[node.label.rule_id] += 1
+
+ def collect_successful_counts(self, counts, chart, sent):
+ used = set()
+ for cell in chart[0, len(sent)]:
+ self.mark_used(used, cell)
+ for cell in chart.values():
+ for node in cell:
+ if node in used:
+ counts[node.label.rule_id] += 1
+
+ def mark_used(self, used, cell):
+ for edge in cell.edges:
+ for ccell in edge:
+ if ccell not in used:
+ self.mark_used(used, ccell)
+ used.add(cell)