diff options
| -rw-r--r-- | python/cdec/scfg/extractor.py | 12 | ||||
| -rw-r--r-- | sa-extract/cn.py | 46 | ||||
| -rw-r--r-- | sa-extract/manager.py | 2 | ||||
| -rw-r--r-- | sa-extract/model.py | 5 | ||||
| -rw-r--r-- | sa-extract/rulefactory.pyx | 15 | 
5 files changed, 38 insertions, 42 deletions
| diff --git a/python/cdec/scfg/extractor.py b/python/cdec/scfg/extractor.py index 9f1e1137..0a45ddb8 100644 --- a/python/cdec/scfg/extractor.py +++ b/python/cdec/scfg/extractor.py @@ -1,5 +1,5 @@ -#!/usr/bin/env python  import StringIO +from itertools import chain  import clex  import rulefactory @@ -9,12 +9,12 @@ import cdat  import sym  import log -log.level = -1 -  from features import EgivenFCoherent, SampleCountF, CountEF,\          MaxLexEgivenF, MaxLexFgivenE, IsSingletonF, IsSingletonFE  from features import contextless +log.level = -1 +  class Output(StringIO.StringIO):      def close(self):          pass @@ -22,8 +22,6 @@ class Output(StringIO.StringIO):      def __str__(self):          return self.getvalue() -from itertools import chain -  def get_cn(sentence):      sentence = chain(('<s>',), sentence.split(), ('</s>',))      sentence = (sym.fromstring(word, terminal=True) for word in sentence) @@ -93,9 +91,11 @@ class GrammarExtractor:          self.models = tuple(contextless(feature) for feature in self.models)      def grammar(self, sentence): +        if isinstance(sentence, unicode): +            sentence = sentence.encode('utf8')          out = Output()          cn = get_cn(sentence) -        self.factory.input_file(cn, out) +        self.factory.input(cn, output=out)          return str(out)  def main(config): diff --git a/sa-extract/cn.py b/sa-extract/cn.py index e534783f..6e45bcf9 100644 --- a/sa-extract/cn.py +++ b/sa-extract/cn.py @@ -4,11 +4,8 @@  # vim:tabstop=4:autoindent:expandtab -import sys -import math  import sym  import log -import sgml  epsilon = sym.fromstring('*EPS*'); @@ -142,23 +139,26 @@ class ConfusionNet(object): - -#file = open(sys.argv[1], "rb") -#sent = sgml.process_sgml_line(file.read()) -#print sent -#cn = ConfusionNet(sent) -#print cn -#results = cn.listdown() -#for result in results: -#    print sym.tostring(result) -#print cn.next(0); -#print cn.next(1); -#print cn.next(2); -#print cn.next(3); -#print cn -#cn = ConfusionNet() -#k = 0 -#while (cn.read(file)): -#  print cn -   -#print cn.stats +""" +import sys +import sgml +file = open(sys.argv[1], "rb") +sent = sgml.process_sgml_line(file.read()) +print sent +cn = ConfusionNet(sent) +print cn +results = cn.listdown() +for result in results: +    print sym.tostring(result) +print cn.next(0); +print cn.next(1); +print cn.next(2); +print cn.next(3); +print cn +cn = ConfusionNet() +k = 0 +while (cn.read(file)): +  print cn +  +print cn.stats +""" diff --git a/sa-extract/manager.py b/sa-extract/manager.py index 767192c1..3a079c2a 100644 --- a/sa-extract/manager.py +++ b/sa-extract/manager.py @@ -1,5 +1,6 @@  import csuf  import cdat +import cintlist  class Sampler(object):  	'''A Sampler implements a logic for choosing @@ -15,7 +16,6 @@ class Sampler(object):  		return cintlist.CIntList() -  class Extractor(object):  	'''Extractor is responsible for extracting rules  	from a given context; once a sentence id/location diff --git a/sa-extract/model.py b/sa-extract/model.py index 66c51051..bcdf129a 100644 --- a/sa-extract/model.py +++ b/sa-extract/model.py @@ -1,4 +1,3 @@ -  class Model(object):      def __init__(self, name=None):          object.__init__(self) @@ -6,7 +5,3 @@ class Model(object):              self.name = self.__class__.__name__          else:              self.name = name - -    def input(self, fwords, meta): -        pass - diff --git a/sa-extract/rulefactory.pyx b/sa-extract/rulefactory.pyx index 20ea80d2..792489c4 100644 --- a/sa-extract/rulefactory.pyx +++ b/sa-extract/rulefactory.pyx @@ -1321,7 +1321,7 @@ cdef class HieroCachingRuleFactory:            candidate.append([next_id,curr[1]+jump])      return sorted(result); -  def input(self, fwords, meta): +  def input(self, fwords, meta=None, output=None):      '''When this function is called on the RuleFactory,      it looks up all of the rules that can be used to translate      the input sentence''' @@ -1342,13 +1342,14 @@ cdef class HieroCachingRuleFactory:      nodes_isteps_away_buffer = {}      hit = 0      reachable_buffer = {} -    #print "id = ",meta -    #print "rule_file = ",self.rule_file -    dattrs = sgml.attrs_to_dict(meta) -    id = dattrs.get('id', 'NOID') -    if self.per_sentence_grammar: +    if meta: +        dattrs = sgml.attrs_to_dict(meta) +        id = dattrs.get('id', 'NOID') +        self.excluded_sent_id = int(dattrs.get('exclude', '-1')) +    if output: +      self.rule_filehandler = output +    elif self.per_sentence_grammar:        self.rule_filehandler = open(self.rule_file+'.'+id, 'w') -    self.excluded_sent_id = int(dattrs.get('exclude', '-1'))      #print "max_initial_size = %i" % self.max_initial_size | 
