From 86805dcb8aaaa716fdc73725ad41e411be53f6a6 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Wed, 22 Dec 2010 08:49:18 -0600 Subject: clean up names of feature functions, fix tagger, fix tests --- decoder/bottom_up_parser.cc | 19 --------- decoder/cdec_ff.cc | 10 ++--- decoder/ff_csplit.cc | 2 +- decoder/ff_tagger.cc | 22 +++++----- decoder/ff_tagger.h | 12 +++--- decoder/ff_wordalign.cc | 98 +++++++++++++++++++-------------------------- decoder/ff_wordalign.h | 35 ++++++++-------- decoder/tagger.cc | 2 +- 8 files changed, 82 insertions(+), 118 deletions(-) (limited to 'decoder') diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 9504419c..aecf1cfa 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -14,20 +14,6 @@ using namespace std; -struct ParserStats { - ParserStats() : active_items(), passive_items() {} - void Reset() { active_items=0; passive_items=0; } - void Report() { - if (!SILENT) cerr << " ACTIVE ITEMS: " << active_items << "\tPASSIVE ITEMS: " << passive_items << endl; - } - int active_items; - int passive_items; - void NotifyActive(int , int ) { ++active_items; } - void NotifyPassive(int , int ) { ++passive_items; } -}; - -ParserStats stats; - class ActiveChart; class PassiveChart { public: @@ -90,7 +76,6 @@ class ActiveChart { void ExtendTerminal(int symbol, float src_cost, vector* out_cell) const { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) { - stats.NotifyActive(-1,-1); // TRACKING STATS out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); } } @@ -98,7 +83,6 @@ class ActiveChart { int symbol = hg->nodes_[node_index].cat_; const GrammarIter* ni = gptr_->Extend(symbol); if (!ni) return; - stats.NotifyActive(-1,-1); // TRACKING STATS Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); for (int i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; @@ -181,7 +165,6 @@ void PassiveChart::ApplyRule(const int i, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { - stats.NotifyPassive(i,j); // TRACKING STATS Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; @@ -299,9 +282,7 @@ ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( bool ExhaustiveBottomUpParser::Parse(const Lattice& input, Hypergraph* forest) const { - stats.Reset(); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); - stats.Report(); return result; } diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 729d1214..75591af8 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -53,15 +53,15 @@ void register_feature_functions() { ff_registry.Register("LexNullJump", new FFFactory); ff_registry.Register("NewJump", new FFFactory); ff_registry.Register("SourceBigram", new FFFactory); + ff_registry.Register("Fertility", new FFFactory); ff_registry.Register("BlunsomSynchronousParseHack", new FFFactory); - ff_registry.Register("AlignerResults", new FFFactory); ff_registry.Register("CSplit_BasicFeatures", new FFFactory); ff_registry.Register("CSplit_ReverseCharLM", new FFFactory); - ff_registry.Register("Tagger_BigramIdentity", new FFFactory); - ff_registry.Register("LexicalPairIdentity", new FFFactory); - ff_registry.Register("OutputIdentity", new FFFactory); + ff_registry.Register("Tagger_BigramIndicator", new FFFactory); + ff_registry.Register("LexicalPairIndicator", new FFFactory); + ff_registry.Register("OutputIndicator", new FFFactory); ff_registry.Register("IdentityCycleDetector", new FFFactory); - ff_registry.Register("InputIdentity", new FFFactory); + ff_registry.Register("InputIndicator", new FFFactory); ff_registry.Register("LexicalTranslationTrigger", new FFFactory); ff_registry.Register("WordPairFeatures", new FFFactory); ff_registry.Register("WordSet", new FFFactory); diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index f267f8e8..1485009b 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -208,9 +208,9 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl( if (edge.rule_->EWords() != 1) return; const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_); features->set_value(fid_, lpp); -#if 0 WordID neighbor_word = 0; const WordID word = edge.rule_->e_[1]; +#if 0 if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) { neighbor_word = TD::Convert(string(&sword[1])); } diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 21d0f812..46c85cf3 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -8,10 +8,10 @@ using namespace std; -Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) : +Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) : FeatureFunction(sizeof(WordID)) {} -void Tagger_BigramIdentity::FireFeature(const WordID& left, +void Tagger_BigramIndicator::FireFeature(const WordID& left, const WordID& right, SparseVector* features) const { int& fid = fmap_[left][right]; @@ -30,7 +30,7 @@ void Tagger_BigramIdentity::FireFeature(const WordID& left, features->set_value(fid, 1.0); } -void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, SparseVector* features, @@ -53,18 +53,18 @@ void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } -void LexicalPairIdentity::PrepareForInput(const SentenceMetadata& smeta) { +void LexicalPairIndicator::PrepareForInput(const SentenceMetadata& smeta) { lexmap_->PrepareForInput(smeta); } -LexicalPairIdentity::LexicalPairIdentity(const std::string& param) { +LexicalPairIndicator::LexicalPairIndicator(const std::string& param) { name_ = "Id"; if (param.size()) { // name corpus.f emap.txt vector params; SplitOnWhitespace(param, ¶ms); if (params.size() != 3) { - cerr << "LexicalPairIdentity takes 3 parameters: \n"; + cerr << "LexicalPairIndicator takes 3 parameters: \n"; cerr << " * may be used for corpus.src.txt or trgmap.txt to use surface forms\n"; cerr << " Received: " << param << endl; abort(); @@ -76,7 +76,7 @@ LexicalPairIdentity::LexicalPairIdentity(const std::string& param) { } } -void LexicalPairIdentity::FireFeature(WordID src, +void LexicalPairIndicator::FireFeature(WordID src, WordID trg, SparseVector* features) const { int& fid = fmap_[src][trg]; @@ -88,7 +88,7 @@ void LexicalPairIdentity::FireFeature(WordID src, features->set_value(fid, 1.0); } -void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void LexicalPairIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, SparseVector* features, @@ -105,9 +105,9 @@ void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } -OutputIdentity::OutputIdentity(const std::string& param) {} +OutputIndicator::OutputIndicator(const std::string& param) {} -void OutputIdentity::FireFeature(WordID trg, +void OutputIndicator::FireFeature(WordID trg, SparseVector* features) const { int& fid = fmap_[trg]; if (!fid) { @@ -125,7 +125,7 @@ void OutputIdentity::FireFeature(WordID trg, features->set_value(fid, 1.0); } -void OutputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void OutputIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, SparseVector* features, diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 6adee5ab..3066866a 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -13,9 +13,9 @@ typedef std::map Class2Class2FID; // the sequence unfolds from left to right, which means it doesn't // have to split states based on left context. // fires unigram features as well -class Tagger_BigramIdentity : public FeatureFunction { +class Tagger_BigramIndicator : public FeatureFunction { public: - Tagger_BigramIdentity(const std::string& param); + Tagger_BigramIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -32,9 +32,9 @@ class Tagger_BigramIdentity : public FeatureFunction { // for each pair of symbols cooccuring in a lexicalized rule, fire // a feature (mostly used for tagging, but could be used for any model) -class LexicalPairIdentity : public FeatureFunction { +class LexicalPairIndicator : public FeatureFunction { public: - LexicalPairIdentity(const std::string& param); + LexicalPairIndicator(const std::string& param); virtual void PrepareForInput(const SentenceMetadata& smeta); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, @@ -53,9 +53,9 @@ class LexicalPairIdentity : public FeatureFunction { }; -class OutputIdentity : public FeatureFunction { +class OutputIndicator : public FeatureFunction { public: - OutputIdentity(const std::string& param); + OutputIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index ef3310b4..cdb8662a 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -443,58 +444,6 @@ void LexicalTranslationTrigger::TraversalFeaturesImpl(const SentenceMetadata& sm } } -// state: src word used, number of trg words generated -AlignerResults::AlignerResults(const std::string& param) : - cur_sent_(-1), - cur_grid_(NULL) { - vector argv; - int argc = SplitOnWhitespace(param, &argv); - if (argc != 2) { - cerr << "Required format: AlignerResults [FeatureName] [file.pharaoh]\n"; - exit(1); - } - cerr << " feature: " << argv[0] << "\talignments: " << argv[1] << endl; - fid_ = FD::Convert(argv[0]); - ReadFile rf(argv[1]); - istream& in = *rf.stream(); int lc = 0; - while(in) { - string line; - getline(in, line); - if (!in) break; - ++lc; - is_aligned_.push_back(AlignmentPharaoh::ReadPharaohAlignmentGrid(line)); - } - cerr << " Loaded " << lc << " refs\n"; -} - -void AlignerResults::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector& /* ant_states */, - SparseVector* features, - SparseVector* /* estimated_features */, - void* /* state */) const { - if (edge.i_ == -1 || edge.prev_i_ == -1) - return; - - if (cur_sent_ != smeta.GetSentenceID()) { - assert(smeta.HasReference()); - cur_sent_ = smeta.GetSentenceID(); - assert(cur_sent_ < is_aligned_.size()); - cur_grid_ = is_aligned_[cur_sent_].get(); - } - - //cerr << edge.rule_->AsString() << endl; - - int j = edge.i_; // source side (f) - int i = edge.prev_i_; // target side (e) - if (j < cur_grid_->height() && i < cur_grid_->width() && (*cur_grid_)(i, j)) { -// if (edge.rule_->e_[0] == smeta.GetReference()[i][0].label) { - features->set_value(fid_, 1.0); -// cerr << edge.rule_->AsString() << " (" << i << "," << j << ")\n"; -// } - } -} - BlunsomSynchronousParseHack::BlunsomSynchronousParseHack(const string& param) : FeatureFunction((100 / 8) + 1), fid_(FD::Convert("NotRef")), cur_sent_(-1) { ReadFile rf(param); @@ -618,10 +567,10 @@ void IdentityCycleDetector::TraversalFeaturesImpl(const SentenceMetadata& smeta, } -InputIdentity::InputIdentity(const std::string& param) {} +InputIndicator::InputIndicator(const std::string& param) {} -void InputIdentity::FireFeature(WordID src, - SparseVector* features) const { +void InputIndicator::FireFeature(WordID src, + SparseVector* features) const { int& fid = fmap_[src]; if (!fid) { static map escape; @@ -638,7 +587,7 @@ void InputIdentity::FireFeature(WordID src, features->set_value(fid, 1.0); } -void InputIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void InputIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const std::vector& ant_contexts, SparseVector* features, @@ -770,3 +719,40 @@ void WordPairFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +struct PathFertility { + unsigned char null_fertility; + unsigned char index_fertility[255]; + PathFertility& operator+=(const PathFertility& rhs) { + null_fertility += rhs.null_fertility; + for (int i = 0; i < 255; ++i) + index_fertility[i] += rhs.index_fertility[i]; + return *this; + } +}; + +Fertility::Fertility(const string& config) : + FeatureFunction(sizeof(PathFertility)) {} + +void Fertility::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + PathFertility& out_fert = *static_cast(context); + if (edge.Arity() == 0) { + if (edge.i_ < 0) { + out_fert.null_fertility = 1; + } else { + out_fert.index_fertility[edge.i_] = 1; + } + } else if (edge.Arity() == 2) { + const PathFertility left = *static_cast(ant_contexts[0]); + const PathFertility right = *static_cast(ant_contexts[1]); + out_fert += left; + out_fert += right; + } else if (edge.Arity() == 1) { + out_fert += *static_cast(ant_contexts[0]); + } +} + diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 8035000e..d7a2dda8 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -124,23 +124,6 @@ class LexicalTranslationTrigger : public FeatureFunction { std::vector > triggers_; }; -class AlignerResults : public FeatureFunction { - public: - AlignerResults(const std::string& param); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector& ant_contexts, - SparseVector* features, - SparseVector* estimated_features, - void* out_context) const; - private: - int fid_; - std::vector > > is_aligned_; - mutable int cur_sent_; - const Array2D mutable* cur_grid_; -}; - #include #include #include @@ -254,9 +237,9 @@ class IdentityCycleDetector : public FeatureFunction { mutable std::map big_enough_; }; -class InputIdentity : public FeatureFunction { +class InputIndicator : public FeatureFunction { public: - InputIdentity(const std::string& param); + InputIndicator(const std::string& param); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -270,4 +253,18 @@ class InputIdentity : public FeatureFunction { mutable Class2FID fmap_; }; +class Fertility : public FeatureFunction { + public: + Fertility(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + mutable std::map fids_; +}; + #endif diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 4dded35f..54890e85 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -96,7 +96,7 @@ bool Tagger::TranslateImpl(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* forest) { - Lattice lattice; + Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextToLattice(input, &lattice); smeta->SetSourceLength(lattice.size()); vector sequence(lattice.size()); -- cgit v1.2.3