summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Chahuneau <vchahune@cs.cmu.edu>2012-06-21 17:27:48 -0400
committerVictor Chahuneau <vchahune@cs.cmu.edu>2012-06-21 17:27:48 -0400
commit899f78c7c4c3a8cff97494665ed52ddb3460d44a (patch)
tree719a46d4b832572eb6e4501328b13057a1295a17
parentf3fef50a89e8e88da39e3c7000310c9e319d5cfc (diff)
Allow SA rule extraction to write to a python buffer
+ very small sa-extract cleanup
-rw-r--r--python/cdec/scfg/extractor.py12
-rw-r--r--sa-extract/cn.py46
-rw-r--r--sa-extract/manager.py2
-rw-r--r--sa-extract/model.py5
-rw-r--r--sa-extract/rulefactory.pyx15
5 files changed, 38 insertions, 42 deletions
diff --git a/python/cdec/scfg/extractor.py b/python/cdec/scfg/extractor.py
index 9f1e1137..0a45ddb8 100644
--- a/python/cdec/scfg/extractor.py
+++ b/python/cdec/scfg/extractor.py
@@ -1,5 +1,5 @@
-#!/usr/bin/env python
import StringIO
+from itertools import chain
import clex
import rulefactory
@@ -9,12 +9,12 @@ import cdat
import sym
import log
-log.level = -1
-
from features import EgivenFCoherent, SampleCountF, CountEF,\
MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE
from features import contextless
+log.level = -1
+
class Output(StringIO.StringIO):
def close(self):
pass
@@ -22,8 +22,6 @@ class Output(StringIO.StringIO):
def __str__(self):
return self.getvalue()
-from itertools import chain
-
def get_cn(sentence):
sentence = chain(('<s>',), sentence.split(), ('</s>',))
sentence = (sym.fromstring(word, terminal=True) for word in sentence)
@@ -93,9 +91,11 @@ class GrammarExtractor:
self.models = tuple(contextless(feature) for feature in self.models)
def grammar(self, sentence):
+ if isinstance(sentence, unicode):
+ sentence = sentence.encode('utf8')
out = Output()
cn = get_cn(sentence)
- self.factory.input_file(cn, out)
+ self.factory.input(cn, output=out)
return str(out)
def main(config):
diff --git a/sa-extract/cn.py b/sa-extract/cn.py
index e534783f..6e45bcf9 100644
--- a/sa-extract/cn.py
+++ b/sa-extract/cn.py
@@ -4,11 +4,8 @@
# vim:tabstop=4:autoindent:expandtab
-import sys
-import math
import sym
import log
-import sgml
epsilon = sym.fromstring('*EPS*');
@@ -142,23 +139,26 @@ class ConfusionNet(object):
-
-#file = open(sys.argv[1], "rb")
-#sent = sgml.process_sgml_line(file.read())
-#print sent
-#cn = ConfusionNet(sent)
-#print cn
-#results = cn.listdown()
-#for result in results:
-# print sym.tostring(result)
-#print cn.next(0);
-#print cn.next(1);
-#print cn.next(2);
-#print cn.next(3);
-#print cn
-#cn = ConfusionNet()
-#k = 0
-#while (cn.read(file)):
-# print cn
-
-#print cn.stats
+"""
+import sys
+import sgml
+file = open(sys.argv[1], "rb")
+sent = sgml.process_sgml_line(file.read())
+print sent
+cn = ConfusionNet(sent)
+print cn
+results = cn.listdown()
+for result in results:
+ print sym.tostring(result)
+print cn.next(0);
+print cn.next(1);
+print cn.next(2);
+print cn.next(3);
+print cn
+cn = ConfusionNet()
+k = 0
+while (cn.read(file)):
+ print cn
+
+print cn.stats
+"""
diff --git a/sa-extract/manager.py b/sa-extract/manager.py
index 767192c1..3a079c2a 100644
--- a/sa-extract/manager.py
+++ b/sa-extract/manager.py
@@ -1,5 +1,6 @@
import csuf
import cdat
+import cintlist
class Sampler(object):
'''A Sampler implements a logic for choosing
@@ -15,7 +16,6 @@ class Sampler(object):
return cintlist.CIntList()
-
class Extractor(object):
'''Extractor is responsible for extracting rules
from a given context; once a sentence id/location
diff --git a/sa-extract/model.py b/sa-extract/model.py
index 66c51051..bcdf129a 100644
--- a/sa-extract/model.py
+++ b/sa-extract/model.py
@@ -1,4 +1,3 @@
-
class Model(object):
def __init__(self, name=None):
object.__init__(self)
@@ -6,7 +5,3 @@ class Model(object):
self.name = self.__class__.__name__
else:
self.name = name
-
- def input(self, fwords, meta):
- pass
-
diff --git a/sa-extract/rulefactory.pyx b/sa-extract/rulefactory.pyx
index 20ea80d2..792489c4 100644
--- a/sa-extract/rulefactory.pyx
+++ b/sa-extract/rulefactory.pyx
@@ -1321,7 +1321,7 @@ cdef class HieroCachingRuleFactory:
candidate.append([next_id,curr[1]+jump])
return sorted(result);
- def input(self, fwords, meta):
+ def input(self, fwords, meta=None, output=None):
'''When this function is called on the RuleFactory,
it looks up all of the rules that can be used to translate
the input sentence'''
@@ -1342,13 +1342,14 @@ cdef class HieroCachingRuleFactory:
nodes_isteps_away_buffer = {}
hit = 0
reachable_buffer = {}
- #print "id = ",meta
- #print "rule_file = ",self.rule_file
- dattrs = sgml.attrs_to_dict(meta)
- id = dattrs.get('id', 'NOID')
- if self.per_sentence_grammar:
+ if meta:
+ dattrs = sgml.attrs_to_dict(meta)
+ id = dattrs.get('id', 'NOID')
+ self.excluded_sent_id = int(dattrs.get('exclude', '-1'))
+ if output:
+ self.rule_filehandler = output
+ elif self.per_sentence_grammar:
self.rule_filehandler = open(self.rule_file+'.'+id, 'w')
- self.excluded_sent_id = int(dattrs.get('exclude', '-1'))
#print "max_initial_size = %i" % self.max_initial_size