summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gi/evaluation/evaluate_entropy.py19
-rw-r--r--gi/evaluation/extract_ccg_labels.py20
2 files changed, 32 insertions, 7 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py
index 88533544..55449c59 100644
--- a/gi/evaluation/evaluate_entropy.py
+++ b/gi/evaluation/evaluate_entropy.py
@@ -4,6 +4,10 @@ import sys, math, itertools
ginfile = open(sys.argv[1])
pinfile = open(sys.argv[2])
+if len(sys.argv) >= 4:
+ slash_threshold = int(sys.argv[3])
+else:
+ slash_threshold = 99999
# evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) }
# = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N }
@@ -24,16 +28,17 @@ for gline, pline in itertools.izip(ginfile, pinfile):
gtag = gpart.split(':',1)[1]
ptag = ppart.split(':',1)[1]
- joint_frequencies.setdefault((gtag, ptag), 0)
- joint_frequencies[gtag,ptag] += 1
+ if gtag.count('/') <= slash_threshold:
+ joint_frequencies.setdefault((gtag, ptag), 0)
+ joint_frequencies[gtag,ptag] += 1
- predict_frequencies.setdefault(ptag, 0)
- predict_frequencies[ptag] += 1
+ predict_frequencies.setdefault(ptag, 0)
+ predict_frequencies[ptag] += 1
- gold_frequencies.setdefault(gtag, 0)
- gold_frequencies[gtag] += 1
+ gold_frequencies.setdefault(gtag, 0)
+ gold_frequencies[gtag] += 1
- N += 1
+ N += 1
hg2p = 0
hp2g = 0
diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py
index 014e0399..5dd6eb65 100644
--- a/gi/evaluation/extract_ccg_labels.py
+++ b/gi/evaluation/extract_ccg_labels.py
@@ -60,12 +60,32 @@ def frontier(node, indices):
else:
return [node]
+def project_heads(node):
+ #print 'project_heads', node
+ is_head = node.data.tag.endswith('-HEAD')
+ if node.children:
+ found = 0
+ for child in node.children:
+ x = project_heads(child)
+ if x:
+ node.data.tag = x
+ found += 1
+ assert found == 1
+ elif is_head:
+ node.data.tag = node.data.tag[:-len('-HEAD')]
+
+ if is_head:
+ return node.data.tag
+ else:
+ return None
+
for tline, eline in itertools.izip(tinfile, einfile):
if tline.strip() != '(())':
if tline.startswith('( '):
tline = tline[2:-1].strip()
tr = tree.parse_PST(tline)
number_leaves(tr)
+ #project_heads(tr) # assumes Bikel-style head annotation for the input trees
else:
tr = None