diff options
Diffstat (limited to 'decoder/tromble_loss.cc')
-rw-r--r-- | decoder/tromble_loss.cc | 309 |
1 files changed, 309 insertions, 0 deletions
diff --git a/decoder/tromble_loss.cc b/decoder/tromble_loss.cc new file mode 100644 index 00000000..9ebd8ab1 --- /dev/null +++ b/decoder/tromble_loss.cc @@ -0,0 +1,309 @@ +#include "tromble_loss.h" + +#include <boost/algorithm/string/predicate.hpp> +#include <boost/circular_buffer.hpp> +#include <boost/functional/hash.hpp> +#include <boost/lexical_cast.hpp> +#include <boost/range/iterator_range.hpp> +#include <boost/tokenizer.hpp> +#include <boost/unordered_map.hpp> + +#include <cmath> +#include <fstream> +#include <vector> + +#include "sentence_metadata.h" +#include "trule.h" +#include "tdict.h" + +using namespace std; + +namespace { + +typedef unsigned char GramCount; + +struct RefCounts { + GramCount max; + std::vector<GramCount> refs; + size_t length; +}; + +typedef boost::unordered_map<std::vector<WordID>, size_t, boost::hash<std::vector<WordID> > > NGramMap; + +// Take all the n-grams in the references and stuff them into ngrams. +void MakeNGramMapFromReferences(const vector<vector<WordID> > &references, + int n, + vector<RefCounts> *counts, + NGramMap *ngrams) { + ngrams->clear(); + std::pair<vector<WordID>, size_t> insert_me; + vector<WordID> &ngram = insert_me.first; + ngram.reserve(n); + size_t &id = insert_me.second; + id = 0; + for (int refi = 0; refi < references.size(); ++refi) { + const vector<WordID>& ref = references[refi]; + const int s = ref.size(); + for (int j=0; j<s; ++j) { + const int remaining = s-j; + const int k = (n < remaining ? n : remaining); + ngram.clear(); + for (unsigned int i = 0; i < k; ++i) { + ngram.push_back(ref[j + i]); + std::pair<NGramMap::iterator, bool> ret(ngrams->insert(insert_me)); + if (ret.second) { + counts->resize(id + 1); + RefCounts &ref_counts = counts->back(); + ref_counts.max = 1; + ref_counts.refs.resize(references.size()); + ref_counts.refs[refi] = 1; + ref_counts.length = ngram.size(); + ++id; + } else { + RefCounts &ref_counts = (*counts)[ret.first->second]; + ref_counts.max = std::max(ref_counts.max, ++ref_counts.refs[refi]); + } + } + } + } +} + +struct MutableState { + MutableState(void *from, size_t n) : length(reinterpret_cast<size_t*>(from)), left(reinterpret_cast<WordID *>(length + 1)), right(left + n - 1), counts(reinterpret_cast<GramCount *>(right + n - 1)) {} + size_t *length; + WordID *left, *right; + GramCount *counts; + static size_t Size(size_t n, size_t bound_ngram_id) { return sizeof(size_t) + (n - 1) * 2 * sizeof(WordID) + bound_ngram_id * sizeof(GramCount); } +}; + +struct ConstState { + ConstState(const void *from, size_t n) : length(reinterpret_cast<const size_t*>(from)), left(reinterpret_cast<const WordID *>(length + 1)), right(left + n - 1), counts(reinterpret_cast<const GramCount *>(right + n - 1)) {} + const size_t *length; + const WordID *left, *right; + const GramCount *counts; + static size_t Size(size_t n, size_t bound_ngram_id) { return sizeof(size_t) + (n - 1) * 2 * sizeof(WordID) + bound_ngram_id * sizeof(GramCount); } +}; + +template <class T> struct CompatibleHashRange : public std::unary_function<const boost::iterator_range<T> &, size_t> { + size_t operator()(const boost::iterator_range<T> &range) const { + return boost::hash_range(range.begin(), range.end()); + } +}; + +template <class T> struct CompatibleEqualsRange : public std::binary_function<const boost::iterator_range<T> &, const std::vector<WordID> &, size_t> { + size_t operator()(const boost::iterator_range<T> &range, const std::vector<WordID> &vec) const { + return boost::algorithm::equals(range, vec); + } + size_t operator()(const std::vector<WordID> &vec, const boost::iterator_range<T> &range) const { + return boost::algorithm::equals(range, vec); + } +}; + +void AddWord(const boost::circular_buffer<WordID> &segment, size_t min_length, const NGramMap &ref_grams, GramCount *counters) { + typedef boost::circular_buffer<WordID>::const_iterator BufferIt; + typedef boost::iterator_range<BufferIt> SegmentRange; + if (segment.size() < min_length) return; +#if 0 + CompatibleHashRange<BufferIt> hasher; + CompatibleEqualsRange<BufferIt> equals; + for (BufferIt seg_start(segment.end() - min_length); ; --seg_start) { + NGramMap::const_iterator found = ref_grams.find(SegmentRange(seg_start, segment.end())); + if (found == ref_grams.end()) break; + ++counters[found->second]; + if (seg_start == segment.begin()) break; + } +#endif +} + +} // namespace + +class TrombleLossComputerImpl { + public: + explicit TrombleLossComputerImpl(const std::string ¶ms) : star_(TD::Convert("<{STAR}>")) { + typedef boost::tokenizer<boost::char_separator<char> > Tokenizer; + // Argument parsing + std::string ref_file_name; + Tokenizer tok(params, boost::char_separator<char>(" ")); + Tokenizer::iterator i = tok.begin(); + if (i == tok.end()) { + std::cerr << "TrombleLossComputer needs a reference file name." << std::endl; + exit(1); + } + ref_file_name = *i++; + if (i == tok.end()) { + std::cerr << "TrombleLossComputer needs to know how many references." << std::endl; + exit(1); + } + num_refs_ = boost::lexical_cast<unsigned int>(*i++); + for (; i != tok.end(); ++i) { + thetas_.push_back(boost::lexical_cast<double>(*i)); + } + if (thetas_.empty()) { + std::cerr << "TrombleLossComputer is pointless with no weight on n-grams." << std::endl; + exit(1); + } + + // Read references file. + std::ifstream ref_file(ref_file_name.c_str()); + if (!ref_file) { + std::cerr << "Could not open TrombleLossComputer file " << ref_file_name << std::endl; + exit(1); + } + std::string ref; + vector<vector<WordID> > references(num_refs_); + bound_ngram_id_ = 0; + for (unsigned int sentence = 0; ref_file; ++sentence) { + for (unsigned int refidx = 0; refidx < num_refs_; ++refidx) { + if (!getline(ref_file, ref)) { + if (refidx == 0) break; + std::cerr << "Short read of " << refidx << " references for sentence " << sentence << std::endl; + exit(1); + } + TD::ConvertSentence(ref, &references[refidx]); + } + ref_ids_.resize(sentence + 1); + ref_counts_.resize(sentence + 1); + MakeNGramMapFromReferences(references, thetas_.size(), &ref_counts_.back(), &ref_ids_.back()); + bound_ngram_id_ = std::max(bound_ngram_id_, ref_ids_.back().size()); + } + } + + size_t StateSize() const { + // n-1 boundary words plus counts for n-grams currently rendered as bytes even though most would fit in bits. + // Also, this is cached by higher up classes so no need to cache here. + return MutableState::Size(thetas_.size(), bound_ngram_id_); + } + + double Traversal( + const SentenceMetadata &smeta, + const TRule &rule, + const vector<const void*> &ant_contexts, + void *out_context) const { + // TODO: get refs from sentence metadata. + // This will require resizable features. + if (smeta.GetSentenceID() >= ref_ids_.size()) { + std::cerr << "Sentence ID " << smeta.GetSentenceID() << " doesn't have references; there are only " << ref_ids_.size() << " references." << std::endl; + exit(1); + } + const NGramMap &ngrams = ref_ids_[smeta.GetSentenceID()]; + MutableState out_state(out_context, thetas_.size()); + memset(out_state.counts, 0, bound_ngram_id_ * sizeof(GramCount)); + boost::circular_buffer<WordID> history(thetas_.size()); + std::vector<const void*>::const_iterator ant_context = ant_contexts.begin(); + *out_state.length = 0; + size_t pushed = 0; + const size_t keep = thetas_.size() - 1; + for (vector<WordID>::const_iterator rhs = rule.e().begin(); rhs != rule.e().end(); ++rhs) { + if (*rhs < 1) { + assert(ant_context != ant_contexts.end()); + // Constituent + ConstState rhs_state(*ant_context, thetas_.size()); + *out_state.length += *rhs_state.length; + { + GramCount *accum = out_state.counts; + for (const GramCount *c = rhs_state.counts; c != rhs_state.counts + ngrams.size(); ++c, ++accum) { + *accum += *c; + } + } + const WordID *w = rhs_state.left; + bool long_constit = true; + for (size_t i = 1; i <= keep; ++i, ++w) { + if (*w == star_) { + long_constit = false; + break; + } + history.push_back(*w); + if (++pushed == keep) { + std::copy(history.begin(), history.end(), out_state.left); + } + // Now i is the length of the history coming from this constituent. So it needs at least i+1 words to have a cross-child add. + AddWord(history, i + 1, ngrams, out_state.counts); + } + // If the consituent is shorter than thetas_.size(), then the + // constituent's left is the entire constituent, so history is already + // correct. Otherwise, the entire right hand side is the entire + // history. + if (long_constit) { + history.assign(thetas_.size(), rhs_state.right, rhs_state.right + keep); + } + ++ant_context; + } else { + // Word + ++*out_state.length; + history.push_back(*rhs); + if (++pushed == keep) { + std::copy(history.begin(), history.end(), out_state.left); + } + AddWord(history, 1, ngrams, out_state.counts); + } + } + // Fill in left and right constituents. + if (pushed < keep) { + std::copy(history.begin(), history.end(), out_state.left); + for (WordID *i = out_state.left + pushed; i != out_state.left + keep; ++i) { + *i = star_; + } + std::copy(out_state.left, out_state.left + keep, out_state.right); + } else if(pushed == keep) { + std::copy(history.begin(), history.end(), out_state.right); + } else if ((pushed > keep) && !history.empty()) { + std::copy(history.begin() + 1, history.end(), out_state.right); + } + std::vector<RefCounts>::const_iterator ref_info = ref_counts_[smeta.GetSentenceID()].begin(); + // Clip the counts and count matches. + // Indexed by reference then by length. + std::vector<std::vector<unsigned int> > matches(num_refs_, std::vector<unsigned int>(thetas_.size())); + for (GramCount *c = out_state.counts; c != out_state.counts + ngrams.size(); ++c, ++ref_info) { + *c = std::min(*c, ref_info->max); + if (*c) { + for (unsigned int refidx = 0; refidx < num_refs_; ++refidx) { + assert(ref_info->length >= 1); + assert(ref_info->length - 1 < thetas_.size()); + matches[refidx][ref_info->length - 1] += std::min(*c, ref_info->refs[refidx]); + } + } + } + double best_score = 0.0; + for (unsigned int refidx = 0; refidx < num_refs_; ++refidx) { + double score = 0.0; + for (unsigned int j = 0; j < std::min(*out_state.length, thetas_.size()); ++j) { + score += thetas_[j] * static_cast<double>(matches[refidx][j]) / static_cast<double>(*out_state.length - j); + } + best_score = std::max(best_score, score); + } + return best_score; + } + + private: + unsigned int num_refs_; + // Indexed by sentence id. + std::vector<NGramMap> ref_ids_; + // Then by id from ref_ids_. + std::vector<std::vector<RefCounts> > ref_counts_; + + // thetas_[0] is the weight for 1-grams + std::vector<double> thetas_; + + // All ngram ids in ref_ids_ are < this value. + size_t bound_ngram_id_; + + const WordID star_; +}; + +TrombleLossComputer::TrombleLossComputer(const std::string ¶ms) : + boost::base_from_member<PImpl>(new TrombleLossComputerImpl(params)), + FeatureFunction(boost::base_from_member<PImpl>::member->StateSize()), + fid_(FD::Convert("TrombleLossComputer")) {} + +TrombleLossComputer::~TrombleLossComputer() {} + +void TrombleLossComputer::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const { + (void) estimated_features; + const double loss = boost::base_from_member<PImpl>::member->Traversal(smeta, *edge.rule_, ant_contexts, out_context); + features->set_value(fid_, loss); +} |