summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2010-12-22 08:49:18 -0600
committerChris Dyer <cdyer@cs.cmu.edu>2010-12-22 08:49:18 -0600
commitbf767bf309459e7ea8801d823d4833ae480b2038 (patch)
treee1f13503a33754b3db0db01e5d46e007a4d3f5c4 /decoder
parentf2814314a1245fa0da3cba248cbe59b7f7cd87a8 (diff)
clean up names of feature functions, fix tagger, fix tests
Diffstat (limited to 'decoder')
-rw-r--r--decoder/bottom_up_parser.cc19
-rw-r--r--decoder/cdec_ff.cc10
-rw-r--r--decoder/ff_csplit.cc2
-rw-r--r--decoder/ff_tagger.cc22
-rw-r--r--decoder/ff_tagger.h12
-rw-r--r--decoder/ff_wordalign.cc98
-rw-r--r--decoder/ff_wordalign.h35
-rw-r--r--decoder/tagger.cc2
8 files changed, 82 insertions, 118 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index 9504419c..aecf1cfa 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -14,20 +14,6 @@
using namespace std;
-struct ParserStats {
- ParserStats() : active_items(), passive_items() {}
- void Reset() { active_items=0; passive_items=0; }
- void Report() {
- if (!SILENT) cerr << " ACTIVE ITEMS: " << active_items << "\tPASSIVE ITEMS: " << passive_items << endl;
- }
- int active_items;
- int passive_items;
- void NotifyActive(int , int ) { ++active_items; }
- void NotifyPassive(int , int ) { ++passive_items; }
-};
-
-ParserStats stats;
-
class ActiveChart;
class PassiveChart {
public:
@@ -90,7 +76,6 @@ class ActiveChart {
void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const {
const GrammarIter* ni = gptr_->Extend(symbol);
if (ni) {
- stats.NotifyActive(-1,-1); // TRACKING STATS
out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
}
}
@@ -98,7 +83,6 @@ class ActiveChart {
int symbol = hg->nodes_[node_index].cat_;
const GrammarIter* ni = gptr_->Extend(symbol);
if (!ni) return;
- stats.NotifyActive(-1,-1); // TRACKING STATS
Hypergraph::TailNodeVector na(ant_nodes_.size() + 1);
for (int i = 0; i < ant_nodes_.size(); ++i)
na[i] = ant_nodes_[i];
@@ -181,7 +165,6 @@ void PassiveChart::ApplyRule(const int i,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
- stats.NotifyPassive(i,j); // TRACKING STATS
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
@@ -299,9 +282,7 @@ ExhaustiveBottomUpParser::ExhaustiveBottomUpParser(
bool ExhaustiveBottomUpParser::Parse(const Lattice& input,
Hypergraph* forest) const {
- stats.Reset();
PassiveChart chart(goal_sym_, grammars_, input, forest);
const bool result = chart.Parse();
- stats.Report();
return result;
}
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 729d1214..75591af8 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -53,15 +53,15 @@ void register_feature_functions() {
ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
ff_registry.Register("NewJump", new FFFactory<NewJump>);
ff_registry.Register("SourceBigram", new FFFactory<SourceBigram>);
+ ff_registry.Register("Fertility", new FFFactory<Fertility>);
ff_registry.Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>);
- ff_registry.Register("AlignerResults", new FFFactory<AlignerResults>);
ff_registry.Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>);
ff_registry.Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>);
- ff_registry.Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>);
- ff_registry.Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>);
- ff_registry.Register("OutputIdentity", new FFFactory<OutputIdentity>);
+ ff_registry.Register("Tagger_BigramIndicator", new FFFactory<Tagger_BigramIndicator>);
+ ff_registry.Register("LexicalPairIndicator", new FFFactory<LexicalPairIndicator>);
+ ff_registry.Register("OutputIndicator", new FFFactory<OutputIndicator>);
ff_registry.Register("IdentityCycleDetector", new FFFactory<IdentityCycleDetector>);
- ff_registry.Register("InputIdentity", new FFFactory<InputIdentity>);
+ ff_registry.Register("InputIndicator", new FFFactory<InputIndicator>);
ff_registry.Register("LexicalTranslationTrigger", new FFFactory<LexicalTranslationTrigger>);
ff_registry.Register("WordPairFeatures", new FFFactory<WordPairFeatures>);
ff_registry.Register("WordSet", new FFFactory<WordSet>);
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc
index f267f8e8..1485009b 100644
--- a/decoder/ff_csplit.cc
+++ b/decoder/ff_csplit.cc
@@ -208,9 +208,9 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
if (edge.rule_->EWords() != 1) return;
const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);
features->set_value(fid_, lpp);
-#if 0
WordID neighbor_word = 0;
const WordID word = edge.rule_->e_[1];
+#if 0
if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) {
neighbor_word = TD::Convert(string(&sword[1]));
}
diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc
index 21d0f812..46c85cf3 100644
--- a/decoder/ff_tagger.cc
+++ b/decoder/ff_tagger.cc
@@ -8,10 +8,10 @@
using namespace std;
-Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) :
+Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) :
FeatureFunction(sizeof(WordID)) {}
-void Tagger_BigramIdentity::FireFeature(const WordID& left,
+void Tagger_BigramIndicator::FireFeature(const WordID& left,
const WordID& right,
SparseVector<double>* features) const {
int& fid = fmap_[left][right];
@@ -30,7 +30,7 @@ void Tagger_BigramIdentity::FireFeature(const WordID& left,
features->set_value(fid, 1.0);
}
-void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
@@ -53,18 +53,18 @@ void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
-void LexicalPairIdentity::PrepareForInput(const SentenceMetadata& smeta) {
+void LexicalPairIndicator::PrepareForInput(const SentenceMetadata& smeta) {
lexmap_->PrepareForInput(smeta);
}
-LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {
+LexicalPairIndicator::LexicalPairIndicator(const std::string& param) {
name_ = "Id";
if (param.size()) {
// name corpus.f emap.txt
vector<string> params;
SplitOnWhitespace(param, &params);
if (params.size() != 3) {
- cerr << "LexicalPairIdentity takes 3 parameters: <name> <corpus.src.txt> <trgmap.txt>\n";
+ cerr << "LexicalPairIndicator takes 3 parameters: <name> <corpus.src.txt> <trgmap.txt>\n";
cerr << " * may be used for corpus.src.txt or trgmap.txt to use surface forms\n";
cerr << " Received: " << param << endl;
abort();
@@ -76,7 +76,7 @@ LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {
}
}
-void LexicalPairIdentity::FireFeature(WordID src,
+void LexicalPairIndicator::FireFeature(WordID src,
WordID trg,
SparseVector<double>* features) const {
int& fid = fmap_[src][trg];
@@ -88,7 +88,7 @@ void LexicalPairIdentity::FireFeature(WordID src,
features->set_value(fid, 1.0);
}
-void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void LexicalPairIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
@@ -105,9 +105,9 @@ void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
-OutputIdentity::OutputIdentity(const std::string& param) {}
+OutputIndicator::OutputIndicator(const std::string& param) {}
-void OutputIdentity::FireFeature(WordID trg,
+void OutputIndicator::FireFeature(WordID trg,
SparseVector<double>* features) const {
int& fid = fmap_[trg];
if (!fid) {
@@ -125,7 +125,7 @@ void OutputIdentity::FireFeature(WordID trg,
features->set_value(fid, 1.0);
}
-void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void OutputIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h
index 6adee5ab..3066866a 100644
--- a/decoder/ff_tagger.h
+++ b/decoder/ff_tagger.h
@@ -13,9 +13,9 @@ typedef std::map<WordID, Class2FID> Class2Class2FID;
// the sequence unfolds from left to right, which means it doesn't
// have to split states based on left context.
// fires unigram features as well
-class Tagger_BigramIdentity : public FeatureFunction {
+class Tagger_BigramIndicator : public FeatureFunction {
public:
- Tagger_BigramIdentity(const std::string& param);
+ Tagger_BigramIndicator(const std::string& param);
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
@@ -32,9 +32,9 @@ class Tagger_BigramIdentity : public FeatureFunction {
// for each pair of symbols cooccuring in a lexicalized rule, fire
// a feature (mostly used for tagging, but could be used for any model)
-class LexicalPairIdentity : public FeatureFunction {
+class LexicalPairIndicator : public FeatureFunction {
public:
- LexicalPairIdentity(const std::string& param);
+ LexicalPairIndicator(const std::string& param);
virtual void PrepareForInput(const SentenceMetadata& smeta);
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
@@ -53,9 +53,9 @@ class LexicalPairIdentity : public FeatureFunction {
};
-class OutputIdentity : public FeatureFunction {
+class OutputIndicator : public FeatureFunction {
public:
- OutputIdentity(const std::string& param);
+ OutputIndicator(const std::string& param);
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc
index ef3310b4..cdb8662a 100644
--- a/decoder/ff_wordalign.cc
+++ b/decoder/ff_wordalign.cc
@@ -6,6 +6,7 @@
#include <sstream>
#include <string>
#include <cmath>
+#include <bitset>
#include <tr1/unordered_map>
#include <boost/tuple/tuple.hpp>
@@ -443,58 +444,6 @@ void LexicalTranslationTrigger::TraversalFeaturesImpl(const SentenceMetadata& sm
}
}
-// state: src word used, number of trg words generated
-AlignerResults::AlignerResults(const std::string& param) :
- cur_sent_(-1),
- cur_grid_(NULL) {
- vector<string> argv;
- int argc = SplitOnWhitespace(param, &argv);
- if (argc != 2) {
- cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n";
- exit(1);
- }
- cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl;
- fid_ = FD::Convert(argv[0]);
- ReadFile rf(argv[1]);
- istream& in = *rf.stream(); int lc = 0;
- while(in) {
- string line;
- getline(in, line);
- if (!in) break;
- ++lc;
- is_aligned_.push_back(AlignmentPharaoh::ReadPharaohAlignmentGrid(line));
- }
- cerr << " Loaded " << lc << " refs\n";
-}
-
-void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& /* ant_states */,
- SparseVector<double>* features,
- SparseVector<double>* /* estimated_features */,
- void* /* state */) const {
- if (edge.i_ == -1 || edge.prev_i_ == -1)
- return;
-
- if (cur_sent_ != smeta.GetSentenceID()) {
- assert(smeta.HasReference());
- cur_sent_ = smeta.GetSentenceID();
- assert(cur_sent_ < is_aligned_.size());
- cur_grid_ = is_aligned_[cur_sent_].get();
- }
-
- //cerr << edge.rule_->AsString() << endl;
-
- int j = edge.i_; // source side (f)
- int i = edge.prev_i_; // target side (e)
- if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) {
-// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) {
- features->set_value(fid_, 1.0);
-// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n";
-// }
- }
-}
-
BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) :
FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) {
ReadFile rf(param);
@@ -618,10 +567,10 @@ void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
-InputIdentity::InputIdentity(const std::string& param) {}
+InputIndicator::InputIndicator(const std::string& param) {}
-void InputIdentity::FireFeature(WordID src,
- SparseVector<double>* features) const {
+void InputIndicator::FireFeature(WordID src,
+ SparseVector<double>* features) const {
int& fid = fmap_[src];
if (!fid) {
static map<WordID, WordID> escape;
@@ -638,7 +587,7 @@ void InputIdentity::FireFeature(WordID src,
features->set_value(fid, 1.0);
}
-void InputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void InputIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
@@ -770,3 +719,40 @@ void WordPairFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
}
+struct PathFertility {
+ unsigned char null_fertility;
+ unsigned char index_fertility[255];
+ PathFertility& operator+=(const PathFertility& rhs) {
+ null_fertility += rhs.null_fertility;
+ for (int i = 0; i < 255; ++i)
+ index_fertility[i] += rhs.index_fertility[i];
+ return *this;
+ }
+};
+
+Fertility::Fertility(const string& config) :
+ FeatureFunction(sizeof(PathFertility)) {}
+
+void Fertility::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ PathFertility& out_fert = *static_cast<PathFertility*>(context);
+ if (edge.Arity() == 0) {
+ if (edge.i_ < 0) {
+ out_fert.null_fertility = 1;
+ } else {
+ out_fert.index_fertility[edge.i_] = 1;
+ }
+ } else if (edge.Arity() == 2) {
+ const PathFertility left = *static_cast<const PathFertility*>(ant_contexts[0]);
+ const PathFertility right = *static_cast<const PathFertility*>(ant_contexts[1]);
+ out_fert += left;
+ out_fert += right;
+ } else if (edge.Arity() == 1) {
+ out_fert += *static_cast<const PathFertility*>(ant_contexts[0]);
+ }
+}
+
diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h
index 8035000e..d7a2dda8 100644
--- a/decoder/ff_wordalign.h
+++ b/decoder/ff_wordalign.h
@@ -124,23 +124,6 @@ class LexicalTranslationTrigger : public FeatureFunction {
std::vector<std::vector<WordID> > triggers_;
};
-class AlignerResults : public FeatureFunction {
- public:
- AlignerResults(const std::string& param);
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* out_context) const;
- private:
- int fid_;
- std::vector<boost::shared_ptr<Array2D<bool> > > is_aligned_;
- mutable int cur_sent_;
- const Array2D<bool> mutable* cur_grid_;
-};
-
#include <tr1/unordered_map>
#include <boost/functional/hash.hpp>
#include <cassert>
@@ -254,9 +237,9 @@ class IdentityCycleDetector : public FeatureFunction {
mutable std::map<WordID, bool> big_enough_;
};
-class InputIdentity : public FeatureFunction {
+class InputIndicator : public FeatureFunction {
public:
- InputIdentity(const std::string& param);
+ InputIndicator(const std::string& param);
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
@@ -270,4 +253,18 @@ class InputIdentity : public FeatureFunction {
mutable Class2FID fmap_;
};
+class Fertility : public FeatureFunction {
+ public:
+ Fertility(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ mutable std::map<WordID, int> fids_;
+};
+
#endif
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 4dded35f..54890e85 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -96,7 +96,7 @@ bool Tagger::TranslateImpl(const string& input,
SentenceMetadata* smeta,
const vector<double>& weights,
Hypergraph* forest) {
- Lattice lattice;
+ Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextToLattice(input, &lattice);
smeta->SetSourceLength(lattice.size());
vector<WordID> sequence(lattice.size());