summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-11-24 00:16:18 -0500
committerChris Dyer <redpony@gmail.com>2014-11-24 00:16:18 -0500
commit326ea26db4680219f801eeb6622dd2ee378e974b (patch)
treeb60b1662b1976d27adc222764666786b04e19ad5
parent25af985d0b31732f5dd6cea9ab001495a1aabb01 (diff)
parent2931396900c89eb19a50407955574960c364d0ee (diff)
Merge pull request #59 from kho/mrmira
Mr.MIRA compatibility
-rw-r--r--decoder/decoder.cc32
-rw-r--r--decoder/oracle_bleu.h37
-rw-r--r--utils/Makefile.am3
-rw-r--r--utils/b64featvector.cc55
-rw-r--r--utils/b64featvector.h12
5 files changed, 127 insertions, 12 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 737201e7..1e6c3194 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }
#include "fdict.h"
#include "timing_stats.h"
#include "verbose.h"
+#include "b64featvector.h"
#include "translator.h"
#include "phrasebased_translator.h"
@@ -301,6 +302,7 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ bool mr_mira_compat; // Mr.MIRA compatibility mode.
boost::scoped_ptr<IncrementalBase> incremental;
@@ -415,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)")
+ ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");
// ob.AddOptions(&opts);
po::options_description clo("Command line options");
@@ -668,6 +671,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
oracle.show_derivation=conf.count("show_derivations");
oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
+ mr_mira_compat = conf.count("mr_mira_compat");
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
@@ -701,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string
static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
+static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) {
+ SparseVector<weight_t> delta;
+ DecodeFeatureVector(delta_b64, &delta);
+ if (delta.empty()) return;
+ // Apply updates
+ for (SparseVector<weight_t>::iterator dit = delta.begin();
+ dit != delta.end(); ++dit) {
+ int feat_id = dit->first;
+ union { weight_t weight; unsigned long long repr; } feat_delta;
+ feat_delta.weight = dit->second;
+ if (!SILENT)
+ cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight
+ << " = " << hex << feat_delta.repr << endl;
+ if (weights->size() <= feat_id) weights->resize(feat_id + 1);
+ (*weights)[feat_id] += feat_delta.weight;
+ }
+}
+
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
@@ -711,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (sgml.find("id") != sgml.end())
sent_id = atoi(sgml["id"].c_str());
+ // Add delta from input to weights before decoding
+ if (mr_mira_compat)
+ ApplyWeightDelta(sgml["delta"], init_weights.get());
+
if (!SILENT) {
cerr << "\nINPUT: ";
if (buf.size() < 100)
@@ -949,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (kbest && !has_ref) {
//TODO: does this work properly?
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1080,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (conf.count("graphviz")) forest.PrintGraphviz();
if (kbest) {
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
}
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h
index 893e36ca..cd587833 100644
--- a/decoder/oracle_bleu.h
+++ b/decoder/oracle_bleu.h
@@ -21,6 +21,7 @@
#include "kbest.h"
#include "timing_stats.h"
#include "sentences.h"
+#include "b64featvector.h"
//TODO: put function impls into .cc
//TODO: move Translation into its own .h and use in cdec
@@ -255,18 +256,28 @@ struct OracleBleu {
int show_derivation_mask;
template <class Filter>
- void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {
+ void kbest(int sent_id, Hypergraph const& forest, int k, bool mr_mira_compat,
+ int src_len, std::ostream& kbest_out = std::cout,
+ std::ostream& deriv_out = std::cerr) {
using namespace std;
using namespace boost;
typedef KBest::KBestDerivations<Sentence, ESentenceTraversal,Filter> K;
K kbest(forest,k);
//add length (f side) src length of this sentence to the psuedo-doc src length count
float curr_src_length = doc_src_length + tmp_src_length;
- for (int i = 0; i < k; ++i) {
+ if (mr_mira_compat) kbest_out << k << "\n";
+ int i = 0;
+ for (; i < k; ++i) {
typename K::Derivation *d = kbest.LazyKthBest(forest.nodes_.size() - 1, i);
if (!d) break;
- kbest_out << sent_id << " ||| " << TD::GetString(d->yield) << " ||| "
- << d->feature_values << " ||| " << log(d->score);
+ kbest_out << sent_id << " ||| ";
+ if (mr_mira_compat) kbest_out << src_len << " ||| ";
+ kbest_out << TD::GetString(d->yield) << " ||| ";
+ if (mr_mira_compat)
+ kbest_out << EncodeFeatureVector(d->feature_values);
+ else
+ kbest_out << d->feature_values;
+ kbest_out << " ||| " << log(d->score);
if (!refs.empty()) {
ScoreP sentscore = GetScore(d->yield,sent_id);
sentscore->PlusEquals(*doc_score,float(1));
@@ -281,10 +292,17 @@ struct OracleBleu {
deriv_out<<"\n"<<flush;
}
}
+ if (mr_mira_compat) {
+ for (; i < k; ++i) kbest_out << "\n";
+ kbest_out << flush;
+ }
}
// TODO decoder output should probably be moved to another file - how about oracle_bleu.h
- void DumpKBest(const int sent_id, const Hypergraph& forest, const int k, const bool unique, std::string const &kbest_out_filename_, std::string const &deriv_out_filename_) {
+ void DumpKBest(const int sent_id, const Hypergraph& forest, const int k,
+ const bool unique, const bool mr_mira_compat,
+ const int src_len, std::string const& kbest_out_filename_,
+ std::string const& deriv_out_filename_) {
WriteFile ko(kbest_out_filename_);
std::cerr << "Output kbest to " << kbest_out_filename_ <<std::endl;
@@ -297,9 +315,11 @@ struct OracleBleu {
WriteFile oderiv(sderiv.str());
if (!unique)
- kbest<KBest::NoFilter<std::vector<WordID> > >(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::NoFilter<std::vector<WordID> > >(
+ sent_id, forest, k, mr_mira_compat, src_len, ko.get(), oderiv.get());
else {
- kbest<KBest::FilterUnique>(sent_id,forest,k,ko.get(),oderiv.get());
+ kbest<KBest::FilterUnique>(sent_id, forest, k, mr_mira_compat, src_len,
+ ko.get(), oderiv.get());
}
}
@@ -307,7 +327,8 @@ void DumpKBest(std::string const& suffix,const int sent_id, const Hypergraph& fo
{
std::ostringstream kbest_string_stream;
kbest_string_stream << forest_output << "/kbest_"<<suffix<< "." << sent_id;
- DumpKBest(sent_id, forest, k, unique, kbest_string_stream.str(), "-");
+ DumpKBest(sent_id, forest, k, unique, false, -1, kbest_string_stream.str(),
+ "-");
}
};
diff --git a/utils/Makefile.am b/utils/Makefile.am
index 727fa8a5..64f6d433 100644
--- a/utils/Makefile.am
+++ b/utils/Makefile.am
@@ -22,6 +22,7 @@ libutils_a_SOURCES = \
alias_sampler.h \
alignment_io.h \
array2d.h \
+ b64featvector.h \
b64tools.h \
batched_append.h \
city.h \
@@ -70,6 +71,7 @@ libutils_a_SOURCES = \
fast_lexical_cast.hpp \
intrusive_refcount.hpp \
alignment_io.cc \
+ b64featvector.cc \
b64tools.cc \
corpus_tools.cc \
dict.cc \
@@ -117,4 +119,3 @@ stringlib_test_LDADD = libutils.a $(BOOST_UNIT_TEST_FRAMEWORK_LDFLAGS) $(BOOST_U
# do NOT NOT NOT add any other -I includes NO NO NO NO NO ######
AM_CPPFLAGS = -DBOOST_TEST_DYN_LINK -W -Wall -I. -I$(top_srcdir) -DTEST_DATA=\"$(top_srcdir)/utils/test_data\"
################################################################
-
diff --git a/utils/b64featvector.cc b/utils/b64featvector.cc
new file mode 100644
index 00000000..c7d08b29
--- /dev/null
+++ b/utils/b64featvector.cc
@@ -0,0 +1,55 @@
+#include "b64featvector.h"
+
+#include <sstream>
+#include <boost/scoped_array.hpp>
+#include "b64tools.h"
+#include "fdict.h"
+
+using namespace std;
+
+static inline void EncodeFeatureWeight(const string &featname, weight_t weight,
+ ostream *output) {
+ output->write(featname.data(), featname.size() + 1);
+ output->write(reinterpret_cast<char *>(&weight), sizeof(weight_t));
+}
+
+string EncodeFeatureVector(const SparseVector<weight_t> &vec) {
+ string b64;
+ {
+ ostringstream base64_strm;
+ {
+ ostringstream strm;
+ for (SparseVector<weight_t>::const_iterator it = vec.begin();
+ it != vec.end(); ++it)
+ if (it->second != 0)
+ EncodeFeatureWeight(FD::Convert(it->first), it->second, &strm);
+ string data(strm.str());
+ B64::b64encode(data.data(), data.size(), &base64_strm);
+ }
+ b64 = base64_strm.str();
+ }
+ return b64;
+}
+
+void DecodeFeatureVector(const string &data, SparseVector<weight_t> *vec) {
+ vec->clear();
+ if (data.empty()) return;
+ // Decode data
+ size_t b64_len = data.size(), len = b64_len / 4 * 3;
+ boost::scoped_array<char> buf(new char[len]);
+ bool res =
+ B64::b64decode(reinterpret_cast<const unsigned char *>(data.data()),
+ b64_len, buf.get(), len);
+ assert(res);
+ // Apply updates
+ size_t cur = 0;
+ while (cur < len) {
+ string feat_name(buf.get() + cur);
+ if (feat_name.empty()) break; // Encountered trailing \0
+ int feat_id = FD::Convert(feat_name);
+ weight_t feat_delta =
+ *reinterpret_cast<weight_t *>(buf.get() + cur + feat_name.size() + 1);
+ (*vec)[feat_id] = feat_delta;
+ cur += feat_name.size() + 1 + sizeof(weight_t);
+ }
+}
diff --git a/utils/b64featvector.h b/utils/b64featvector.h
new file mode 100644
index 00000000..6ac04d44
--- /dev/null
+++ b/utils/b64featvector.h
@@ -0,0 +1,12 @@
+#ifndef _B64FEATVECTOR_H_
+#define _B64FEATVECTOR_H_
+
+#include <string>
+
+#include "sparse_vector.h"
+#include "weights.h"
+
+std::string EncodeFeatureVector(const SparseVector<weight_t> &);
+void DecodeFeatureVector(const std::string &, SparseVector<weight_t> *);
+
+#endif // _B64FEATVECTOR_H_