summaryrefslogtreecommitdiff
path: root/decoder/cfg_format.h
blob: 1036180410aff56c383d0c6182684fd937692e34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#ifndef CFG_FORMAT_H
#define CFG_FORMAT_H

#include <string>
#include "wordid.h"
#include "feature_vector.h"
#include "program_options.h"

struct CFGFormat {
  bool identity_scfg;
  bool features;
  bool logprob_feat;
  bool cfg_comma_nt;
  bool nt_span;
  std::string goal_nt_name;
  std::string nt_prefix;
  std::string logprob_feat_name;
  std::string partsep;
  template <class Opts> // template to support both printable_opts and boost nonprintable
  void AddOptions(Opts *opts) {
    //using namespace boost::program_options;
    //using namespace std;
    opts->add_options()
      ("identity_scfg",defaulted_value(&identity_scfg),"output an identity SCFG: add an identity target side - '[X12] ||| [X13,1] a ||| [1] a ||| feat= ...' - the redundant target '[1] a |||' is omitted otherwise.")
      ("features",defaulted_value(&features),"print the CFG feature vector")
      ("logprob_feat",defaulted_value(&logprob_feat),"print a LogProb=-1.5 feature irrespective of --features.")
      ("logprob_feat_name",defaulted_value(&logprob_feat_name),"alternate name for the LogProb feature")
      ("cfg_comma_nt",defaulted_value(&cfg_comma_nt),"if false, omit the usual [NP,1] ',1' variable index in the source side")
      ("goal_nt_name",defaulted_value(&goal_nt_name),"if nonempty, the first production will be '[goal_nt_name] ||| [x123] ||| LogProb=y' where x123 is the actual goal nt, and y is the pushed prob, if any")
      ("nt_prefix",defaulted_value(&nt_prefix),"NTs are [<nt_prefix>123] where 123 is the node number starting at 0, and the highest node (last in file) is the goal node in an acyclic hypergraph")
      ("nt_span",defaulted_value(&nt_span),"prefix A(i,j) for NT coming from hypergraph node with category A on span [i,j).  this is after --nt_prefix if any")
      ;
  }
  void Validate() {  }
  template<class CFG>
  void print_source_nt(std::ostream &o,CFG const&cfg,int id,int position=1) const {
    o<<'[';
    print_nt_name(o,cfg,id);
    if (cfg_comma_nt) o<<','<<position;
    o<<']';
  }

  template <class CFG>
  void print_nt_name(std::ostream &o,CFG const& cfg,int id) const {
    o<<nt_prefix;
    cfg.print_nt_name(o,id);
    o<<id;
  }

  template <class CFG>
  void print_lhs(std::ostream &o,CFG const& cfg,int id) const {
    o<<'[';
    print_nt_name(o,cfg,id);
    o<<']';
  }

  template <class CFG,class Iter>
  void print_rhs(std::ostream &o,CFG const&cfg,Iter begin,Iter end) const {
    o<<partsep;
    int pos=0;
    for (Iter i=begin;i!=end;++i) {
      WordID w=*i;
      if (i!=begin) o<<' ';
      if (w>0) o << TD::Convert(w);
      else print_source_nt(o,cfg,-w,++pos);
    }
    if (identity_scfg) {
      o<<partsep;
      int pos=0;
      for (Iter i=begin;i!=end;++i) {
        WordID w=*i;
        if (i!=begin) o<<' ';
        if (w>0) o << TD::Convert(w);
        else o << '['<<++pos<<']';
      }
    }
  }

  void print_features(std::ostream &o,prob_t p,FeatureVector const& fv=FeatureVector()) const {
    bool logp=(logprob_feat && p!=1);
    if (features || logp) {
      o << partsep;
      if (logp)
        o << logprob_feat_name<<'='<<log(p)<<' ';
      if (features)
        o << fv;
    }
  }

  void set_defaults() {
    identity_scfg=false;
    features=true;
    logprob_feat=true;
    cfg_comma_nt=true;
    goal_nt_name="S";
    logprob_feat_name="LogProb";
    nt_prefix="";
    partsep=" ||| ";
    nt_span=true;
  }

  CFGFormat() {
    set_defaults();
  }
};



#endif