From 1aac806af7785ab440d300ca5cfa8833e3ed61d3 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Fri, 18 Dec 2009 01:27:19 -0500 Subject: add support for freezing the feature set to a user-specified list, even if feature detectors create additional features --- decoder/cdec.cc | 24 ++++++++++++++++-------- decoder/dict.h | 4 +++- decoder/fdict.cc | 1 + decoder/fdict.h | 11 +++++++++-- decoder/sparse_vector.h | 14 +++++++++++--- decoder/trule.cc | 9 +++++++-- 6 files changed, 47 insertions(+), 16 deletions(-) diff --git a/decoder/cdec.cc b/decoder/cdec.cc index c6773cce..c6a0057f 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -53,6 +53,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("input,i",po::value()->default_value("-"),"Source file") ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("weights,w",po::value(),"Feature weights file") + ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file") ("feature_function,F",po::value >()->composing(), "Additional feature function(s) (-L for list)") ("list_feature_functions,L","List available feature functions") ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") @@ -248,6 +249,20 @@ int main(int argc, char** argv) { exit(1); } + // load feature weights (and possibly freeze feature set) + vector feature_weights; + Weights w; + if (conf.count("weights")) { + w.InitFromFile(conf["weights"].as()); + feature_weights.resize(FD::NumFeats()); + w.InitVector(&feature_weights); + if (!conf.count("no_freeze_feature_set")) { + cerr << "Freezing feature set (use --no_freeze_feature_set to change)." << endl; + FD::Freeze(); + } + } + + // set up translation back end if (formalism == "scfg") translator.reset(new SCFGTranslator(conf)); else if (formalism == "fst") @@ -263,14 +278,6 @@ int main(int argc, char** argv) { else assert(!"error"); - vector feature_weights; - Weights w; - if (conf.count("weights")) { - w.InitFromFile(conf["weights"].as()); - feature_weights.resize(FD::NumFeats()); - w.InitVector(&feature_weights); - } - // set up additional scoring features vector > pffs; vector late_ffs; @@ -480,6 +487,7 @@ int main(int argc, char** argv) { } if (output_training_vector) { + acc_vec.clear_value(0); ++g_count; if (g_count % combine_size == 0) { if (encode_b64) { diff --git a/decoder/dict.h b/decoder/dict.h index bae9debe..0cbc9ff0 100644 --- a/decoder/dict.h +++ b/decoder/dict.h @@ -16,9 +16,11 @@ class Dict { public: Dict() : b0_("") { words_.reserve(1000); } inline int max() const { return words_.size(); } - inline WordID Convert(const std::string& word) { + inline WordID Convert(const std::string& word, bool frozen = false) { Map::iterator i = d_.find(word); if (i == d_.end()) { + if (frozen) + return 0; words_.push_back(word); d_[word] = words_.size(); return words_.size(); diff --git a/decoder/fdict.cc b/decoder/fdict.cc index 83aa7cea..8218a5d3 100644 --- a/decoder/fdict.cc +++ b/decoder/fdict.cc @@ -1,4 +1,5 @@ #include "fdict.h" Dict FD::dict_; +bool FD::frozen_ = false; diff --git a/decoder/fdict.h b/decoder/fdict.h index ff491cfb..d05f1706 100644 --- a/decoder/fdict.h +++ b/decoder/fdict.h @@ -6,16 +6,23 @@ #include "dict.h" struct FD { - static Dict dict_; + // once the FD is frozen, new features not already in the + // dictionary will return 0 + static void Freeze() { + frozen_ = true; + } static inline int NumFeats() { return dict_.max() + 1; } static inline WordID Convert(const std::string& s) { - return dict_.Convert(s); + return dict_.Convert(s, frozen_); } static inline const std::string& Convert(const WordID& w) { return dict_.Convert(w); } + static Dict dict_; + private: + static bool frozen_; }; #endif diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h index 6a8c9bf4..2b4a63a9 100644 --- a/decoder/sparse_vector.h +++ b/decoder/sparse_vector.h @@ -185,10 +185,15 @@ public: } std::ostream &operator<<(std::ostream &out) const { + bool first = true; for (typename std::map::const_iterator - it = _values.begin(); it != _values.end(); ++it) - out << (it == _values.begin() ? "" : ";") - << FD::Convert(it->first) << '=' << it->second; + it = _values.begin(); it != _values.end(); ++it) { + // by definition feature id 0 is a dummy value + if (it->first == 0) continue; + out << (first ? "" : ";") + << FD::Convert(it->first) << '=' << it->second; + first = false; + } return out; } @@ -216,6 +221,9 @@ public: void clear() { _values.clear(); } + void clear_value(int index) { + _values.erase(index); + } void swap(SparseVector& other) { _values.swap(other._values); diff --git a/decoder/trule.cc b/decoder/trule.cc index b8f6995e..505839c7 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -126,7 +126,11 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) { if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); } fname[12]='0' + fv; ++fv; - scores_.set_value(FD::Convert(fname), atof(&ss[start])); + // if the feature set is frozen, this may return zero, indicating an + // undefined feature + const int fid = FD::Convert(fname); + if (fid) + scores_.set_value(fid, atof(&ss[start])); //cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl; } else { const int fid = FD::Convert(ss.substr(start, end - start)); @@ -136,7 +140,8 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) { ++end; if (end < len) { ss[end] = 0; } assert(start < len); - scores_.set_value(fid, atof(&ss[start])); + if (fid) + scores_.set_value(fid, atof(&ss[start])); //cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl; } start = end + 1; -- cgit v1.2.3