diff options
Diffstat (limited to 'nlp_tools/vocabulary.py')
-rw-r--r-- | nlp_tools/vocabulary.py | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/nlp_tools/vocabulary.py b/nlp_tools/vocabulary.py new file mode 100644 index 0000000..ed200f5 --- /dev/null +++ b/nlp_tools/vocabulary.py @@ -0,0 +1,49 @@ +import cPickle + +class Vocabulary: + + OOV_VAL = -1 + + def __init__(self): + self.str_to_tok = {} + self.tok_to_str = {} + + def put(self, string): + if string in self.str_to_tok: + raise ValueError("%s is already in this vocabulary (token %d)" % \ + (string, self.str_to_tok[string])) + return self.ensure(string) + + def ensure(self, string): + if string in self.str_to_tok: + return + tok = len(self) + self.str_to_tok[string] = tok + self.tok_to_str[tok] = string + return tok + + def gett(self, string): + if string not in self.str_to_tok: + return self.OOV_VAL + return self.str_to_tok[string] + + def gets(self, tok): + return self.tok_to_str[tok] + + def strs(self): + return self.str_to_tok.keys() + + def toks(self): + return self.tok_to_str.keys() + + def __len__(self): + return len(self.str_to_tok) + + def save(self, path): + with open(path, 'w') as f: + cPickle.dump(self, f) + + @classmethod + def load(cls, path): + with open(path) as f: + return cPickle.load(f) |