summaryrefslogtreecommitdiff
path: root/nlp_tools/vocabulary.py
diff options
context:
space:
mode:
Diffstat (limited to 'nlp_tools/vocabulary.py')
-rw-r--r--nlp_tools/vocabulary.py49
1 files changed, 49 insertions, 0 deletions
diff --git a/nlp_tools/vocabulary.py b/nlp_tools/vocabulary.py
new file mode 100644
index 0000000..ed200f5
--- /dev/null
+++ b/nlp_tools/vocabulary.py
@@ -0,0 +1,49 @@
+import cPickle
+
+class Vocabulary:
+
+ OOV_VAL = -1
+
+ def __init__(self):
+ self.str_to_tok = {}
+ self.tok_to_str = {}
+
+ def put(self, string):
+ if string in self.str_to_tok:
+ raise ValueError("%s is already in this vocabulary (token %d)" % \
+ (string, self.str_to_tok[string]))
+ return self.ensure(string)
+
+ def ensure(self, string):
+ if string in self.str_to_tok:
+ return
+ tok = len(self)
+ self.str_to_tok[string] = tok
+ self.tok_to_str[tok] = string
+ return tok
+
+ def gett(self, string):
+ if string not in self.str_to_tok:
+ return self.OOV_VAL
+ return self.str_to_tok[string]
+
+ def gets(self, tok):
+ return self.tok_to_str[tok]
+
+ def strs(self):
+ return self.str_to_tok.keys()
+
+ def toks(self):
+ return self.tok_to_str.keys()
+
+ def __len__(self):
+ return len(self.str_to_tok)
+
+ def save(self, path):
+ with open(path, 'w') as f:
+ cPickle.dump(self, f)
+
+ @classmethod
+ def load(cls, path):
+ with open(path) as f:
+ return cPickle.load(f)