diff options
| author | Chris Dyer <redpony@gmail.com> | 2009-12-18 01:27:19 -0500 | 
|---|---|---|
| committer | Chris Dyer <redpony@gmail.com> | 2009-12-18 01:27:19 -0500 | 
| commit | 1aac806af7785ab440d300ca5cfa8833e3ed61d3 (patch) | |
| tree | 4a2ffa484af029ebc542f2cdf7bb6da93325b29a /decoder | |
| parent | 40ac2d31391c27b168b0294e7683cb69da29f868 (diff) | |
add support for freezing the feature set to a user-specified list, even if feature detectors create additional features
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/cdec.cc | 24 | ||||
| -rw-r--r-- | decoder/dict.h | 4 | ||||
| -rw-r--r-- | decoder/fdict.cc | 1 | ||||
| -rw-r--r-- | decoder/fdict.h | 11 | ||||
| -rw-r--r-- | decoder/sparse_vector.h | 14 | ||||
| -rw-r--r-- | decoder/trule.cc | 9 | 
6 files changed, 47 insertions, 16 deletions
| diff --git a/decoder/cdec.cc b/decoder/cdec.cc index c6773cce..c6a0057f 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -53,6 +53,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {          ("input,i",po::value<string>()->default_value("-"),"Source file")          ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")          ("weights,w",po::value<string>(),"Feature weights file") +        ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file")          ("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")          ("list_feature_functions,L","List available feature functions")          ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") @@ -248,6 +249,20 @@ int main(int argc, char** argv) {      exit(1);    } +  // load feature weights (and possibly freeze feature set) +  vector<double> feature_weights; +  Weights w; +  if (conf.count("weights")) { +    w.InitFromFile(conf["weights"].as<string>()); +    feature_weights.resize(FD::NumFeats()); +    w.InitVector(&feature_weights); +    if (!conf.count("no_freeze_feature_set")) { +      cerr << "Freezing feature set (use --no_freeze_feature_set to change)." << endl; +      FD::Freeze(); +    } +  } + +  // set up translation back end    if (formalism == "scfg")      translator.reset(new SCFGTranslator(conf));    else if (formalism == "fst") @@ -263,14 +278,6 @@ int main(int argc, char** argv) {    else      assert(!"error"); -  vector<double> feature_weights; -  Weights w; -  if (conf.count("weights")) { -    w.InitFromFile(conf["weights"].as<string>()); -    feature_weights.resize(FD::NumFeats()); -    w.InitVector(&feature_weights); -  } -    // set up additional scoring features    vector<shared_ptr<FeatureFunction> > pffs;    vector<const FeatureFunction*> late_ffs; @@ -480,6 +487,7 @@ int main(int argc, char** argv) {          }          if (output_training_vector) { +          acc_vec.clear_value(0);            ++g_count;            if (g_count % combine_size == 0) {              if (encode_b64) { diff --git a/decoder/dict.h b/decoder/dict.h index bae9debe..0cbc9ff0 100644 --- a/decoder/dict.h +++ b/decoder/dict.h @@ -16,9 +16,11 @@ class Dict {   public:    Dict() : b0_("<bad0>") { words_.reserve(1000); }    inline int max() const { return words_.size(); } -  inline WordID Convert(const std::string& word) { +  inline WordID Convert(const std::string& word, bool frozen = false) {      Map::iterator i = d_.find(word);      if (i == d_.end()) { +      if (frozen) +        return 0;        words_.push_back(word);        d_[word] = words_.size();        return words_.size(); diff --git a/decoder/fdict.cc b/decoder/fdict.cc index 83aa7cea..8218a5d3 100644 --- a/decoder/fdict.cc +++ b/decoder/fdict.cc @@ -1,4 +1,5 @@  #include "fdict.h"  Dict FD::dict_; +bool FD::frozen_ = false; diff --git a/decoder/fdict.h b/decoder/fdict.h index ff491cfb..d05f1706 100644 --- a/decoder/fdict.h +++ b/decoder/fdict.h @@ -6,16 +6,23 @@  #include "dict.h"  struct FD { -  static Dict dict_; +  // once the FD is frozen, new features not already in the +  // dictionary will return 0 +  static void Freeze() { +    frozen_ = true; +  }    static inline int NumFeats() {      return dict_.max() + 1;    }    static inline WordID Convert(const std::string& s) { -    return dict_.Convert(s); +    return dict_.Convert(s, frozen_);    }    static inline const std::string& Convert(const WordID& w) {      return dict_.Convert(w);    } +  static Dict dict_; + private: +  static bool frozen_;  };  #endif diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h index 6a8c9bf4..2b4a63a9 100644 --- a/decoder/sparse_vector.h +++ b/decoder/sparse_vector.h @@ -185,10 +185,15 @@ public:      }      std::ostream &operator<<(std::ostream &out) const { +        bool first = true;          for (typename std::map<int, T>::const_iterator  -                it = _values.begin(); it != _values.end(); ++it) -            out << (it == _values.begin() ? "" : ";") -	        << FD::Convert(it->first) << '=' << it->second; +                it = _values.begin(); it != _values.end(); ++it) { +          // by definition feature id 0 is a dummy value +          if (it->first == 0) continue; +          out << (first ? "" : ";") +	      << FD::Convert(it->first) << '=' << it->second; +          first = false; +        }          return out;      } @@ -216,6 +221,9 @@ public:      void clear() {          _values.clear();      } +    void clear_value(int index) { +      _values.erase(index); +    }      void swap(SparseVector<T>& other) {        _values.swap(other._values); diff --git a/decoder/trule.cc b/decoder/trule.cc index b8f6995e..505839c7 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -126,7 +126,11 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) {            if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); }            fname[12]='0' + fv;            ++fv; -          scores_.set_value(FD::Convert(fname), atof(&ss[start])); +          // if the feature set is frozen, this may return zero, indicating an +          // undefined feature +          const int fid = FD::Convert(fname); +          if (fid) +            scores_.set_value(fid, atof(&ss[start]));            //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl;          } else {            const int fid = FD::Convert(ss.substr(start, end - start)); @@ -136,7 +140,8 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) {              ++end;            if (end < len) { ss[end] = 0; }  	  assert(start < len); -          scores_.set_value(fid, atof(&ss[start])); +          if (fid) +            scores_.set_value(fid, atof(&ss[start]));            //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl;          }          start = end + 1; | 
