diff options
| author | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 | 
|---|---|---|
| committer | Patrick Simianer <simianer@cl.uni-heidelberg.de> | 2012-11-05 15:29:46 +0100 | 
| commit | 6f29f345dc06c1a1033475eac1d1340781d1d603 (patch) | |
| tree | 6fa4cdd7aefd7d54c9585c2c6274db61bb8b159a /klm/search/weights.cc | |
| parent | b510da2e562c695c90d565eb295c749569c59be8 (diff) | |
| parent | c615c37501fa8576584a510a9d2bfe2fdd5bace7 (diff) | |
merge upstream/master
Diffstat (limited to 'klm/search/weights.cc')
| -rw-r--r-- | klm/search/weights.cc | 71 | 
1 files changed, 71 insertions, 0 deletions
| diff --git a/klm/search/weights.cc b/klm/search/weights.cc new file mode 100644 index 00000000..d65471ad --- /dev/null +++ b/klm/search/weights.cc @@ -0,0 +1,71 @@ +#include "search/weights.hh" +#include "util/tokenize_piece.hh" + +#include <cstdlib> + +namespace search { + +namespace { +struct Insert { +  void operator()(boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) const { +    std::string copy(name.data(), name.size()); +    map[copy] = score; +  } +}; + +struct DotProduct { +  search::Score total; +  DotProduct() : total(0.0) {} + +  void operator()(const boost::unordered_map<std::string, search::Score> &map, StringPiece name, search::Score score) { +    boost::unordered_map<std::string, search::Score>::const_iterator i(FindStringPiece(map, name)); +    if (i != map.end())  +      total += score * i->second; +  } +}; + +template <class Map, class Op> void Parse(StringPiece text, Map &map, Op &op) { +  for (util::TokenIter<util::SingleCharacter, true> spaces(text, ' '); spaces; ++spaces) { +    util::TokenIter<util::SingleCharacter> equals(*spaces, '='); +    UTIL_THROW_IF(!equals, WeightParseException, "Bad weight token " << *spaces); +    StringPiece name(*equals); +    UTIL_THROW_IF(!++equals, WeightParseException, "Bad weight token " << *spaces); +    char *end; +    // Assumes proper termination.   +    double value = std::strtod(equals->data(), &end); +    UTIL_THROW_IF(end != equals->data() + equals->size(), WeightParseException, "Failed to parse weight" << *equals); +    UTIL_THROW_IF(++equals, WeightParseException, "Too many equals in " << *spaces); +    op(map, name, value); +  } +} + +} // namespace + +Weights::Weights(StringPiece text) { +  Insert op; +  Parse<Map, Insert>(text, map_, op); +  lm_ = Steal("LanguageModel"); +  oov_ = Steal("OOV"); +  word_penalty_ = Steal("WordPenalty"); +} + +Weights::Weights(Score lm, Score oov, Score word_penalty) : lm_(lm), oov_(oov), word_penalty_(word_penalty) {} + +search::Score Weights::DotNoLM(StringPiece text) const { +  DotProduct dot; +  Parse<const Map, DotProduct>(text, map_, dot); +  return dot.total; +} + +float Weights::Steal(const std::string &str) { +  Map::iterator i(map_.find(str)); +  if (i == map_.end()) { +    return 0.0; +  } else { +    float ret = i->second; +    map_.erase(i); +    return ret; +  } +} + +} // namespace search | 
