summaryrefslogtreecommitdiff
path: root/decoder/tdict.cc
blob: f0588cfc3dd617d5a46cdb70734f16063ed6541b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#define TD_ALLOW_UNDEFINED_WORDIDS 1

// if 1, word ids that are >= end() will give a numeric token name (single per-thread shared buffer), which of course won't be Convert-able back to the id, because it's not added to the dict.  This is a convenience for logging fake token indices.  Any tokens actually added to the dict may cause end() to overlap the range of fake ids you were using - that's up to you to prevent.

#include <stdlib.h>
#include <cstring>
#include <sstream>
#include "Ngram.h"
#include "dict.h"
#include "tdict.h"
#include "Vocab.h"
#include "stringlib.h"
#include "threadlocal.h"

using namespace std;

Vocab TD::dict_(0,TD::max_wordid);
WordID TD::ss=dict_.ssIndex();
WordID TD::se=dict_.seIndex();
WordID TD::unk=dict_.unkIndex();
char const*const TD::ss_str=Vocab_SentStart;
char const*const TD::se_str=Vocab_SentEnd;
char const*const TD::unk_str=Vocab_Unknown;

// pre+(i-base)+">" for i in [base,e)
inline void pad(std::string const& pre,int base,int e) {
  assert(base<=e);
  ostringstream o;
  for (int i=base;i<e;++i) {
    o.str(pre);
    o<<(i-base)<<'>';
    WordID id=TD::Convert(o.str());
    assert(id==i); // this fails.  why?
  }
}


namespace {
struct TD_init {
  TD_init() {
    /*
      // disabled for now since it's breaking trunk
    assert(TD::Convert(TD::ss_str)==TD::ss);
    assert(TD::Convert(TD::se_str)==TD::se);
    assert(TD::Convert(TD::unk_str)==TD::unk);
    assert(TD::none==Vocab_None);
    pad("<FILLER",TD::end(),TD::reserved_begin);
    assert(TD::end()==TD::reserved_begin);
    int reserved_end=TD::begin();
    pad("<RESERVED",TD::end(),reserved_end);
    assert(TD::end()==reserved_end);
    */
  }
};

TD_init td_init;
}

unsigned int TD::NumWords() {
  return dict_.numWords();
}
WordID TD::end() {
  return dict_.highIndex();
}

WordID TD::Convert(const std::string& s) {
  return dict_.addWord((VocabString)s.c_str());
}

WordID TD::Convert(char const* s) {
  return dict_.addWord((VocabString)s);
}


#if TD_ALLOW_UNDEFINED_WORDIDS
# include "static_utoa.h"
char undef_prefix[]="UNDEF_";
static const int undefpre_n=sizeof(undef_prefix)/sizeof(undef_prefix[0]);
THREADLOCAL char undef_buf[]="UNDEF_________________";
inline char const* undef_token(WordID w)
{
  append_utoa(undef_buf+undefpre_n,w);
  return undef_buf;
}
#endif

const char* TD::Convert(WordID w) {
#if TD_ALLOW_UNDEFINED_WORDIDS
  if (w>=dict_.highIndex()) return undef_token(w);
#endif
  return dict_.getWord((VocabIndex)w);
}


void TD::GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids) {
  ids->clear();
  for (vector<string>::const_iterator i = strings.begin(); i != strings.end(); ++i)
    ids->push_back(TD::Convert(*i));
}

std::string TD::GetString(const std::vector<WordID>& str) {
  ostringstream o;
  for (int i=0;i<str.size();++i) {
    if (i) o << ' ';
    o << TD::Convert(str[i]);
  }
  return o.str();
}

std::string TD::GetString(WordID const* i,WordID const* e) {
  ostringstream o;
  bool sp=false;
  for (;i<e;++i,sp=true) {
    if (sp)
      o << ' ';
    o << TD::Convert(*i);
  }
  return o.str();
}

int TD::AppendString(const WordID& w, int pos, int bufsize, char* buffer)
{
  const char* word = TD::Convert(w);
  const char* const end_buf = buffer + bufsize;
  char* dest = buffer + pos;
  while(dest < end_buf && *word) {
    *dest = *word;
    ++dest;
    ++word;
  }
  return (dest - buffer);
}


namespace {
struct add_wordids {
  typedef std::vector<WordID> Ws;
  Ws *ids;
  explicit add_wordids(Ws *i) : ids(i) {  }
  add_wordids(const add_wordids& o) : ids(o.ids) {  }
  void operator()(char const* s) {
    ids->push_back(TD::Convert(s));
  }
  void operator()(std::string const& s) {
    ids->push_back(TD::Convert(s));
  }
};

}

void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids) {
  ids->clear();
  VisitTokens(s,add_wordids(ids));
}