diff options
author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 |
---|---|---|
committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 |
commit | 6f29f345dc06c1a1033475eac1d1340781d1d603 (patch) | |
tree | 6fa4cdd7aefd7d54c9585c2c6274db61bb8b159a /klm/search/weights.cc | |
parent | b510da2e562c695c90d565eb295c749569c59be8 (diff) | |
parent | c615c37501fa8576584a510a9d2bfe2fdd5bace7 (diff) |
merge upstream/master
Diffstat (limited to 'klm/search/weights.cc')
-rw-r--r-- | klm/search/weights.cc | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..d65471ad --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,71 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { + void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { + std::string copy(name.data(), name.size()); + map[copy] = score; + } +}; + +struct DotProduct { + search::Score total; + DotProduct() : total(0.0) {} + + void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { + boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); + if (i != map.end()) + total += score * i->second; + } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { + for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { + util::TokenIter<util::SingleCharacter> equals(*spaces, '='); + UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); + StringPiece name(*equals); + UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); + char *end; + // Assumes proper termination. + double value = std::strtod(equals->data(), &end); + UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); + UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); + op(map, name, value); + } +} + +} // namespace + +Weights::Weights(StringPiece text) { + Insert op; + Parse<Map, Insert>(text, map_, op); + lm_ = Steal("LanguageModel"); + oov_ = Steal("OOV"); + word_penalty_ = Steal("WordPenalty"); +} + +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + +search::Score Weights::DotNoLM(StringPiece text) const { + DotProduct dot; + Parse<const Map, DotProduct>(text, map_, dot); + return dot.total; +} + +float Weights::Steal(const std::string &str) { + Map::iterator i(map_.find(str)); + if (i == map_.end()) { + return 0.0; + } else { + float ret = i->second; + map_.erase(i); + return ret; + } +} + +} // namespace search |