summaryrefslogtreecommitdiff
path: root/klm/search/weights.cc
diff options
context:
space:
mode:
authorKenneth Heafield <github@kheafield.com>2012-09-11 14:30:16 +0100
committerKenneth Heafield <github@kheafield.com>2012-09-11 14:31:42 +0100
commit58d7f847cd5b3c56682e834a2d9b897c6943fafc (patch)
tree04b370b0d92240a4a60dbf921548c927e0a5f00d /klm/search/weights.cc
parent7aa4baf365a80380bebacfc4d4a1ef1b9d757590 (diff)
Add search library to cdec (not used yet)
Diffstat (limited to 'klm/search/weights.cc')
-rw-r--r--klm/search/weights.cc69
1 files changed, 69 insertions, 0 deletions
diff --git a/klm/search/weights.cc b/klm/search/weights.cc
new file mode 100644
index 00000000..82ff3f12
--- /dev/null
+++ b/klm/search/weights.cc
@@ -0,0 +1,69 @@
+#include "search/weights.hh"
+#include "util/tokenize_piece.hh"
+
+#include <cstdlib>
+
+namespace search {
+
+namespace {
+struct Insert {
+ void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const {
+ std::string copy(name.data(), name.size());
+ map[copy] = score;
+ }
+};
+
+struct DotProduct {
+ search::Score total;
+ DotProduct() : total(0.0) {}
+
+ void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) {
+ boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name));
+ if (i != map.end())
+ total += score * i->second;
+ }
+};
+
+template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) {
+ for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) {
+ util::TokenIter<util::SingleCharacter> equals(*spaces, '=');
+ UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces);
+ StringPiece name(*equals);
+ UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces);
+ char *end;
+ // Assumes proper termination.
+ double value = std::strtod(equals->data(), &end);
+ UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals);
+ UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces);
+ op(map, name, value);
+ }
+}
+
+} // namespace
+
+Weights::Weights(StringPiece text) {
+ Insert op;
+ Parse<Map, Insert>(text, map_, op);
+ lm_ = Steal("LanguageModel");
+ oov_ = Steal("OOV");
+ word_penalty_ = Steal("WordPenalty");
+}
+
+search::Score Weights::DotNoLM(StringPiece text) const {
+ DotProduct dot;
+ Parse<const Map, DotProduct>(text, map_, dot);
+ return dot.total;
+}
+
+float Weights::Steal(const std::string &str) {
+ Map::iterator i(map_.find(str));
+ if (i == map_.end()) {
+ return 0.0;
+ } else {
+ float ret = i->second;
+ map_.erase(i);
+ return ret;
+ }
+}
+
+} // namespace search