diff options
| -rw-r--r-- | gi/evaluation/evaluate_entropy.py | 19 | ||||
| -rw-r--r-- | gi/evaluation/extract_ccg_labels.py | 20 | 
2 files changed, 32 insertions, 7 deletions
diff --git a/gi/evaluation/evaluate_entropy.py b/gi/evaluation/evaluate_entropy.py index 88533544..55449c59 100644 --- a/gi/evaluation/evaluate_entropy.py +++ b/gi/evaluation/evaluate_entropy.py @@ -4,6 +4,10 @@ import sys, math, itertools  ginfile = open(sys.argv[1])  pinfile = open(sys.argv[2]) +if len(sys.argv) >= 4: +    slash_threshold = int(sys.argv[3]) +else: +    slash_threshold = 99999  # evaluating: H(G | P) = sum_{g,p} p(g,p) log { p(p) / p(g,p) }  #                      = sum_{g,p} c(g,p)/N { log c(p) - log N - log c(g,p) + log N } @@ -24,16 +28,17 @@ for gline, pline in itertools.izip(ginfile, pinfile):          gtag = gpart.split(':',1)[1]          ptag = ppart.split(':',1)[1] -        joint_frequencies.setdefault((gtag, ptag), 0) -        joint_frequencies[gtag,ptag] += 1 +        if gtag.count('/') <= slash_threshold: +            joint_frequencies.setdefault((gtag, ptag), 0) +            joint_frequencies[gtag,ptag] += 1 -        predict_frequencies.setdefault(ptag, 0) -        predict_frequencies[ptag] += 1 +            predict_frequencies.setdefault(ptag, 0) +            predict_frequencies[ptag] += 1 -        gold_frequencies.setdefault(gtag, 0) -        gold_frequencies[gtag] += 1 +            gold_frequencies.setdefault(gtag, 0) +            gold_frequencies[gtag] += 1 -        N += 1 +            N += 1  hg2p = 0  hp2g = 0 diff --git a/gi/evaluation/extract_ccg_labels.py b/gi/evaluation/extract_ccg_labels.py index 014e0399..5dd6eb65 100644 --- a/gi/evaluation/extract_ccg_labels.py +++ b/gi/evaluation/extract_ccg_labels.py @@ -60,12 +60,32 @@ def frontier(node, indices):      else:          return [node] +def project_heads(node): +    #print 'project_heads', node +    is_head = node.data.tag.endswith('-HEAD') +    if node.children: +        found = 0 +        for child in node.children: +            x = project_heads(child) +            if x: +                node.data.tag = x +                found += 1 +        assert found == 1 +    elif is_head: +        node.data.tag = node.data.tag[:-len('-HEAD')] + +    if is_head: +        return node.data.tag +    else: +        return None +  for tline, eline in itertools.izip(tinfile, einfile):      if tline.strip() != '(())':          if tline.startswith('( '):              tline = tline[2:-1].strip()          tr = tree.parse_PST(tline)          number_leaves(tr) +        #project_heads(tr) # assumes Bikel-style head annotation for the input trees      else:          tr = None  | 
