diff options
| author | Chris Dyer <redpony@gmail.com> | 2014-11-24 00:16:18 -0500 | 
|---|---|---|
| committer | Chris Dyer <redpony@gmail.com> | 2014-11-24 00:16:18 -0500 | 
| commit | 326ea26db4680219f801eeb6622dd2ee378e974b (patch) | |
| tree | b60b1662b1976d27adc222764666786b04e19ad5 | |
| parent | 25af985d0b31732f5dd6cea9ab001495a1aabb01 (diff) | |
| parent | 2931396900c89eb19a50407955574960c364d0ee (diff) | |
Merge pull request #59 from kho/mrmira
Mr.MIRA compatibility
| -rw-r--r-- | decoder/decoder.cc | 32 | ||||
| -rw-r--r-- | decoder/oracle_bleu.h | 37 | ||||
| -rw-r--r-- | utils/Makefile.am | 3 | ||||
| -rw-r--r-- | utils/b64featvector.cc | 55 | ||||
| -rw-r--r-- | utils/b64featvector.h | 12 | 
5 files changed, 127 insertions, 12 deletions
| diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 737201e7..1e6c3194 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }  #include "fdict.h"  #include "timing_stats.h"  #include "verbose.h" +#include "b64featvector.h"  #include "translator.h"  #include "phrasebased_translator.h" @@ -301,6 +302,7 @@ struct DecoderImpl {    bool feature_expectations; // TODO Observer    bool output_training_vector; // TODO Observer    bool remove_intersected_rule_annotations; +  bool mr_mira_compat;  // Mr.MIRA compatibility mode.    boost::scoped_ptr<IncrementalBase> incremental; @@ -415,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")          ("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")          ("forest_output,O",po::value<string>(),"Directory to write forests to") -        ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)"); +        ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") +        ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");    // ob.AddOptions(&opts);    po::options_description clo("Command line options"); @@ -668,6 +671,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    oracle.show_derivation=conf.count("show_derivations");    oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();    remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); +  mr_mira_compat = conf.count("mr_mira_compat");    combine_size = conf["combine_size"].as<int>();    if (combine_size < 1) combine_size = 1; @@ -701,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string    static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);  } +static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) { +  SparseVector<weight_t> delta; +  DecodeFeatureVector(delta_b64, &delta); +  if (delta.empty()) return; +  // Apply updates +  for (SparseVector<weight_t>::iterator dit = delta.begin(); +       dit != delta.end(); ++dit) { +    int feat_id = dit->first; +    union { weight_t weight; unsigned long long repr; } feat_delta; +    feat_delta.weight = dit->second; +    if (!SILENT) +      cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight +           << " = " << hex << feat_delta.repr << endl; +    if (weights->size() <= feat_id) weights->resize(feat_id + 1); +    (*weights)[feat_id] += feat_delta.weight; +  } +} +  bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    string buf = input;    NgramCache::Clear();   // clear ngram cache for remote LM (if used) @@ -711,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    if (sgml.find("id") != sgml.end())      sent_id = atoi(sgml["id"].c_str()); +  // Add delta from input to weights before decoding +  if (mr_mira_compat) +    ApplyWeightDelta(sgml["delta"], init_weights.get()); +    if (!SILENT) {      cerr << "\nINPUT: ";      if (buf.size() < 100) @@ -949,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      if (kbest && !has_ref) {        //TODO: does this work properly?        const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; -      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); +      oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);      } else if (csplit_output_plf) {        cout << HypergraphIO::AsPLF(forest, false) << endl;      } else { @@ -1080,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        if (conf.count("graphviz")) forest.PrintGraphviz();        if (kbest) {          const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; -        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname); +        oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);        }        if (conf.count("show_conditional_prob")) {          const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index 893e36ca..cd587833 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -21,6 +21,7 @@  #include "kbest.h"  #include "timing_stats.h"  #include "sentences.h" +#include "b64featvector.h"  //TODO: put function impls into .cc  //TODO: move Translation into its own .h and use in cdec @@ -255,18 +256,28 @@ struct OracleBleu {    int show_derivation_mask;    template <class Filter> -  void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) { +  void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat, +             int src_len, std::ostream& kbest_out = std::cout, +             std::ostream& deriv_out = std::cerr) {      using namespace std;      using namespace boost;      typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;      K kbest(forest,k);      //add length (f side) src length of this sentence to the psuedo-doc src length count      float curr_src_length = doc_src_length + tmp_src_length; -    for (int i = 0; i < k; ++i) { +    if (mr_mira_compat) kbest_out << k << "\n"; +    int i = 0; +    for (; i < k; ++i) {        typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);        if (!d) break; -      kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| " -                << d->feature_values << " ||| " << log(d->score); +      kbest_out << sent_id << " ||| "; +      if (mr_mira_compat) kbest_out << src_len << " ||| "; +      kbest_out << TD::GetString(d->yield) << " ||| "; +      if (mr_mira_compat) +        kbest_out << EncodeFeatureVector(d->feature_values); +      else +        kbest_out << d->feature_values; +      kbest_out << " ||| " << log(d->score);        if (!refs.empty()) {          ScoreP sentscore = GetScore(d->yield,sent_id);          sentscore->PlusEquals(*doc_score,float(1)); @@ -281,10 +292,17 @@ struct OracleBleu {          deriv_out<<"\n"<<flush;        }      } +    if (mr_mira_compat) { +      for (; i < k; ++i) kbest_out << "\n"; +      kbest_out << flush; +    }    }  // TODO decoder output should probably be moved to another file - how about oracle_bleu.h -  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) { +  void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, +                 const bool unique, const bool mr_mira_compat, +                 const int src_len, std::string const& kbest_out_filename_, +                 std::string const& deriv_out_filename_) {      WriteFile ko(kbest_out_filename_);      std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl; @@ -297,9 +315,11 @@ struct OracleBleu {      WriteFile oderiv(sderiv.str());      if (!unique) -      kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get()); +      kbest<KBest::NoFilter<std::vector<WordID> > >( +          sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());      else { -      kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get()); +      kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len, +                                 ko.get(), oderiv.get());      }    } @@ -307,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo    {      std::ostringstream kbest_string_stream;      kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id; -    DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-"); +    DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(), +              "-");    }  }; diff --git a/utils/Makefile.am b/utils/Makefile.am index 727fa8a5..64f6d433 100644 --- a/utils/Makefile.am +++ b/utils/Makefile.am @@ -22,6 +22,7 @@ libutils_a_SOURCES = \    alias_sampler.h \    alignment_io.h \    array2d.h \ +  b64featvector.h \    b64tools.h \    batched_append.h \    city.h \ @@ -70,6 +71,7 @@ libutils_a_SOURCES = \    fast_lexical_cast.hpp \    intrusive_refcount.hpp \    alignment_io.cc \ +  b64featvector.cc \    b64tools.cc \    corpus_tools.cc \    dict.cc \ @@ -117,4 +119,3 @@ stringlib_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_U  # do NOT NOT NOT add any other -I includes NO NO NO NO NO ######  AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -I. -I$(top_srcdir) -DTEST_DATA=\"$(top_srcdir)/utils/test_data\"  ################################################################ - diff --git a/utils/b64featvector.cc b/utils/b64featvector.cc new file mode 100644 index 00000000..c7d08b29 --- /dev/null +++ b/utils/b64featvector.cc @@ -0,0 +1,55 @@ +#include "b64featvector.h" + +#include <sstream> +#include <boost/scoped_array.hpp> +#include "b64tools.h" +#include "fdict.h" + +using namespace std; + +static inline void EncodeFeatureWeight(const string &featname, weight_t weight, +                                       ostream *output) { +  output->write(featname.data(), featname.size() + 1); +  output->write(reinterpret_cast<char *>(&weight), sizeof(weight_t)); +} + +string EncodeFeatureVector(const SparseVector<weight_t> &vec) { +  string b64; +  { +    ostringstream base64_strm; +    { +      ostringstream strm; +      for (SparseVector<weight_t>::const_iterator it = vec.begin(); +           it != vec.end(); ++it) +        if (it->second != 0) +          EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm); +      string data(strm.str()); +      B64::b64encode(data.data(), data.size(), &base64_strm); +    } +    b64 = base64_strm.str(); +  } +  return b64; +} + +void DecodeFeatureVector(const string &data, SparseVector<weight_t> *vec) { +  vec->clear(); +  if (data.empty()) return; +  // Decode data +  size_t b64_len = data.size(), len = b64_len / 4 * 3; +  boost::scoped_array<char> buf(new char[len]); +  bool res = +      B64::b64decode(reinterpret_cast<const unsigned char *>(data.data()), +                     b64_len, buf.get(), len); +  assert(res); +  // Apply updates +  size_t cur = 0; +  while (cur < len) { +    string feat_name(buf.get() + cur); +    if (feat_name.empty()) break;  // Encountered trailing \0 +    int feat_id = FD::Convert(feat_name); +    weight_t feat_delta = +        *reinterpret_cast<weight_t *>(buf.get() + cur + feat_name.size() + 1); +    (*vec)[feat_id] = feat_delta; +    cur += feat_name.size() + 1 + sizeof(weight_t); +  } +} diff --git a/utils/b64featvector.h b/utils/b64featvector.h new file mode 100644 index 00000000..6ac04d44 --- /dev/null +++ b/utils/b64featvector.h @@ -0,0 +1,12 @@ +#ifndef _B64FEATVECTOR_H_ +#define _B64FEATVECTOR_H_ + +#include <string> + +#include "sparse_vector.h" +#include "weights.h" + +std::string EncodeFeatureVector(const SparseVector<weight_t> &); +void DecodeFeatureVector(const std::string &, SparseVector<weight_t> *); + +#endif  // _B64FEATVECTOR_H_ | 
