#include "data_array.h" #include <fstream> #include <iostream> #include <sstream> #include <string> using namespace std; namespace extractor { int DataArray::NULL_WORD = 0; int DataArray::END_OF_LINE = 1; string DataArray::NULL_WORD_STR = "__NULL__"; string DataArray::END_OF_LINE_STR = "__END_OF_LINE__"; DataArray::DataArray() { InitializeDataArray(); } DataArray::DataArray(const string& filename) { InitializeDataArray(); ifstream infile(filename.c_str()); vector<string> lines; string line; while (getline(infile, line)) { lines.push_back(line); } CreateDataArray(lines); } DataArray::DataArray(const string& filename, const Side& side) { InitializeDataArray(); ifstream infile(filename.c_str()); vector<string> lines; string line, delimiter = "|||"; while (getline(infile, line)) { int position = line.find(delimiter); if (side == SOURCE) { lines.push_back(line.substr(0, position)); } else { lines.push_back(line.substr(position + delimiter.size())); } } CreateDataArray(lines); } void DataArray::InitializeDataArray() { word2id[NULL_WORD_STR] = NULL_WORD; id2word.push_back(NULL_WORD_STR); word2id[END_OF_LINE_STR] = END_OF_LINE; id2word.push_back(END_OF_LINE_STR); } void DataArray::CreateDataArray(const vector<string>& lines) { for (size_t i = 0; i < lines.size(); ++i) { sentence_start.push_back(data.size()); istringstream iss(lines[i]); string word; while (iss >> word) { if (word2id.count(word) == 0) { word2id[word] = id2word.size(); id2word.push_back(word); } data.push_back(word2id[word]); sentence_id.push_back(i); } data.push_back(END_OF_LINE); sentence_id.push_back(i); } sentence_start.push_back(data.size()); data.shrink_to_fit(); sentence_id.shrink_to_fit(); sentence_start.shrink_to_fit(); } DataArray::~DataArray() {} vector<int> DataArray::GetData() const { return data; } int DataArray::AtIndex(int index) const { return data[index]; } string DataArray::GetWordAtIndex(int index) const { return id2word[data[index]]; } vector<int> DataArray::GetWordIds(int index, int size) const { return vector<int>(data.begin() + index, data.begin() + index + size); } vector<string> DataArray::GetWords(int start_index, int size) const { vector<string> words; for (int word_id: GetWordIds(start_index, size)) { words.push_back(id2word[word_id]); } return words; } int DataArray::GetSize() const { return data.size(); } int DataArray::GetVocabularySize() const { return id2word.size(); } int DataArray::GetNumSentences() const { return sentence_start.size() - 1; } int DataArray::GetSentenceStart(int position) const { return sentence_start[position]; } int DataArray::GetSentenceLength(int sentence_id) const { // Ignore end of line markers. return sentence_start[sentence_id + 1] - sentence_start[sentence_id] - 1; } int DataArray::GetSentenceId(int position) const { return sentence_id[position]; } int DataArray::GetWordId(const string& word) const { auto result = word2id.find(word); return result == word2id.end() ? -1 : result->second; } string DataArray::GetWord(int word_id) const { return id2word[word_id]; } bool DataArray::operator==(const DataArray& other) const { return word2id == other.word2id && id2word == other.id2word && data == other.data && sentence_start == other.sentence_start && sentence_id == other.sentence_id; } } // namespace extractor