summaryrefslogtreecommitdiff
path: root/extractor
diff options
context:
space:
mode:
authorPaul Baltescu <pauldb89@gmail.com>2013-11-26 01:14:28 +0000
committerPaul Baltescu <pauldb89@gmail.com>2013-11-26 01:31:40 +0000
commit1cd86c44e1799c441cdcda2a022be0ee6e52d38c (patch)
tree226777c6b734bbe8bf9ad0d4c941df198c72f2d1 /extractor
parent6bdb362473cf0ee1c636ca0c3f4cca63d82a5573 (diff)
Serialize vocabulary.
Diffstat (limited to 'extractor')
-rw-r--r--extractor/Makefile.am12
-rw-r--r--extractor/compile.cc4
-rw-r--r--extractor/vocabulary.cc4
-rw-r--r--extractor/vocabulary.h23
-rw-r--r--extractor/vocabulary_test.cc45
5 files changed, 84 insertions, 4 deletions
diff --git a/extractor/Makefile.am b/extractor/Makefile.am
index 65a3d436..64a5a2b5 100644
--- a/extractor/Makefile.am
+++ b/extractor/Makefile.am
@@ -24,7 +24,8 @@ EXTRA_PROGRAMS = alignment_test \
scorer_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
if HAVE_GTEST
RUNNABLE_TESTS = alignment_test \
@@ -48,12 +49,14 @@ if HAVE_GTEST
scorer_test \
suffix_array_test \
target_phrase_extractor_test \
- translation_table_test
+ translation_table_test \
+ vocabulary_test
endif
noinst_PROGRAMS = $(RUNNABLE_TESTS)
-TESTS = $(RUNNABLE_TESTS)
+# TESTS = $(RUNNABLE_TESTS)
+TESTS = vocabulary_test
alignment_test_SOURCES = alignment_test.cc
alignment_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
@@ -99,6 +102,8 @@ target_phrase_extractor_test_SOURCES = target_phrase_extractor_test.cc
target_phrase_extractor_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
translation_table_test_SOURCES = translation_table_test.cc
translation_table_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) $(GMOCK_LDFLAGS) $(GMOCK_LIBS) libextractor.a
+vocabulary_test_SOURCES = vocabulary_test.cc
+vocabulary_test_LDADD = $(GTEST_LDFLAGS) $(GTEST_LIBS) libextractor.a
noinst_LIBRARIES = libextractor.a libcompile.a
@@ -115,6 +120,7 @@ libcompile_a_SOURCES = \
suffix_array.cc \
time_util.cc \
translation_table.cc \
+ vocabulary.cc \
alignment.h \
data_array.h \
fast_intersector.h \
diff --git a/extractor/compile.cc b/extractor/compile.cc
index 0d62757e..9e8044ad 100644
--- a/extractor/compile.cc
+++ b/extractor/compile.cc
@@ -145,6 +145,10 @@ int main(int argc, char** argv) {
ofstream precomp_fstream((output_dir / fs::path("precomp.bin")).string());
ar::binary_oarchive precomp_stream(precomp_fstream);
precomp_stream << precomputation;
+
+ ofstream vocab_fstream((output_dir / fs::path("vocab.bin")).string());
+ ar::binary_oarchive vocab_stream(vocab_fstream);
+ vocab_stream << *vocabulary;
stop_write = Clock::now();
write_duration += GetDuration(start_write, stop_write);
diff --git a/extractor/vocabulary.cc b/extractor/vocabulary.cc
index aef674a5..c9c2d6f4 100644
--- a/extractor/vocabulary.cc
+++ b/extractor/vocabulary.cc
@@ -35,4 +35,8 @@ string Vocabulary::GetTerminalValue(int symbol) {
return word;
}
+bool Vocabulary::operator==(const Vocabulary& other) const {
+ return words == other.words && dictionary == other.dictionary;
+}
+
} // namespace extractor
diff --git a/extractor/vocabulary.h b/extractor/vocabulary.h
index c8fd9411..db092e99 100644
--- a/extractor/vocabulary.h
+++ b/extractor/vocabulary.h
@@ -5,6 +5,10 @@
#include <unordered_map>
#include <vector>
+#include <boost/serialization/serialization.hpp>
+#include <boost/serialization/string.hpp>
+#include <boost/serialization/vector.hpp>
+
using namespace std;
namespace extractor {
@@ -14,7 +18,7 @@ namespace extractor {
*
* This strucure contains words located in the frequent collocations and words
* encountered during the grammar extraction time. This dictionary is
- * considerably smaller than the dictionaries in the data arrays (and so is the
+ * considerably smaller than the dictionaries in the data arays (and so is the
* query time). Note that this is the single data structure that changes state
* and needs to have thread safe read/write operations.
*
@@ -38,7 +42,24 @@ class Vocabulary {
// Returns the word corresponding to the given word id.
virtual string GetTerminalValue(int symbol);
+ bool operator==(const Vocabulary& vocabulary) const;
+
private:
+ friend class boost::serialization::access;
+
+ template<class Archive> void save(Archive& ar, unsigned int) const {
+ ar << words;
+ }
+
+ template<class Archive> void load(Archive& ar, unsigned int) {
+ ar >> words;
+ for (size_t i = 0; i < words.size(); ++i) {
+ dictionary[words[i]] = i;
+ }
+ }
+
+ BOOST_SERIALIZATION_SPLIT_MEMBER();
+
unordered_map<string, int> dictionary;
vector<string> words;
};
diff --git a/extractor/vocabulary_test.cc b/extractor/vocabulary_test.cc
new file mode 100644
index 00000000..cf5e3e36
--- /dev/null
+++ b/extractor/vocabulary_test.cc
@@ -0,0 +1,45 @@
+#include <gtest/gtest.h>
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include <boost/archive/text_iarchive.hpp>
+#include <boost/archive/text_oarchive.hpp>
+
+#include "vocabulary.h"
+
+using namespace std;
+using namespace ::testing;
+namespace ar = boost::archive;
+
+namespace extractor {
+namespace {
+
+TEST(VocabularyTest, TestIndexes) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ EXPECT_EQ(1, vocabulary.GetTerminalIndex("one"));
+ EXPECT_EQ("one", vocabulary.GetTerminalValue(1));
+}
+
+TEST(VocabularyTest, TestSerialization) {
+ Vocabulary vocabulary;
+ EXPECT_EQ(0, vocabulary.GetTerminalIndex("zero"));
+ EXPECT_EQ("zero", vocabulary.GetTerminalValue(0));
+
+ stringstream stream(ios_base::out | ios_base::in);
+ ar::text_oarchive output_stream(stream, ar::no_header);
+ output_stream << vocabulary;
+
+ Vocabulary vocabulary_copy;
+ ar::text_iarchive input_stream(stream, ar::no_header);
+ input_stream >> vocabulary_copy;
+
+ EXPECT_EQ(vocabulary, vocabulary_copy);
+}
+
+} // namespace
+} // namespace extractor