diff options
Diffstat (limited to 'src/extractor.py')
-rw-r--r-- | src/extractor.py | 418 |
1 files changed, 418 insertions, 0 deletions
diff --git a/src/extractor.py b/src/extractor.py new file mode 100644 index 0000000..ff2abdb --- /dev/null +++ b/src/extractor.py @@ -0,0 +1,418 @@ +from nltk.stem.porter import PorterStemmer +from nltk.stem.snowball import GermanStemmer +import os +import re +import util +import xml.etree.ElementTree as ET + +class IdStemmer: + def stem(self, word): + return word + +class Extractor: + + NP_WEIGHT = 50 + + def __init__(self, config): + self.config = config + if config.stem: + if config.lang == 'en': + self.stemmer = PorterStemmer() + elif config.lang == 'de': + self.stemmer = GermanStemmer() + else: + self.stemmer = IdStemmer() + + def run(self): + if self.config.corpus == 'geo': + self.run_geo() + elif self.config.corpus == 'robo': + self.run_robo() + elif self.config.corpus == 'atis': + self.run_atis() + else: + assert False + + def run_atis(self): + + train_nl = open('%s/train.nl' % self.config.experiment_dir, 'w') + train_nl_lm = open('%s/train.nl.lm' % self.config.experiment_dir, 'w') + train_nl_np = open('%s/train.np.nl' % self.config.experiment_dir, 'w') + train_mrl = open('%s/train.mrl' % self.config.experiment_dir, 'w') + train_mrl_lm = open('%s/train.mrl.lm' % self.config.experiment_dir, 'w') + train_mrl_np = open('%s/train.np.mrl' % self.config.experiment_dir, 'w') + train_fun = open('%s/train.fun' % self.config.experiment_dir, 'w') + tune_nl = open('%s/tune.nl' % self.config.experiment_dir, 'w') + tune_mrl = open('%s/tune.mrl' % self.config.experiment_dir, 'w') + test_nl = open('%s/test.nl' % self.config.experiment_dir, 'w') + test_mrl = open('%s/test.mrl' % self.config.experiment_dir, 'w') + test_fun = open('%s/test.fun' % self.config.experiment_dir, 'w') + + if self.config.run == 'debug': + with open('%s/atis-train.sem' % self.config.data_dir) as data_file: + counter = 0 + for line in data_file: + nl, slot = line.split('<=>', 1) + nl = self.preprocess_nl(nl) + slot = self.replace_specials(slot) + fun = self.slot_to_fun(slot) + mrl = util.fun_to_mrl(fun, True) + if counter % 4 in (0,1): + print >>train_nl, nl + print >>train_mrl, mrl + print >>train_fun, fun + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, '<s>', nl, '</s>' + print >>train_mrl_lm, '<s>', mrl, '</s>' + elif counter % 4 == 2: + print >>tune_nl, nl + print >>tune_mrl, mrl + else: + print >>test_nl, nl + print >>test_mrl, mrl + print >>test_fun, fun + counter += 1 + + else: + train_path = '%s/atis-train.sem' % self.config.data_dir + if self.config.run == 'dev': + tune_path = train_path + test_path = '%s/atis-dev.sem' % self.config.data_dir + elif self.config.run == 'test': + tune_path = '%s/atis-dev.sem' % self.config.data_dir + test_path = '%s/atis-test.sem' % self.config.data_dir + + with open(train_path) as train_file: + for line in train_file: + nl, slot = line.split('<=>', 1) + nl = self.preprocess_nl(nl) + slot = self.replace_specials(slot) + fun = self.slot_to_fun(slot) + mrl = util.fun_to_mrl(fun, True) + print >>train_nl, nl + print >>train_mrl, mrl + print >>train_fun, fun + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, '<s>', nl, '</s>' + print >>train_mrl_lm, '<s>', mrl, '</s>' + + with open(tune_path) as tune_file: + for line in tune_file: + nl, slot = line.split('<=>', 1) + nl = self.preprocess_nl(nl) + slot = self.replace_specials(slot) + fun = self.slot_to_fun(slot) + mrl = util.fun_to_mrl(fun, True) + print >>tune_nl, nl + print >>tune_mrl, mrl + + with open(test_path) as test_file: + for line in test_file: + nl, slot = line.split('<=>', 1) + nl = self.preprocess_nl(nl) + slot = self.replace_specials(slot) + fun = self.slot_to_fun(slot) + mrl = util.fun_to_mrl(fun, True) + print >>test_nl, nl + print >>test_mrl, mrl + print >>test_fun, fun + + for np_name in os.listdir('%s/db' % self.config.data_dir): + np_path = '%s/db/%s' % (self.config.data_dir, np_name) + with open(np_path) as np_file: + for line in np_file: + names = re.findall(r'"([^"]+)"', line) + for name in names: + nl = name + mrl = "%s" % self.replace_specials(name) + mrl = mrl.replace(' ', '_') + mrl = mrl + '@s' + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, nl + print >>train_mrl_lm, mrl + + train_nl.close() + train_nl_lm.close() + train_mrl.close() + train_mrl_lm.close() + train_fun.close() + test_nl.close() + test_mrl.close() + test_fun.close() + tune_nl.close() + tune_mrl.close() + + def run_robo(self): + + train_ids, tune_ids, test_ids = self.get_folds() + tune_ids = test_ids + + train_nl = open('%s/train.nl' % self.config.experiment_dir, 'w') + train_nl_lm = open('%s/train.nl.lm' % self.config.experiment_dir, 'w') + train_nl_np = open('%s/train.np.nl' % self.config.experiment_dir, 'w') + train_mrl = open('%s/train.mrl' % self.config.experiment_dir, 'w') + train_mrl_lm = open('%s/train.mrl.lm' % self.config.experiment_dir, 'w') + train_mrl_np = open('%s/train.np.mrl' % self.config.experiment_dir, 'w') + train_fun = open('%s/train.fun' % self.config.experiment_dir, 'w') + tune_nl = open('%s/tune.nl' % self.config.experiment_dir, 'w') + tune_mrl = open('%s/tune.mrl' % self.config.experiment_dir, 'w') + test_nl = open('%s/test.nl' % self.config.experiment_dir, 'w') + test_mrl = open('%s/test.mrl' % self.config.experiment_dir, 'w') + test_fun = open('%s/test.fun' % self.config.experiment_dir, 'w') + + corpus = ET.parse('%s/corpus.xml' % self.config.data_dir) + corpus_root = corpus.getroot() + + for node in corpus_root.findall('example'): + nl = node.find("nl[@lang='%s']" % self.config.lang).text + nl = self.preprocess_nl(nl) + clang = node.find("mrl[@lang='robocup-clang']").text + clang = self.replace_specials(clang) + fun = self.clang_to_fun(clang) + #print fun + mrl = util.fun_to_mrl(fun) + eid = int(node.attrib['id']) + + if eid in tune_ids: + print >>tune_nl, nl + print >>tune_mrl, mrl + elif eid in train_ids: + print >>train_nl, nl + print >>train_mrl, mrl + print >>train_fun, fun + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, '<s>', nl, '</s>' + print >>train_mrl_lm, '<s>', mrl, '</s>' + if eid in test_ids: + #elif eid in test_ids: + print >>test_nl, nl + print >>test_mrl, mrl + print >>test_fun, fun + + nps_file = open('%s/names' % self.config.data_dir) + while True: + line = nps_file.readline() + if not line: + break + nl = nps_file.readline().strip()[3:] + nl = self.preprocess_nl(nl) + nps_file.readline() + nps_file.readline() + while True: + line = nps_file.readline().strip() + if line == '': + break + m = re.match('^\*n:(Num|Unum|Ident) -> \(\{ (\S+) \}\)$', line) + mrl = m.group(2) + '@0' + for i in range(self.NP_WEIGHT): + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, nl + print >>train_mrl_lm, mrl + + train_nl.close() + train_nl_lm.close() + train_mrl.close() + train_mrl_lm.close() + train_fun.close() + test_nl.close() + test_mrl.close() + test_fun.close() + tune_nl.close() + tune_mrl.close() + + def run_geo(self): + train_ids, tune_ids, test_ids = self.get_folds() + + train_nl = open('%s/train.nl' % self.config.experiment_dir, 'w') + train_nl_lm = open('%s/train.nl.lm' % self.config.experiment_dir, 'w') + train_nl_np = open('%s/train.np.nl' % self.config.experiment_dir, 'w') + train_mrl = open('%s/train.mrl' % self.config.experiment_dir, 'w') + train_mrl_lm = open('%s/train.mrl.lm' % self.config.experiment_dir, 'w') + train_mrl_np = open('%s/train.np.mrl' % self.config.experiment_dir, 'w') + train_fun = open('%s/train.fun' % self.config.experiment_dir, 'w') + unlabeled_nl = open('%s/unlabeled.nl' % self.config.experiment_dir, 'w') + tune_nl = open('%s/tune.nl' % self.config.experiment_dir, 'w') + tune_mrl = open('%s/tune.mrl' % self.config.experiment_dir, 'w') + test_nl = open('%s/test.nl' % self.config.experiment_dir, 'w') + test_mrl = open('%s/test.mrl' % self.config.experiment_dir, 'w') + test_fun = open('%s/test.fun' % self.config.experiment_dir, 'w') + + corpus = ET.parse('%s/corpus-true.xml' % self.config.data_dir) + corpus_root = corpus.getroot() + + counter = 0 + #stop_labeling = False + for node in corpus_root.findall('example'): + nl = node.find("nl[@lang='%s']" % self.config.lang).text + nl = self.preprocess_nl(nl) + fun = node.find("mrl[@lang='geo-funql']").text + fun = self.preprocess_fun(fun) + #fun = self.replace_specials(fun) + mrl = util.fun_to_mrl(fun) + eid = int(node.attrib['id']) + + unlabel_this = (counter >= 10 * self.config.lfrac) + counter += 1 + counter %= 10 + + if eid in tune_ids: + print >>tune_nl, nl + print >>tune_mrl, mrl + elif eid in train_ids and not unlabel_this: + print >>train_nl, nl + print >>train_mrl, mrl + print >>train_fun, fun + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, '<s>', nl, '</s>' + print >>train_mrl_lm, '<s>', mrl, '</s>' + elif eid in train_ids and unlabel_this: + print >>unlabeled_nl, nl + elif eid in test_ids: + print >>test_nl, nl + print >>test_mrl, mrl + print >>test_fun, fun + + nplist = ET.parse('%s/nps-true.xml' % self.config.data_dir) + nplist_root = nplist.getroot() + for node in nplist_root.findall('example'): + fun = node.find("mrl[@lang='geo-funql']").text + fun = self.preprocess_fun(fun) + #fun = self.replace_specials(fun) + mrl = util.fun_to_mrl(fun) + big_np = len(mrl.split()) > 1 + if (self.config.np_type == 'big' and not big_np) or \ + (self.config.np_type == 'small' and big_np): + continue + for nl_node in node.findall("nl[@lang='%s']" % self.config.lang): + nl = nl_node.text + nl = self.preprocess_nl(nl) + for i in range(self.NP_WEIGHT): + print >>train_nl_np, nl + print >>train_mrl_np, mrl + print >>train_nl_lm, nl + print >>train_mrl_lm, mrl + + train_nl.close() + train_nl_lm.close() + train_mrl.close() + train_mrl_lm.close() + train_fun.close() + test_nl.close() + test_mrl.close() + test_fun.close() + tune_nl.close() + tune_mrl.close() + + def get_folds(self): + + if self.config.corpus == 'geo': + if self.config.run in ('debug', 'dev'): + train_ids_file = '%s/folds600/fold-%d-train.ids' \ + % (self.config.data_dir, self.config.fold) + tune_ids_file = None + test_ids_file = '%s/folds600/fold-%d-test.ids' \ + % (self.config.data_dir, self.config.fold) + elif self.config.run == 'test': + train_ids_file = '%s/split880/fold-0-train.ids' % self.config.data_dir + tune_ids_file = '%s/split880/fold-0-tune.ids' % self.config.data_dir + test_ids_file = '%s/split880/fold-0-test.ids' % self.config.data_dir + + elif self.config.corpus == 'robo': + if self.config.run in ('debug', 'dev'): + train_ids_file = '%s/split-300/run-0/fold-%d/train-N270' \ + % (self.config.data_dir, self.config.fold) + tune_ids_file = None + test_ids_file = '%s/split-300/run-0/fold-%d/test' \ + % (self.config.data_dir, self.config.fold) + else: + assert False + + train_ids = set() + tune_ids = set() + test_ids = set() + with open(train_ids_file) as fold_file: + for line in fold_file.readlines(): + train_ids.add(int(line)) + if tune_ids_file: + with open(tune_ids_file) as fold_file: + for line in fold_file.readlines(): + tune_ids.add(int(line)) + with open(test_ids_file) as fold_file: + for line in fold_file.readlines(): + test_ids.add(int(line)) + + return train_ids, tune_ids, test_ids + + def preprocess_nl(self, nl): + nl = nl.strip().lower() + if self.config.stem and self.config.lang == 'de': + # German stemmer can't handle UTF-8 + nl = nl.encode('ascii', 'ignore') + else: + nl = nl.encode('utf-8', 'ignore') + if nl[-2:] == ' .' or nl[-2:] == ' ?': + nl = nl[:-2] + if self.config.stem: + nl = ' '.join([self.stemmer.stem(tok) for tok in nl.split()]) + return nl + + def preprocess_fun(self, fun): + return fun.strip() + + def replace_specials(self, mrl): + mrl = mrl.replace('.', 'xxd') + mrl = mrl.replace("'", 'xxq') + mrl = mrl.replace('/', 'xxs') + #mrl = re.sub(r"(' *[^'()]*)\'([^'()]* *')", r'\1_q_\2', mrl) + #mrl = re.sub(r"(' *[^'()]*)\.([^'()]* *')", r'\1_dot_\2', mrl) + #mrl = re.sub(r"(' *[^'()]*)\/([^'()]* *')", r'\1_slash_\2', mrl) + return mrl + + def clang_to_fun(self, clang): + clang = clang.strip() + clang = re.sub(r'\s+', ' ', clang) + clang = re.sub(r'\{([\d|X]+( [\d|X]+)*)\}', r'(set \1)', clang) + clang = re.sub(r'\(([\w.-]+) ?', r'\1(', clang) + clang = self.strip_bare_parens(clang) + clang = clang.replace('()', '') + clang = clang.replace(' ', ',') + clang = clang.replace('"', '') + + clang = re.sub(r'definerule\([^,]+,[^,]+,', r'definerule(', clang) + + return clang + + def strip_bare_parens(self, clang): + try: + start = clang.index(' (')+1 + except ValueError: + return clang + + end = start+1 + pcounter = 0 + while pcounter >= 0: + c = clang[end:end+1] + if c == '(': + pcounter += 1 + elif c == ')': + pcounter -= 1 + end += 1 + end -= 1 + + r = clang[:start] + clang[start+1:end] + clang[end+1:] + return r + + def slot_to_fun(self, slot): + slot = slot.strip() + slot = slot.replace('value', '"value"') + slot = slot.replace('="', "('") + slot = slot.replace('",', "'),") + slot = slot.replace('")', "'))") + slot = slot.replace("'value'", 'value') + return slot |