From c8c315a4f78c464636ea5e3fd9a11416b2f966b9 Mon Sep 17 00:00:00 2001 From: redpony Date: Mon, 15 Nov 2010 20:22:22 +0000 Subject: rescoring working git-svn-id: https://ws10smt.googlecode.com/svn/trunk@726 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/decoder.cc | 17 +++++++++++------ decoder/lextrans.cc | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) (limited to 'decoder') diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 065510a7..daf82f10 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -354,7 +354,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_tree_structure", "Show the Viterbi derivation structure") ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") - ("show_partition_as_translation", "Output the partition to STDOUT instead of a translation") + ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") ("show_features","Show the feature vector for the viterbi translation") ("prelm_density_prune", po::value(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -680,7 +680,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (!SILENT) cerr << " NO PARSE FOUND.\n"; o->NotifySourceParseFailure(smeta); o->NotifyDecodingComplete(smeta); - if (conf.count("show_partition_as_translation")) { + if (conf.count("show_conditional_prob")) { cout << "-Inf" << endl << flush; } return false; @@ -807,6 +807,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } } + prob_t first_z; + if (conf.count("show_conditional_prob")) { + first_z = Inside(forest); + } + // TODO this should be handled by an Observer const int max_trans_beam_size = conf.count("max_translation_beam") ? conf["max_translation_beam"].as() : 0; @@ -910,9 +915,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (conf.count("graphviz")) forest.PrintGraphviz(); if (kbest) oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,"-"); - if (conf.count("show_partition_as_translation")) { - const prob_t z = Inside(forest); - cout << log(z) << endl << flush; + if (conf.count("show_conditional_prob")) { + const prob_t ref_z = Inside(forest); + cout << (log(ref_z) - log(first_z)) << endl << flush; } } else { o->NotifyAlignmentFailure(smeta); @@ -920,7 +925,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { if (write_gradient) { cout << endl << flush; } - if (conf.count("show_partition_as_translation")) { + if (conf.count("show_conditional_prob")) { cout << "-Inf" << endl << flush; } } diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 551e77e3..c3bd775f 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -60,7 +60,7 @@ struct LexicalTransImpl { } } - void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + bool BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { if (psg_file_) { const string offset = smeta.GetSGMLValue("psg"); if (offset.size() < 2 || offset[0] != '@') { @@ -86,7 +86,7 @@ struct LexicalTransImpl { gi = sup_grammar->GetRoot()->Extend(src_sym); if (!gi) { cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; - abort(); + return false; } } const RuleBin* rb = gi->GetRules(); @@ -117,6 +117,7 @@ struct LexicalTransImpl { Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); forest->ConnectEdgeToHeadNode(hg_edge, goal); + return true; } private: @@ -146,7 +147,7 @@ bool LexicalTrans::TranslateImpl(const string& input, abort(); } smeta->SetSourceLength(lattice.size()); - pimpl_->BuildTrellis(lattice, *smeta, forest); + if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false; forest->is_linear_chain_ = true; forest->Reweight(weights); return true; -- cgit v1.2.3