summaryrefslogtreecommitdiff
path: root/python/cdec/scfg/extractor.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/cdec/scfg/extractor.py')
-rw-r--r--python/cdec/scfg/extractor.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/python/cdec/scfg/extractor.py b/python/cdec/scfg/extractor.py
index 0a45ddb8..1dfa2421 100644
--- a/python/cdec/scfg/extractor.py
+++ b/python/cdec/scfg/extractor.py
@@ -1,3 +1,5 @@
+import sys, os
+import re
import StringIO
from itertools import chain
@@ -32,7 +34,16 @@ class PhonyGrammar:
pass
class GrammarExtractor:
- def __init__(self, config):
+ def __init__(self, cfg):
+ if isinstance(cfg, dict):
+ config = cfg
+ elif isinstance(cfg, str):
+ cfg_file = os.path.basename(cfg)
+ if not re.match(r'^\w+\.py$', cfg_file):
+ raise ValueError('Config must be a *.py file')
+ sys.path.append(os.path.dirname(cfg))
+ config = __import__(cfg_file.replace('.py', '')).__dict__
+ sys.path.pop()
alignment = calignment.Alignment(config['a_file'], from_binary=True)
self.factory = rulefactory.HieroCachingRuleFactory(
# compiled alignment object (REQUIRED)
@@ -99,13 +110,10 @@ class GrammarExtractor:
return str(out)
def main(config):
- sys.path.append(os.path.dirname(config))
- module = __import__(os.path.basename(config).replace('.py', ''))
- extractor = GrammarExtractor(module.__dict__)
- print extractor.grammar(next(sys.stdin))
+ extractor = GrammarExtractor(config)
+ sys.stdout.write(extractor.grammar(next(sys.stdin)))
if __name__ == '__main__':
- import sys, os
if len(sys.argv) != 2 or not sys.argv[1].endswith('.py'):
sys.stderr.write('Usage: %s config.py\n' % sys.argv[0])
sys.exit(1)