summaryrefslogtreecommitdiff
path: root/extractor/grammar_extractor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'extractor/grammar_extractor_test.cc')
-rw-r--r--extractor/grammar_extractor_test.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
new file mode 100644
index 00000000..823bb8b4
--- /dev/null
+++ b/extractor/grammar_extractor_test.cc
@@ -0,0 +1,51 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "grammar.h"
+#include "grammar_extractor.h"
+#include "mocks/mock_rule_factory.h"
+#include "mocks/mock_vocabulary.h"
+#include "rule.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+TEST(GrammarExtractorTest, TestAnnotatingWords) {
+ shared_ptr<MockVocabulary> vocabulary = make_shared<MockVocabulary>();
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("<s>"))
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("Anna"))
+ .WillRepeatedly(Return(1));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("has"))
+ .WillRepeatedly(Return(2));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("many"))
+ .WillRepeatedly(Return(3));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("apples"))
+ .WillRepeatedly(Return(4));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("."))
+ .WillRepeatedly(Return(5));
+ EXPECT_CALL(*vocabulary, GetTerminalIndex("</s>"))
+ .WillRepeatedly(Return(6));
+
+ shared_ptr<MockHieroCachingRuleFactory> factory =
+ make_shared<MockHieroCachingRuleFactory>();
+ vector<int> word_ids = {0, 1, 2, 3, 3, 4, 5, 6};
+ vector<Rule> rules;
+ vector<string> feature_names;
+ Grammar grammar(rules, feature_names);
+ EXPECT_CALL(*factory, GetGrammar(word_ids))
+ .WillOnce(Return(grammar));
+
+ GrammarExtractor extractor(vocabulary, factory);
+ string sentence = "Anna has many many apples .";
+ extractor.GetGrammar(sentence);
+}
+
+} // namespace
+} // namespace extractor