diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/decoder.cc | 17 | ||||
| -rw-r--r-- | decoder/lextrans.cc | 7 | 
2 files changed, 15 insertions, 9 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 065510a7..daf82f10 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -354,7 +354,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("show_tree_structure", "Show the Viterbi derivation structure")          ("show_expected_length", "Show the expected translation length under the model")          ("show_partition,z", "Compute and show the partition (inside score)") -        ("show_partition_as_translation", "Output the partition to STDOUT instead of a translation") +        ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")          ("show_cfg_search_space", "Show the search space as a CFG")      ("show_features","Show the feature vector for the viterbi translation")      ("prelm_density_prune", po::value<double>(), "Applied to -LM forest just before final LM rescoring: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") @@ -680,7 +680,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      if (!SILENT) cerr << "  NO PARSE FOUND.\n";      o->NotifySourceParseFailure(smeta);      o->NotifyDecodingComplete(smeta); -    if (conf.count("show_partition_as_translation")) { +    if (conf.count("show_conditional_prob")) {        cout << "-Inf" << endl << flush;      }      return false; @@ -807,6 +807,11 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {      }    } +  prob_t first_z; +  if (conf.count("show_conditional_prob")) { +    first_z = Inside<prob_t, EdgeProb>(forest); +  } +    // TODO this should be handled by an Observer    const int max_trans_beam_size = conf.count("max_translation_beam") ?      conf["max_translation_beam"].as<int>() : 0; @@ -910,9 +915,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        if (conf.count("graphviz")) forest.PrintGraphviz();        if (kbest)          oracle.DumpKBest(sent_id, forest, conf["k_best"].as<int>(), unique_kbest,"-"); -      if (conf.count("show_partition_as_translation")) { -        const prob_t z = Inside<prob_t, EdgeProb>(forest); -        cout << log(z) << endl << flush; +      if (conf.count("show_conditional_prob")) { +        const prob_t ref_z = Inside<prob_t, EdgeProb>(forest); +        cout << (log(ref_z) - log(first_z)) << endl << flush;        }      } else {        o->NotifyAlignmentFailure(smeta); @@ -920,7 +925,7 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        if (write_gradient) {          cout << endl << flush;        } -      if (conf.count("show_partition_as_translation")) { +      if (conf.count("show_conditional_prob")) {          cout << "-Inf" << endl << flush;        }      } diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 551e77e3..c3bd775f 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -60,7 +60,7 @@ struct LexicalTransImpl {      }    } -  void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { +  bool BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) {      if (psg_file_) {        const string offset = smeta.GetSGMLValue("psg");        if (offset.size() < 2 || offset[0] != '@') { @@ -86,7 +86,7 @@ struct LexicalTransImpl {              gi = sup_grammar->GetRoot()->Extend(src_sym);            if (!gi) {              cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; -            abort(); +            return false;            }          }          const RuleBin* rb = gi->GetRules(); @@ -117,6 +117,7 @@ struct LexicalTransImpl {      Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);      Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);      forest->ConnectEdgeToHeadNode(hg_edge, goal); +    return true;    }   private: @@ -146,7 +147,7 @@ bool LexicalTrans::TranslateImpl(const string& input,      abort();    }    smeta->SetSourceLength(lattice.size()); -  pimpl_->BuildTrellis(lattice, *smeta, forest); +  if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false;    forest->is_linear_chain_ = true;    forest->Reweight(weights);    return true;  | 
