diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/decoder.cc | 9 | ||||
| -rw-r--r-- | decoder/oracle_bleu.h | 4 | 
2 files changed, 8 insertions, 5 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 9e8d692a..3cc77d27 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -195,7 +195,7 @@ struct DecoderImpl {        }        forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1);        if (!forestname.empty()) forestname=" "+forestname; -      if (!SILENT) {  +      if (!SILENT) {          forest_stats(forest,"  Pruned "+forestname+" forest",false,false);          cerr << "  Pruned "<<forestname<<" forest portion of edges kept: "<<forest.edges_.size()/presize<<endl;        } @@ -261,7 +261,7 @@ struct DecoderImpl {        assert(ref);        LatticeTools::ConvertTextOrPLF(sref, ref);      } -  }  +  }    // used to construct the suffix string to get the name of arguments for multiple passes    // e.g., the "2" in --weights2 @@ -284,7 +284,7 @@ struct DecoderImpl {    boost::shared_ptr<RandomNumberGenerator<boost::mt19937> > rng;    int sample_max_trans;    bool aligner_mode; -  bool graphviz;  +  bool graphviz;    bool joshua_viz;    bool encode_b64;    bool kbest; @@ -404,6 +404,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")          ("extract_rules", po::value<string>(), "Extract the rules used in translation (not de-duped!) to a file in this directory")          ("show_derivations", po::value<string>(), "Directory to print the derivation structures to") +        ("show_derivations_mask", po::value<int>()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures")          ("graphviz","Show (constrained) translation forest in GraphViz format")          ("max_translation_beam,x", po::value<int>(), "Beam approximation to get max translation from the chart")          ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart") @@ -665,6 +666,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    unique_kbest = conf.count("unique_k_best");    get_oracle_forest = conf.count("get_oracle_forest");    oracle.show_derivation=conf.count("show_derivations"); +  oracle.show_derivation_mask=conf["show_derivations_mask"].as<int>();    remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations");    combine_size = conf["combine_size"].as<int>(); @@ -1098,4 +1100,3 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    o->NotifyDecodingComplete(smeta);    return true;  } - diff --git a/decoder/oracle_bleu.h b/decoder/oracle_bleu.h index d2c4715c..893e36ca 100644 --- a/decoder/oracle_bleu.h +++ b/decoder/oracle_bleu.h @@ -252,6 +252,8 @@ struct OracleBleu {    }    bool show_derivation; +  int show_derivation_mask; +    template <class Filter>    void kbest(int sent_id,Hypergraph const& forest,int k,std::ostream &kbest_out=std::cout,std::ostream &deriv_out=std::cerr) {      using namespace std; @@ -275,7 +277,7 @@ struct OracleBleu {        if (show_derivation) {          deriv_out<<"\nsent_id="<<sent_id<<"."<<i<<" ||| "; //where i is candidate #/k          deriv_out<<log(d->score)<<"\n"; -        deriv_out<<kbest.derivation_tree(*d,true); +        deriv_out<<kbest.derivation_tree(*d,true, show_derivation_mask);          deriv_out<<"\n"<<flush;        }      }  | 
