summaryrefslogtreecommitdiff
path: root/decoder/tdict.cc
blob: 7b56d2593d2a7c492b63d6cec719e8fe5b8493ee (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <sstream>
#include "Ngram.h"
#include "dict.h"
#include "tdict.h"
#include "Vocab.h"
#include "stringlib.h"

using namespace std;

//FIXME: valgrind errors (static init order?)
Vocab TD::dict_(0,TD::max_wordid);
WordID TD::ss=dict_.ssIndex();
WordID TD::se=dict_.seIndex();
WordID TD::unk=dict_.unkIndex();
char const*const TD::ss_str=Vocab_SentStart;
char const*const TD::se_str=Vocab_SentEnd;
char const*const TD::unk_str=Vocab_Unknown;

// pre+(i-base)+">" for i in [base,e)
inline void pad(std::string const& pre,int base,int e) {
  assert(base<=e);
  ostringstream o;
  for (int i=base;i<e;++i) {
    o.str(pre);
    o<<(i-base)<<'>';
    WordID id=TD::Convert(o.str());
    assert(id==i); // this fails.  why?
  }
}


namespace {
struct TD_init {
  TD_init() {
    /*
      // disabled for now since it's breaking trunk
    assert(TD::Convert(TD::ss_str)==TD::ss);
    assert(TD::Convert(TD::se_str)==TD::se);
    assert(TD::Convert(TD::unk_str)==TD::unk);
    assert(TD::none==Vocab_None);
    pad("<FILLER",TD::end(),TD::reserved_begin);
    assert(TD::end()==TD::reserved_begin);
    int reserved_end=TD::begin();
    pad("<RESERVED",TD::end(),reserved_end);
    assert(TD::end()==reserved_end);
    */
  }
};

TD_init td_init;
}

unsigned int TD::NumWords() {
  return dict_.numWords();
}
WordID TD::end() {
  return dict_.highIndex();
}

WordID TD::Convert(const std::string& s) {
  return dict_.addWord((VocabString)s.c_str());
}

WordID TD::Convert(char const* s) {
  return dict_.addWord((VocabString)s);
}

const char* TD::Convert(const WordID& w) {
  return dict_.getWord((VocabIndex)w);
}


void TD::GetWordIDs(const std::vector<std::string>& strings, std::vector<WordID>* ids) {
  ids->clear();
  for (vector<string>::const_iterator i = strings.begin(); i != strings.end(); ++i)
    ids->push_back(TD::Convert(*i));
}

std::string TD::GetString(const std::vector<WordID>& str) {
  ostringstream o;
  for (int i=0;i<str.size();++i) {
    if (i) o << ' ';
    o << TD::Convert(str[i]);
  }
  return o.str();
}

std::string TD::GetString(WordID const* i,WordID const* e) {
  ostringstream o;
  bool sp=false;
  for (;i<e;++i,sp=true) {
    if (sp)
      o << ' ';
    o << TD::Convert(*i);
  }
  return o.str();
}

int TD::AppendString(const WordID& w, int pos, int bufsize, char* buffer)
{
  const char* word = TD::Convert(w);
  const char* const end_buf = buffer + bufsize;
  char* dest = buffer + pos;
  while(dest < end_buf && *word) {
    *dest = *word;
    ++dest;
    ++word;
  }
  return (dest - buffer);
}


namespace {
struct add_wordids {
  typedef std::vector<WordID> Ws;
  Ws *ids;
  explicit add_wordids(Ws *i) : ids(i) {  }
  add_wordids(const add_wordids& o) : ids(o.ids) {  }
  void operator()(char const* s) {
    ids->push_back(TD::Convert(s));
  }
  void operator()(std::string const& s) {
    ids->push_back(TD::Convert(s));
  }
};

}

void TD::ConvertSentence(std::string const& s, std::vector<WordID>* ids) {
  ids->clear();
  VisitTokens(s,add_wordids(ids));
}