From 8c0e4a5c1f168a419b3a236a94815c97164bddbc Mon Sep 17 00:00:00 2001 From: "Wu, Ke" Date: Sun, 12 Oct 2014 16:30:02 -0400 Subject: Cherry picked Mr.MIRA compatibility mode code --- decoder/decoder.cc | 39 ++++++++++++++++++++++++++++------- decoder/oracle_bleu.h | 37 +++++++++++++++++++++++++-------- utils/Makefile.am | 3 ++- utils/b64featvector.cc | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++ utils/b64featvector.h | 12 +++++++++++ 5 files changed, 130 insertions(+), 16 deletions(-) create mode 100644 utils/b64featvector.cc create mode 100644 utils/b64featvector.h diff --git a/decoder/decoder.cc b/decoder/decoder.cc index c384c33f..93282576 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; } #include "fdict.h" #include "timing_stats.h" #include "verbose.h" +#include "b64featvector.h" #include "translator.h" #include "phrasebased_translator.h" @@ -195,7 +196,7 @@ struct DecoderImpl { } forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); if (!forestname.empty()) forestname=" "+forestname; - if (!SILENT) { + if (!SILENT) { forest_stats(forest," Pruned "+forestname+" forest",false,false); cerr << " Pruned "< > rng; int sample_max_trans; bool aligner_mode; - bool graphviz; + bool graphviz; bool joshua_viz; bool encode_b64; bool kbest; @@ -301,6 +302,7 @@ struct DecoderImpl { bool feature_expectations; // TODO Observer bool output_training_vector; // TODO Observer bool remove_intersected_rule_annotations; + bool mr_mira_compat; // Mr.MIRA compatibility mode. boost::scoped_ptr incremental; @@ -414,7 +416,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") ("forest_output,O",po::value(),"Directory to write forests to") - ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") + ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)"); // ob.AddOptions(&opts); po::options_description clo("Command line options"); @@ -666,6 +669,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream get_oracle_forest = conf.count("get_oracle_forest"); oracle.show_derivation=conf.count("show_derivations"); remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); + mr_mira_compat = conf.count("mr_mira_compat"); combine_size = conf["combine_size"].as(); if (combine_size < 1) combine_size = 1; @@ -699,6 +703,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string static_cast(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); } +static inline void ApplyWeightDelta(const string &delta_b64, vector *weights) { + SparseVector delta; + DecodeFeatureVector(delta_b64, &delta); + if (delta.empty()) return; + // Apply updates + for (SparseVector::iterator dit = delta.begin(); + dit != delta.end(); ++dit) { + int feat_id = dit->first; + union { weight_t weight; unsigned long long repr; } feat_delta; + feat_delta.weight = dit->second; + if (!SILENT) + cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight + << " = " << hex << feat_delta.repr << endl; + if (weights->size() <= feat_id) weights->resize(feat_id + 1); + (*weights)[feat_id] += feat_delta.weight; + } +} + bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { string buf = input; NgramCache::Clear(); // clear ngram cache for remote LM (if used) @@ -709,6 +731,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (sgml.find("id") != sgml.end()) sent_id = atoi(sgml["id"].c_str()); + // Add delta from input to weights before decoding + if (mr_mira_compat) + ApplyWeightDelta(sgml["delta"], init_weights.get()); + if (!SILENT) { cerr << "\nINPUT: "; if (buf.size() < 100) @@ -947,7 +973,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (kbest && !has_ref) { //TODO: does this work properly? const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } else if (csplit_output_plf) { cout << HypergraphIO::AsPLF(forest, false) << endl; } else { @@ -1078,7 +1104,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("graphviz")) forest.PrintGraphviz(); if (kbest) { const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; - oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-", deriv_fname); + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); } if (conf.count("show_conditional_prob")) { const prob_t ref_z = Inside(forest); @@ -1098,4 +1124,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { o->NotifyDecodingComplete(smeta); return true; } - diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..75db61e8 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@ #include "kbest.h" #include "timing_stats.h" #include "sentences.h" +#include "b64featvector.h" //TODO: put function impls into .cc //TODO: move Translation into its own .h and use in cdec @@ -253,18 +254,28 @@ struct OracleBleu { bool show_derivation; template - void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { + void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, + int src_len, std::ostream& kbest_out = std::cout, + std::ostream& deriv_out = std::cerr) { using namespace std; using namespace boost; typedef KBest::KBestDerivations K; K kbest(forest,k); //add length (f side) src length of this sentence to the psuedo-doc src length count float curr_src_length = doc_src_length + tmp_src_length; - for (int i = 0; i < k; ++i) { + if (mr_mira_compat) kbest_out << k << "\n"; + int i = 0; + for (; i < k; ++i) { typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i); if (!d) break; - kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " - << d->feature_values << " ||| " << log(d->score); + kbest_out << sent_id << " ||| "; + if (mr_mira_compat) kbest_out << src_len << " ||| "; + kbest_out << TD::GetString(d->yield) << " ||| "; + if (mr_mira_compat) + kbest_out << EncodeFeatureVector(d->feature_values); + else + kbest_out << d->feature_values; + kbest_out << " ||| " << log(d->score); if (!refs.empty()) { ScoreP sentscore = GetScore(d->yield,sent_id); sentscore->PlusEquals(*doc_score,float(1)); @@ -279,10 +290,17 @@ struct OracleBleu { deriv_out<<"\n"< > >(sent_id,forest,k,ko.get(),oderiv.get()); + kbest > >( + sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get()); else { - kbest(sent_id,forest,k,ko.get(),oderiv.get()); + kbest(sent_id, forest, k, mr_mira_compat, src_len, + ko.get(), oderiv.get()); } } @@ -305,7 +325,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo { std::ostringstream kbest_string_stream; kbest_string_stream << forest_output << "/kbest_"< +#include +#include "b64tools.h" +#include "fdict.h" + +using namespace std; + +static inline void EncodeFeatureWeight(const string &featname, weight_t weight, + ostream *output) { + output->write(featname.data(), featname.size() + 1); + output->write(reinterpret_cast(&weight), sizeof(weight_t)); +} + +string EncodeFeatureVector(const SparseVector &vec) { + string b64; + { + ostringstream base64_strm; + { + ostringstream strm; + for (SparseVector::const_iterator it = vec.begin(); + it != vec.end(); ++it) + if (it->second != 0) + EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm); + string data(strm.str()); + B64::b64encode(data.data(), data.size(), &base64_strm); + } + b64 = base64_strm.str(); + } + return b64; +} + +void DecodeFeatureVector(const string &data, SparseVector *vec) { + vec->clear(); + if (data.empty()) return; + // Decode data + size_t b64_len = data.size(), len = b64_len / 4 * 3; + boost::scoped_array buf(new char[len]); + bool res = + B64::b64decode(reinterpret_cast(data.data()), + b64_len, buf.get(), len); + assert(res); + // Apply updates + size_t cur = 0; + while (cur < len) { + string feat_name(buf.get() + cur); + if (feat_name.empty()) break; // Encountered trailing \0 + int feat_id = FD::Convert(feat_name); + weight_t feat_delta = + *reinterpret_cast(buf.get() + cur + feat_name.size() + 1); + (*vec)[feat_id] = feat_delta; + cur += feat_name.size() + 1 + sizeof(weight_t); + } +} diff --git a/utils/b64featvector.h b/utils/b64featvector.h new file mode 100644 index 00000000..6ac04d44 --- /dev/null +++ b/utils/b64featvector.h @@ -0,0 +1,12 @@ +#ifndef _B64FEATVECTOR_H_ +#define _B64FEATVECTOR_H_ + +#include + +#include "sparse_vector.h" +#include "weights.h" + +std::string EncodeFeatureVector(const SparseVector &); +void DecodeFeatureVector(const std::string &, SparseVector *); + +#endif // _B64FEATVECTOR_H_ -- cgit v1.2.3