diff options
author | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 18:23:01 +0000 |
---|---|---|
committer | trevor.cohn <trevor.cohn@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 18:23:01 +0000 |
commit | 17f8fa666dd0614dbbc520985d0cdcb8b0e69b05 (patch) | |
tree | 88ace1d2b8d190e4e1b3dc71947ba394ad9f2649 /gi/evaluation | |
parent | 5558d8fc9b67eb4dd98414587082ff3df27daaf9 (diff) |
Added slash limits
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@224 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/evaluation')
-rw-r--r-- | gi/evaluation/evaluate_entropy.py | 19 | ||||
-rw-r--r-- | gi/evaluation/extract_ccg_labels.py | 20 |
2 files changed, 32 insertions, 7 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py index 88533544..55449c59 100644 --- a/gi/evaluation/evaluate_entropy.py +++ b/gi/evaluation/evaluate_entropy.py @@ -4,6 +4,10 @@ import sys, math, itertools ginfile = open(sys.argv[1]) pinfile = open(sys.argv[2]) +if len(sys.argv) >= 4: + slash_threshold = int(sys.argv[3]) +else: + slash_threshold = 99999 # evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) } # = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N } @@ -24,16 +28,17 @@ for gline, pline in itertools.izip(ginfile, pinfile): gtag = gpart.split(':',1)[1] ptag = ppart.split(':',1)[1] - joint_frequencies.setdefault((gtag, ptag), 0) - joint_frequencies[gtag,ptag] += 1 + if gtag.count('/') <= slash_threshold: + joint_frequencies.setdefault((gtag, ptag), 0) + joint_frequencies[gtag,ptag] += 1 - predict_frequencies.setdefault(ptag, 0) - predict_frequencies[ptag] += 1 + predict_frequencies.setdefault(ptag, 0) + predict_frequencies[ptag] += 1 - gold_frequencies.setdefault(gtag, 0) - gold_frequencies[gtag] += 1 + gold_frequencies.setdefault(gtag, 0) + gold_frequencies[gtag] += 1 - N += 1 + N += 1 hg2p = 0 hp2g = 0 diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py index 014e0399..5dd6eb65 100644 --- a/gi/evaluation/extract_ccg_labels.py +++ b/gi/evaluation/extract_ccg_labels.py @@ -60,12 +60,32 @@ def frontier(node, indices): else: return [node] +def project_heads(node): + #print 'project_heads', node + is_head = node.data.tag.endswith('-HEAD') + if node.children: + found = 0 + for child in node.children: + x = project_heads(child) + if x: + node.data.tag = x + found += 1 + assert found == 1 + elif is_head: + node.data.tag = node.data.tag[:-len('-HEAD')] + + if is_head: + return node.data.tag + else: + return None + for tline, eline in itertools.izip(tinfile, einfile): if tline.strip() != '(())': if tline.startswith('( '): tline = tline[2:-1].strip() tr = tree.parse_PST(tline) number_leaves(tr) + #project_heads(tr) # assumes Bikel-style head annotation for the input trees else: tr = None |