diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-08 01:46:32 -0500 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2012-03-08 01:46:32 -0500 |
commit | e2998fc79c9dd549b1c1bad537fdf1052276f82c (patch) | |
tree | 5ed5a727b030c804ddb026e61b184e179b685132 | |
parent | 7fd9fe26f00cf31a7b407364399d37b4eaf04eba (diff) |
simple context feature for tagger
-rw-r--r-- | decoder/Makefile.am | 1 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 2 | ||||
-rw-r--r-- | decoder/ff_context.cc | 99 | ||||
-rw-r--r-- | decoder/ff_context.h | 23 | ||||
-rw-r--r-- | gi/pf/align-tl.cc | 6 | ||||
-rw-r--r-- | gi/pf/reachability.cc | 2 | ||||
-rw-r--r-- | gi/pf/reachability.h | 6 | ||||
-rw-r--r-- | gi/pf/transliterations.cc | 198 | ||||
-rw-r--r-- | gi/pf/transliterations.h | 5 |
9 files changed, 194 insertions, 148 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 30eaf04d..a00b18af 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -63,6 +63,7 @@ libcdec_a_SOURCES = \ ff.cc \ ff_rules.cc \ ff_wordset.cc \ + ff_context.cc \ ff_charset.cc \ ff_lm.cc \ ff_klm.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 4ce5749e..b516c386 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -1,6 +1,7 @@ #include <boost/shared_ptr.hpp> #include "ff.h" +#include "ff_context.h" #include "ff_spans.h" #include "ff_lm.h" #include "ff_klm.h" @@ -42,6 +43,7 @@ void register_feature_functions() { #endif ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>()); ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>()); + ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>()); ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>()); ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>); ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>); diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc new file mode 100644 index 00000000..19f9a413 --- /dev/null +++ b/decoder/ff_context.cc @@ -0,0 +1,99 @@ +#include "ff_context.h" + +#include <sstream> +#include <cassert> +#include <cmath> + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" + +using namespace std; + +namespace { + string Escape(const string& x) { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; + } +} + +RuleContextFeatures::RuleContextFeatures(const std::string& param) { + kSOS = TD::Convert("<s>"); + kEOS = TD::Convert("</s>"); + + // TODO param lets you pass in a string from the cdec.ini file +} + +void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) { + const Lattice& sl = smeta.GetSourceLattice(); + current_input.resize(sl.size()); + for (unsigned i = 0; i < sl.size(); ++i) { + if (sl[i].size() != 1) { + cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl; + abort(); + } + current_input[i] = sl[i][0].label; + } +} + +void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + const TRule& rule = *edge.rule_; + + if (rule.Arity() != 0 || // arity = 0, no nonterminals + rule.e_.size() != 1) return; // size = 1, predicted label is a single token + + + // you can see the current label "for free" + const WordID cur_label = rule.e_[0]; + // (if you want to see more labels, you have to be very careful, and muck + // about with contexts and ant_contexts) + + // but... you can look at as much of the source as you want! + const int from_src_index = edge.i_; // start of the span in the input being labeled + const int to_src_index = edge.j_; // end of the span in the input + // (note: in the case of tagging the size of the spans being labeled will + // always be 1, but in other formalisms, you can have bigger spans.) + + // this is the current token being labeled: + const WordID cur_input = current_input[from_src_index]; + + // let's get the previous token in the input (may be to the left of the start + // of the sentence!) + WordID prev_input = kSOS; + if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; } + // let's get the next token (may be to the left of the start of the sentence!) + WordID next_input = kEOS; + if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; } + + // now, build a feature string + ostringstream os; + // TD::Convert converts from the internal integer representation of a token + // to the actual token + os << "C1:" << TD::Convert(prev_input) << '_' + << TD::Convert(cur_input) << '|' << TD::Convert(cur_label); + // C1 is just to prevent a name clash + + // pick a value + double fval = 1.0; // can be any real value + + // add it to the feature vector FD::Convert converts the feature string to a + // feature int, Escape makes sure the feature string doesn't have any bad + // symbols that could confuse a parser somewhere + features->add_value(FD::Convert(Escape(os.str())), fval); + // that's it! + + // create more features if you like... +} + diff --git a/decoder/ff_context.h b/decoder/ff_context.h new file mode 100644 index 00000000..0d22b027 --- /dev/null +++ b/decoder/ff_context.h @@ -0,0 +1,23 @@ +#ifndef _FF_CONTEXT_H_ +#define _FF_CONTEXT_H_ + +#include <vector> +#include "ff.h" + +class RuleContextFeatures : public FeatureFunction { + public: + RuleContextFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + virtual void PrepareForInput(const SentenceMetadata& smeta); + private: + std::vector<WordID> current_input; + WordID kSOS, kEOS; +}; + +#endif diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index 0e0454e5..6bb8c886 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -310,18 +310,16 @@ int main(int argc, char** argv) { // TODO CONFIGURE THIS int min_trans_src = 4; - cerr << "Initializing transliteration DPs ...\n"; + cerr << "Initializing transliteration graph structures ...\n"; for (int i = 0; i < corpus.size(); ++i) { const vector<int>& src = corpus[i].src; const vector<int>& trg = corpus[i].trg; - cerr << '.' << flush; - if (i % 80 == 79) cerr << endl; for (int j = 0; j < src.size(); ++j) { const vector<int>& src_let = letters[src[j]]; for (int k = 0; k < trg.size(); ++k) { const vector<int>& trg_let = letters[trg[k]]; if (src_let.size() < min_trans_src) - tl.Forbid(src[j], trg[k]); + tl.Forbid(src[j], src_let, trg[k], trg_let); else tl.Initialize(src[j], src_let, trg[k], trg_let); } diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index 73dd8d39..70fb76da 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -47,6 +47,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; int src_delta = i - prevs[k].prev_src_covered; edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + valid_deltas[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(make_pair<short,short>(src_delta,j - prevs[k].prev_trg_covered)); short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; if (src_delta > msd) msd = src_delta; } @@ -56,6 +57,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras assert(!edges[0][0][0][1]); assert(!edges[0][0][0][0]); assert(max_src_delta[0][0] > 0); + cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n"; //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; //for (int i = 0; i < b[0][0].size(); ++i) { // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index 98450ec1..fb2f4965 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -12,12 +12,14 @@ // currently forbids 0 -> n and n -> 0 alignments struct Reachability { - boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring? boost::multi_array<short, 2> max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + boost::multi_array<std::vector<std::pair<short,short> >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), - max_src_delta(boost::extents[srclen][trglen]) { + max_src_delta(boost::extents[srclen][trglen]), + valid_deltas(boost::extents[srclen][trglen]) { ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); } diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index 6e0c2e93..e29334fd 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -2,173 +2,92 @@ #include <iostream> #include <vector> -#include <tr1/unordered_map> -#include "grammar.h" -#include "bottom_up_parser.h" -#include "hg.h" -#include "hg_intersect.h" +#include "boost/shared_ptr.hpp" + #include "filelib.h" #include "ccrp.h" #include "m.h" -#include "lattice.h" -#include "verbose.h" +#include "reachability.h" using namespace std; using namespace std::tr1; -static WordID kX; -static int kMAX_SRC_SIZE = 0; -static vector<vector<WordID> > cur_trg_chunks; - -vector<GrammarIter*> tlttofreelist; - -static void InitTargetChunks(int max_size, const vector<WordID>& trg) { - cur_trg_chunks.clear(); - vector<WordID> tmp; - unordered_set<vector<WordID>, boost::hash<vector<WordID> > > u; - for (int len = 1; len <= max_size; ++len) { - int end = trg.size() + 1; - end -= len; - for (int i = 0; i < end; ++i) { - tmp.clear(); - for (int j = 0; j < len; ++j) - tmp.push_back(trg[i + j]); - if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp); - } - } -} - -struct TransliterationGrammarIter : public GrammarIter, public RuleBin { - TransliterationGrammarIter() { tlttofreelist.push_back(this); } - TransliterationGrammarIter(const TRulePtr& inr, int symbol) { - if (inr) { - r.reset(new TRule(*inr)); - } else { - r.reset(new TRule); - } - TRule& rr = *r; - rr.lhs_ = kX; - rr.f_.push_back(symbol); - tlttofreelist.push_back(this); - } - virtual int GetNumRules() const { - if (!r) return 0; - return cur_trg_chunks.size(); - } - virtual TRulePtr GetIthRule(int i) const { - TRulePtr nr(new TRule(*r)); - nr->e_ = cur_trg_chunks[i]; - //cerr << nr->AsString() << endl; - return nr; - } - virtual int Arity() const { - return 0; - } - virtual const RuleBin* GetRules() const { - if (!r) return NULL; else return this; - } - virtual const GrammarIter* Extend(int symbol) const { - if (symbol <= 0) return NULL; - if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE) - return new TransliterationGrammarIter(r, symbol); - else - return NULL; - } - TRulePtr r; -}; - -struct TransliterationGrammar : public Grammar { - virtual const GrammarIter* GetRoot() const { - return new TransliterationGrammarIter; - } - virtual bool HasRuleForSpan(int, int, int distance) const { - return (distance < kMAX_SRC_SIZE); - } -}; - -struct TInfo { - TInfo() : initialized(false) {} +struct GraphStructure { + GraphStructure() : initialized(false) {} + boost::shared_ptr<Reachability> r; bool initialized; - Hypergraph lattice; // may be empty if transliteration is not possible - prob_t est_prob; // will be zero if not possible }; struct TransliterationsImpl { TransliterationsImpl() { - kX = TD::Convert("X")*-1; - kMAX_SRC_SIZE = 4; - grammars.push_back(GrammarPtr(new TransliterationGrammar)); - grammars.push_back(GrammarPtr(new GlueGrammar("S", "X"))); - SetSilent(true); } void Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) { - if (src >= graphs.size()) graphs.resize(src + 1); - if (graphs[src][trg].initialized) return; - int kMAX_TRG_SIZE = 4; - InitTargetChunks(kMAX_TRG_SIZE, trg_lets); - ExhaustiveBottomUpParser parser("S", grammars); - Lattice lat(src_lets.size()), tlat(trg_lets.size()); - for (unsigned i = 0; i < src_lets.size(); ++i) - lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1)); - for (unsigned i = 0; i < trg_lets.size(); ++i) - tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1)); - //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl; - //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl; - if (!parser.Parse(lat, &graphs[src][trg].lattice)) { - //cerr << "Failed to parse " << TD::GetString(src_lets) << endl; - abort(); - } - if (HG::Intersect(tlat, &graphs[src][trg].lattice)) { - graphs[src][trg].est_prob = prob_t(1e-4); + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + if (graphs[src_len][trg_len].initialized) return; + graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4)); + +#if 0 + if (HG::Intersect(tlat, &hg)) { + // TODO } else { - graphs[src][trg].lattice.clear(); - //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl; - graphs[src][trg].est_prob = prob_t::Zero(); + cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl; + hg.clear(); } - for (unsigned i = 0; i < tlttofreelist.size(); ++i) - delete tlttofreelist[i]; - tlttofreelist.clear(); //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; - graphs[src][trg].initialized = true; +#endif + graphs[src_len][trg_len].initialized = true; } - const prob_t& EstimateProbability(WordID src, WordID trg) const { - assert(src < graphs.size()); - const unordered_map<WordID, TInfo>& um = graphs[src]; - const unordered_map<WordID, TInfo>::const_iterator it = um.find(trg); - assert(it != um.end()); - assert(it->second.initialized); - return it->second.est_prob; + void Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) { + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + graphs[src_len][trg_len].r.reset(); + graphs[src_len][trg_len].initialized = true; } - void Forbid(WordID src, WordID trg) { - if (src >= graphs.size()) graphs.resize(src + 1); - graphs[src][trg].est_prob = prob_t::Zero(); - graphs[src][trg].initialized = true; + prob_t EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const { + assert(src.size() < graphs.size()); + const vector<GraphStructure>& tv = graphs[src.size()]; + assert(trg.size() < tv.size()); + const GraphStructure& gs = tv[trg.size()]; + // TODO: do prob + return prob_t::Zero(); } void GraphSummary() const { - double tlp = 0; - int tt = 0; + double to = 0; + double tn = 0; + double tt = 0; for (int i = 0; i < graphs.size(); ++i) { - const unordered_map<WordID, TInfo>& um = graphs[i]; - unordered_map<WordID, TInfo>::const_iterator it; - for (it = um.begin(); it != um.end(); ++it) { - if (it->second.lattice.empty()) continue; - //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl; - tlp += log(it->second.lattice.NumberOfPaths()); + const vector<GraphStructure>& vt = graphs[i]; + for (int j = 0; j < vt.size(); ++j) { + const GraphStructure& gs = vt[j]; + if (!gs.r) continue; tt++; + for (int k = 0; k < i; ++k) { + for (int l = 0; l < j; ++l) { + size_t c = gs.r->valid_deltas[k][l].size(); + if (c) { + tn += 1; + to += c; + } + } + } } } - tlp /= tt; - cerr << "E[log paths] = " << tlp << endl; - cerr << "exp(E[log paths]) = " << exp(tlp) << endl; + cerr << " Average nodes = " << (tn / tt) << endl; + cerr << "Average out-degree = " << (to / tn) << endl; + cerr << " Unique structures = " << tt << endl; } - vector<unordered_map<WordID, TInfo> > graphs; - vector<GrammarPtr> grammars; + vector<vector<GraphStructure> > graphs; // graphs[src_len][trg_len] }; Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} @@ -178,16 +97,15 @@ void Transliterations::Initialize(WordID src, const vector<WordID>& src_lets, Wo pimpl_->Initialize(src, src_lets, trg, trg_lets); } -prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const { - return pimpl_->EstimateProbability(src,trg); +prob_t Transliterations::EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const { + return pimpl_->EstimateProbability(s, src,t, trg); } -void Transliterations::Forbid(WordID src, WordID trg) { - pimpl_->Forbid(src, trg); +void Transliterations::Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) { + pimpl_->Forbid(src, src_lets, trg, trg_lets); } void Transliterations::GraphSummary() const { pimpl_->GraphSummary(); } - diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index a548aacf..76eb2a05 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -10,9 +10,10 @@ struct Transliterations { explicit Transliterations(); ~Transliterations(); void Initialize(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets); - void Forbid(WordID src, WordID trg); + void Forbid(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets); void GraphSummary() const; - prob_t EstimateProbability(WordID src, WordID trg) const; + prob_t EstimateProbability(WordID s, const std::vector<WordID>& src, WordID t, const std::vector<WordID>& trg) const; + private: TransliterationsImpl* pimpl_; }; |