summaryrefslogtreecommitdiff
path: root/decoder/decoder.cc
diff options
context:
space:
mode:
authorWu, Ke <wuke@cs.umd.edu>2014-12-17 16:15:13 -0500
committerWu, Ke <wuke@cs.umd.edu>2014-12-17 16:15:13 -0500
commit6829a0bc624b02ebefc79f8cf9ec89d7d64a7c30 (patch)
tree125dfb20f73342873476c793995397b26fd202dd /decoder/decoder.cc
parentb455a108a21f4ba5a58ab1bc53a8d2bf4d829067 (diff)
parent7468e8d85e99b4619442c7afaf4a0d92870111bb (diff)
Merge branch 'const_reorder_2' into softsyn_2
Diffstat (limited to 'decoder/decoder.cc')
-rw-r--r--decoder/decoder.cc55
1 files changed, 41 insertions, 14 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index c384c33f..1e6c3194 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -17,6 +17,7 @@ namespace std { using std::tr1::unordered_map; }
#include "fdict.h"
#include "timing_stats.h"
#include "verbose.h"
+#include "b64featvector.h"
#include "translator.h"
#include "phrasebased_translator.h"
@@ -86,7 +87,7 @@ struct ELengthWeightFunction {
}
};
inline void ShowBanner() {
- cerr << "cdec (c) 2009--2014 by Chris Dyer\n";
+ cerr << "cdec (c) 2009--2014 by Chris Dyer" << endl;
}
inline string str(char const* name,po::variables_map const& conf) {
@@ -195,7 +196,7 @@ struct DecoderImpl {
}
forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);
if (!forestname.empty()) forestname=" "+forestname;
- if (!SILENT) {
+ if (!SILENT) {
forest_stats(forest," Pruned "+forestname+" forest",false,false);
cerr << " Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;
}
@@ -261,7 +262,7 @@ struct DecoderImpl {
assert(ref);
LatticeTools::ConvertTextOrPLF(sref, ref);
}
- }
+ }
// used to construct the suffix string to get the name of arguments for multiple passes
// e.g., the "2" in --weights2
@@ -284,7 +285,7 @@ struct DecoderImpl {
boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;
int sample_max_trans;
bool aligner_mode;
- bool graphviz;
+ bool graphviz;
bool joshua_viz;
bool encode_b64;
bool kbest;
@@ -301,6 +302,7 @@ struct DecoderImpl {
bool feature_expectations; // TODO Observer
bool output_training_vector; // TODO Observer
bool remove_intersected_rule_annotations;
+ bool mr_mira_compat; // Mr.MIRA compatibility mode.
boost::scoped_ptr<IncrementalBase> incremental;
@@ -404,6 +406,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")
("show_derivations", po::value<string>(), "Directory to print the derivation structures to")
+ ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")
("graphviz","Show (constrained) translation forest in GraphViz format")
("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")
("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")
@@ -414,7 +417,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)")
("combine_size,C",po::value<int>()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)")
("forest_output,O",po::value<string>(),"Directory to write forests to")
- ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)");
+ ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)")
+ ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)");
// ob.AddOptions(&opts);
po::options_description clo("Command line options");
@@ -665,7 +669,9 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
unique_kbest = conf.count("unique_k_best");
get_oracle_forest = conf.count("get_oracle_forest");
oracle.show_derivation=conf.count("show_derivations");
+ oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();
remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");
+ mr_mira_compat = conf.count("mr_mira_compat");
combine_size = conf["combine_size"].as<int>();
if (combine_size < 1) combine_size = 1;
@@ -699,6 +705,24 @@ void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string
static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
+static inline void ApplyWeightDelta(const string &delta_b64, vector<weight_t> *weights) {
+ SparseVector<weight_t> delta;
+ DecodeFeatureVector(delta_b64, &delta);
+ if (delta.empty()) return;
+ // Apply updates
+ for (SparseVector<weight_t>::iterator dit = delta.begin();
+ dit != delta.end(); ++dit) {
+ int feat_id = dit->first;
+ union { weight_t weight; unsigned long long repr; } feat_delta;
+ feat_delta.weight = dit->second;
+ if (!SILENT)
+ cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight
+ << " = " << hex << feat_delta.repr << endl;
+ if (weights->size() <= feat_id) weights->resize(feat_id + 1);
+ (*weights)[feat_id] += feat_delta.weight;
+ }
+}
+
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
@@ -709,6 +733,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (sgml.find("id") != sgml.end())
sent_id = atoi(sgml["id"].c_str());
+ // Add delta from input to weights before decoding
+ if (mr_mira_compat)
+ ApplyWeightDelta(sgml["delta"], init_weights.get());
+
if (!SILENT) {
cerr << "\nINPUT: ";
if (buf.size() < 100)
@@ -928,14 +956,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -947,7 +975,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (kbest && !has_ref) {
//TODO: does this work properly?
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
} else if (csplit_output_plf) {
cout << HypergraphIO::AsPLF(forest, false) << endl;
} else {
@@ -1021,14 +1049,14 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
Hypergraph new_hg;
{
ReadFile rf(writer.fname_);
- bool succeeded = HypergraphIO::ReadFromJSON(rf.stream(), &new_hg);
+ bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg);
if (!succeeded) abort();
}
HG::Union(forest, &new_hg);
- bool succeeded = writer.Write(new_hg, false);
+ bool succeeded = writer.Write(new_hg);
if (!succeeded) abort();
} else {
- bool succeeded = writer.Write(forest, false);
+ bool succeeded = writer.Write(forest);
if (!succeeded) abort();
}
}
@@ -1078,7 +1106,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
if (conf.count("graphviz")) forest.PrintGraphviz();
if (kbest) {
const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-";
- oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-", deriv_fname);
+ oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname);
}
if (conf.count("show_conditional_prob")) {
const prob_t ref_z = Inside<prob_t, EdgeProb>(forest);
@@ -1098,4 +1126,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
o->NotifyDecodingComplete(smeta);
return true;
}
-