summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-18 01:27:19 -0500
committerChris Dyer <redpony@gmail.com>2009-12-18 01:27:19 -0500
commit1aac806af7785ab440d300ca5cfa8833e3ed61d3 (patch)
tree4a2ffa484af029ebc542f2cdf7bb6da93325b29a
parent40ac2d31391c27b168b0294e7683cb69da29f868 (diff)
add support for freezing the feature set to a user-specified list, even if feature detectors create additional features
-rw-r--r--decoder/cdec.cc24
-rw-r--r--decoder/dict.h4
-rw-r--r--decoder/fdict.cc1
-rw-r--r--decoder/fdict.h11
-rw-r--r--decoder/sparse_vector.h14
-rw-r--r--decoder/trule.cc9
6 files changed, 47 insertions, 16 deletions
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index c6773cce..c6a0057f 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -53,6 +53,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
+ ("no_freeze_feature_set,Z", "Do not freeze feature set after reading feature weights file")
("feature_function,F",po::value<vector<string> >()->composing(), "Additional feature function(s) (-L for list)")
("list_feature_functions,L","List available feature functions")
("add_pass_through_rules,P","Add rules to translate OOV words as themselves")
@@ -248,6 +249,20 @@ int main(int argc, char** argv) {
exit(1);
}
+ // load feature weights (and possibly freeze feature set)
+ vector<double> feature_weights;
+ Weights w;
+ if (conf.count("weights")) {
+ w.InitFromFile(conf["weights"].as<string>());
+ feature_weights.resize(FD::NumFeats());
+ w.InitVector(&feature_weights);
+ if (!conf.count("no_freeze_feature_set")) {
+ cerr << "Freezing feature set (use --no_freeze_feature_set to change)." << endl;
+ FD::Freeze();
+ }
+ }
+
+ // set up translation back end
if (formalism == "scfg")
translator.reset(new SCFGTranslator(conf));
else if (formalism == "fst")
@@ -263,14 +278,6 @@ int main(int argc, char** argv) {
else
assert(!"error");
- vector<double> feature_weights;
- Weights w;
- if (conf.count("weights")) {
- w.InitFromFile(conf["weights"].as<string>());
- feature_weights.resize(FD::NumFeats());
- w.InitVector(&feature_weights);
- }
-
// set up additional scoring features
vector<shared_ptr<FeatureFunction> > pffs;
vector<const FeatureFunction*> late_ffs;
@@ -480,6 +487,7 @@ int main(int argc, char** argv) {
}
if (output_training_vector) {
+ acc_vec.clear_value(0);
++g_count;
if (g_count % combine_size == 0) {
if (encode_b64) {
diff --git a/decoder/dict.h b/decoder/dict.h
index bae9debe..0cbc9ff0 100644
--- a/decoder/dict.h
+++ b/decoder/dict.h
@@ -16,9 +16,11 @@ class Dict {
public:
Dict() : b0_("<bad0>") { words_.reserve(1000); }
inline int max() const { return words_.size(); }
- inline WordID Convert(const std::string& word) {
+ inline WordID Convert(const std::string& word, bool frozen = false) {
Map::iterator i = d_.find(word);
if (i == d_.end()) {
+ if (frozen)
+ return 0;
words_.push_back(word);
d_[word] = words_.size();
return words_.size();
diff --git a/decoder/fdict.cc b/decoder/fdict.cc
index 83aa7cea..8218a5d3 100644
--- a/decoder/fdict.cc
+++ b/decoder/fdict.cc
@@ -1,4 +1,5 @@
#include "fdict.h"
Dict FD::dict_;
+bool FD::frozen_ = false;
diff --git a/decoder/fdict.h b/decoder/fdict.h
index ff491cfb..d05f1706 100644
--- a/decoder/fdict.h
+++ b/decoder/fdict.h
@@ -6,16 +6,23 @@
#include "dict.h"
struct FD {
- static Dict dict_;
+ // once the FD is frozen, new features not already in the
+ // dictionary will return 0
+ static void Freeze() {
+ frozen_ = true;
+ }
static inline int NumFeats() {
return dict_.max() + 1;
}
static inline WordID Convert(const std::string& s) {
- return dict_.Convert(s);
+ return dict_.Convert(s, frozen_);
}
static inline const std::string& Convert(const WordID& w) {
return dict_.Convert(w);
}
+ static Dict dict_;
+ private:
+ static bool frozen_;
};
#endif
diff --git a/decoder/sparse_vector.h b/decoder/sparse_vector.h
index 6a8c9bf4..2b4a63a9 100644
--- a/decoder/sparse_vector.h
+++ b/decoder/sparse_vector.h
@@ -185,10 +185,15 @@ public:
}
std::ostream &operator<<(std::ostream &out) const {
+ bool first = true;
for (typename std::map<int, T>::const_iterator
- it = _values.begin(); it != _values.end(); ++it)
- out << (it == _values.begin() ? "" : ";")
- << FD::Convert(it->first) << '=' << it->second;
+ it = _values.begin(); it != _values.end(); ++it) {
+ // by definition feature id 0 is a dummy value
+ if (it->first == 0) continue;
+ out << (first ? "" : ";")
+ << FD::Convert(it->first) << '=' << it->second;
+ first = false;
+ }
return out;
}
@@ -216,6 +221,9 @@ public:
void clear() {
_values.clear();
}
+ void clear_value(int index) {
+ _values.erase(index);
+ }
void swap(SparseVector<T>& other) {
_values.swap(other._values);
diff --git a/decoder/trule.cc b/decoder/trule.cc
index b8f6995e..505839c7 100644
--- a/decoder/trule.cc
+++ b/decoder/trule.cc
@@ -126,7 +126,11 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) {
if (fv > 9) { cerr << "Too many phrasetable scores - used named format\n"; abort(); }
fname[12]='0' + fv;
++fv;
- scores_.set_value(FD::Convert(fname), atof(&ss[start]));
+ // if the feature set is frozen, this may return zero, indicating an
+ // undefined feature
+ const int fid = FD::Convert(fname);
+ if (fid)
+ scores_.set_value(fid, atof(&ss[start]));
//cerr << "F: " << fname << " VAL=" << scores_.value(FD::Convert(fname)) << endl;
} else {
const int fid = FD::Convert(ss.substr(start, end - start));
@@ -136,7 +140,8 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) {
++end;
if (end < len) { ss[end] = 0; }
assert(start < len);
- scores_.set_value(fid, atof(&ss[start]));
+ if (fid)
+ scores_.set_value(fid, atof(&ss[start]));
//cerr << "F: " << FD::Convert(fid) << " VAL=" << scores_.value(fid) << endl;
}
start = end + 1;