summaryrefslogtreecommitdiff
path: root/utils/corpus_tools.cc
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-02-29 01:12:40 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-02-29 01:12:40 -0500
commit54bcfb835232d190a5ab6f0bd825de8a50dae126 (patch)
tree161988135be75a47524cdf2efbba7bdf06c9bd65 /utils/corpus_tools.cc
parent89238977fc9d8f8d9a6421b0d4f35afc200f08e7 (diff)
cleanup, mpi-ify lblmodel
Diffstat (limited to 'utils/corpus_tools.cc')
-rw-r--r--utils/corpus_tools.cc62
1 files changed, 62 insertions, 0 deletions
diff --git a/utils/corpus_tools.cc b/utils/corpus_tools.cc
new file mode 100644
index 00000000..a0542b6e
--- /dev/null
+++ b/utils/corpus_tools.cc
@@ -0,0 +1,62 @@
+#include "corpus_tools.h"
+
+#include <iostream>
+
+#include "tdict.h"
+#include "filelib.h"
+#include "verbose.h"
+
+using namespace std;
+
+void CorpusTools::ReadFromFile(const string& filename,
+ vector<vector<WordID> >* src,
+ set<WordID>* src_vocab,
+ vector<vector<WordID> >* trg,
+ set<WordID>* trg_vocab,
+ int rank,
+ int size) {
+ assert(rank >= 0);
+ assert(size > 0);
+ assert(rank < size);
+ if (src) src->clear();
+ if (src_vocab) src_vocab->clear();
+ if (trg) trg->clear();
+ if (trg_vocab) trg_vocab->clear();
+ const int expected_fields = 1 + (trg == NULL ? 0 : 1);
+ if (!SILENT) cerr << "Reading from " << filename << " ...\n";
+ ReadFile rf(filename);
+ istream& in = *rf.stream();
+ string line;
+ int lc = 0;
+ static const WordID kDIV = TD::Convert("|||");
+ vector<WordID> tmp;
+ while(getline(in, line)) {
+ const bool skip = (lc % size != rank);
+ ++lc;
+ if (skip) continue;
+ TD::ConvertSentence(line, &tmp);
+ src->push_back(vector<WordID>());
+ vector<WordID>* d = &src->back();
+ set<WordID>* v = src_vocab;
+ int s = 0;
+ for (unsigned i = 0; i < tmp.size(); ++i) {
+ if (tmp[i] == kDIV) {
+ ++s;
+ if (s > 1) { cerr << "Unexpected format in line " << lc << ": " << line << endl; abort(); }
+ assert(trg);
+ trg->push_back(vector<WordID>());
+ d = &trg->back();
+ v = trg_vocab;
+ } else {
+ d->push_back(tmp[i]);
+ if (v) v->insert(tmp[i]);
+ }
+ }
+ ++s;
+ if (expected_fields != s) {
+ cerr << "Wrong number of fields in line " << lc << ": " << line << endl; abort();
+ }
+ }
+}
+
+