summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-08 01:46:32 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-08 01:46:32 -0500
commite2998fc79c9dd549b1c1bad537fdf1052276f82c (patch)
tree5ed5a727b030c804ddb026e61b184e179b685132
parent7fd9fe26f00cf31a7b407364399d37b4eaf04eba (diff)
simple context feature for tagger
-rw-r--r--decoder/Makefile.am1
-rw-r--r--decoder/cdec_ff.cc2
-rw-r--r--decoder/ff_context.cc99
-rw-r--r--decoder/ff_context.h23
-rw-r--r--gi/pf/align-tl.cc6
-rw-r--r--gi/pf/reachability.cc2
-rw-r--r--gi/pf/reachability.h6
-rw-r--r--gi/pf/transliterations.cc198
-rw-r--r--gi/pf/transliterations.h5
9 files changed, 194 insertions, 148 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 30eaf04d..a00b18af 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -63,6 +63,7 @@ libcdec_a_SOURCES = \
ff.cc \
ff_rules.cc \
ff_wordset.cc \
+ ff_context.cc \
ff_charset.cc \
ff_lm.cc \
ff_klm.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 4ce5749e..b516c386 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -1,6 +1,7 @@
#include <boost/shared_ptr.hpp>
#include "ff.h"
+#include "ff_context.h"
#include "ff_spans.h"
#include "ff_lm.h"
#include "ff_klm.h"
@@ -42,6 +43,7 @@ void register_feature_functions() {
#endif
ff_registry.Register("SpanFeatures", new FFFactory<SpanFeatures>());
ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
+ ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>());
ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>);
ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>);
diff --git a/decoder/ff_context.cc b/decoder/ff_context.cc
new file mode 100644
index 00000000..19f9a413
--- /dev/null
+++ b/decoder/ff_context.cc
@@ -0,0 +1,99 @@
+#include "ff_context.h"
+
+#include <sstream>
+#include <cassert>
+#include <cmath>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "fdict.h"
+#include "verbose.h"
+
+using namespace std;
+
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
+RuleContextFeatures::RuleContextFeatures(const std::string& param) {
+ kSOS = TD::Convert("<s>");
+ kEOS = TD::Convert("</s>");
+
+ // TODO param lets you pass in a string from the cdec.ini file
+}
+
+void RuleContextFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ const Lattice& sl = smeta.GetSourceLattice();
+ current_input.resize(sl.size());
+ for (unsigned i = 0; i < sl.size(); ++i) {
+ if (sl[i].size() != 1) {
+ cerr << "Context features not supported with lattice inputs!\nid=" << smeta.GetSentenceId() << endl;
+ abort();
+ }
+ current_input[i] = sl[i][0].label;
+ }
+}
+
+void RuleContextFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ const TRule& rule = *edge.rule_;
+
+ if (rule.Arity() != 0 || // arity = 0, no nonterminals
+ rule.e_.size() != 1) return; // size = 1, predicted label is a single token
+
+
+ // you can see the current label "for free"
+ const WordID cur_label = rule.e_[0];
+ // (if you want to see more labels, you have to be very careful, and muck
+ // about with contexts and ant_contexts)
+
+ // but... you can look at as much of the source as you want!
+ const int from_src_index = edge.i_; // start of the span in the input being labeled
+ const int to_src_index = edge.j_; // end of the span in the input
+ // (note: in the case of tagging the size of the spans being labeled will
+ // always be 1, but in other formalisms, you can have bigger spans.)
+
+ // this is the current token being labeled:
+ const WordID cur_input = current_input[from_src_index];
+
+ // let's get the previous token in the input (may be to the left of the start
+ // of the sentence!)
+ WordID prev_input = kSOS;
+ if (from_src_index > 0) { prev_input = current_input[from_src_index - 1]; }
+ // let's get the next token (may be to the left of the start of the sentence!)
+ WordID next_input = kEOS;
+ if (to_src_index < current_input.size()) { next_input = current_input[to_src_index]; }
+
+ // now, build a feature string
+ ostringstream os;
+ // TD::Convert converts from the internal integer representation of a token
+ // to the actual token
+ os << "C1:" << TD::Convert(prev_input) << '_'
+ << TD::Convert(cur_input) << '|' << TD::Convert(cur_label);
+ // C1 is just to prevent a name clash
+
+ // pick a value
+ double fval = 1.0; // can be any real value
+
+ // add it to the feature vector FD::Convert converts the feature string to a
+ // feature int, Escape makes sure the feature string doesn't have any bad
+ // symbols that could confuse a parser somewhere
+ features->add_value(FD::Convert(Escape(os.str())), fval);
+ // that's it!
+
+ // create more features if you like...
+}
+
diff --git a/decoder/ff_context.h b/decoder/ff_context.h
new file mode 100644
index 00000000..0d22b027
--- /dev/null
+++ b/decoder/ff_context.h
@@ -0,0 +1,23 @@
+#ifndef _FF_CONTEXT_H_
+#define _FF_CONTEXT_H_
+
+#include <vector>
+#include "ff.h"
+
+class RuleContextFeatures : public FeatureFunction {
+ public:
+ RuleContextFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ virtual void PrepareForInput(const SentenceMetadata& smeta);
+ private:
+ std::vector<WordID> current_input;
+ WordID kSOS, kEOS;
+};
+
+#endif
diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc
index 0e0454e5..6bb8c886 100644
--- a/gi/pf/align-tl.cc
+++ b/gi/pf/align-tl.cc
@@ -310,18 +310,16 @@ int main(int argc, char** argv) {
// TODO CONFIGURE THIS
int min_trans_src = 4;
- cerr << "Initializing transliteration DPs ...\n";
+ cerr << "Initializing transliteration graph structures ...\n";
for (int i = 0; i < corpus.size(); ++i) {
const vector<int>& src = corpus[i].src;
const vector<int>& trg = corpus[i].trg;
- cerr << '.' << flush;
- if (i % 80 == 79) cerr << endl;
for (int j = 0; j < src.size(); ++j) {
const vector<int>& src_let = letters[src[j]];
for (int k = 0; k < trg.size(); ++k) {
const vector<int>& trg_let = letters[trg[k]];
if (src_let.size() < min_trans_src)
- tl.Forbid(src[j], trg[k]);
+ tl.Forbid(src[j], src_let, trg[k], trg_let);
else
tl.Initialize(src[j], src_let, trg[k], trg_let);
}
diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc
index 73dd8d39..70fb76da 100644
--- a/gi/pf/reachability.cc
+++ b/gi/pf/reachability.cc
@@ -47,6 +47,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true;
int src_delta = i - prevs[k].prev_src_covered;
edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true;
+ valid_deltas[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(make_pair<short,short>(src_delta,j - prevs[k].prev_trg_covered));
short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered];
if (src_delta > msd) msd = src_delta;
}
@@ -56,6 +57,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
assert(!edges[0][0][0][1]);
assert(!edges[0][0][0][0]);
assert(max_src_delta[0][0] > 0);
+ cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n";
//cerr << "First cell contains " << b[0][0].size() << " forward pointers\n";
//for (int i = 0; i < b[0][0].size(); ++i) {
// cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n";
diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h
index 98450ec1..fb2f4965 100644
--- a/gi/pf/reachability.h
+++ b/gi/pf/reachability.h
@@ -12,12 +12,14 @@
// currently forbids 0 -> n and n -> 0 alignments
struct Reachability {
- boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring?
+ boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring?
boost::multi_array<short, 2> max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid
+ boost::multi_array<std::vector<std::pair<short,short> >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node
Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) :
edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]),
- max_src_delta(boost::extents[srclen][trglen]) {
+ max_src_delta(boost::extents[srclen][trglen]),
+ valid_deltas(boost::extents[srclen][trglen]) {
ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len);
}
diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc
index 6e0c2e93..e29334fd 100644
--- a/gi/pf/transliterations.cc
+++ b/gi/pf/transliterations.cc
@@ -2,173 +2,92 @@
#include <iostream>
#include <vector>
-#include <tr1/unordered_map>
-#include "grammar.h"
-#include "bottom_up_parser.h"
-#include "hg.h"
-#include "hg_intersect.h"
+#include "boost/shared_ptr.hpp"
+
#include "filelib.h"
#include "ccrp.h"
#include "m.h"
-#include "lattice.h"
-#include "verbose.h"
+#include "reachability.h"
using namespace std;
using namespace std::tr1;
-static WordID kX;
-static int kMAX_SRC_SIZE = 0;
-static vector<vector<WordID> > cur_trg_chunks;
-
-vector<GrammarIter*> tlttofreelist;
-
-static void InitTargetChunks(int max_size, const vector<WordID>& trg) {
- cur_trg_chunks.clear();
- vector<WordID> tmp;
- unordered_set<vector<WordID>, boost::hash<vector<WordID> > > u;
- for (int len = 1; len <= max_size; ++len) {
- int end = trg.size() + 1;
- end -= len;
- for (int i = 0; i < end; ++i) {
- tmp.clear();
- for (int j = 0; j < len; ++j)
- tmp.push_back(trg[i + j]);
- if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp);
- }
- }
-}
-
-struct TransliterationGrammarIter : public GrammarIter, public RuleBin {
- TransliterationGrammarIter() { tlttofreelist.push_back(this); }
- TransliterationGrammarIter(const TRulePtr& inr, int symbol) {
- if (inr) {
- r.reset(new TRule(*inr));
- } else {
- r.reset(new TRule);
- }
- TRule& rr = *r;
- rr.lhs_ = kX;
- rr.f_.push_back(symbol);
- tlttofreelist.push_back(this);
- }
- virtual int GetNumRules() const {
- if (!r) return 0;
- return cur_trg_chunks.size();
- }
- virtual TRulePtr GetIthRule(int i) const {
- TRulePtr nr(new TRule(*r));
- nr->e_ = cur_trg_chunks[i];
- //cerr << nr->AsString() << endl;
- return nr;
- }
- virtual int Arity() const {
- return 0;
- }
- virtual const RuleBin* GetRules() const {
- if (!r) return NULL; else return this;
- }
- virtual const GrammarIter* Extend(int symbol) const {
- if (symbol <= 0) return NULL;
- if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE)
- return new TransliterationGrammarIter(r, symbol);
- else
- return NULL;
- }
- TRulePtr r;
-};
-
-struct TransliterationGrammar : public Grammar {
- virtual const GrammarIter* GetRoot() const {
- return new TransliterationGrammarIter;
- }
- virtual bool HasRuleForSpan(int, int, int distance) const {
- return (distance < kMAX_SRC_SIZE);
- }
-};
-
-struct TInfo {
- TInfo() : initialized(false) {}
+struct GraphStructure {
+ GraphStructure() : initialized(false) {}
+ boost::shared_ptr<Reachability> r;
bool initialized;
- Hypergraph lattice; // may be empty if transliteration is not possible
- prob_t est_prob; // will be zero if not possible
};
struct TransliterationsImpl {
TransliterationsImpl() {
- kX = TD::Convert("X")*-1;
- kMAX_SRC_SIZE = 4;
- grammars.push_back(GrammarPtr(new TransliterationGrammar));
- grammars.push_back(GrammarPtr(new GlueGrammar("S", "X")));
- SetSilent(true);
}
void Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
- if (src >= graphs.size()) graphs.resize(src + 1);
- if (graphs[src][trg].initialized) return;
- int kMAX_TRG_SIZE = 4;
- InitTargetChunks(kMAX_TRG_SIZE, trg_lets);
- ExhaustiveBottomUpParser parser("S", grammars);
- Lattice lat(src_lets.size()), tlat(trg_lets.size());
- for (unsigned i = 0; i < src_lets.size(); ++i)
- lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1));
- for (unsigned i = 0; i < trg_lets.size(); ++i)
- tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1));
- //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl;
- //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl;
- if (!parser.Parse(lat, &graphs[src][trg].lattice)) {
- //cerr << "Failed to parse " << TD::GetString(src_lets) << endl;
- abort();
- }
- if (HG::Intersect(tlat, &graphs[src][trg].lattice)) {
- graphs[src][trg].est_prob = prob_t(1e-4);
+ const size_t src_len = src_lets.size();
+ const size_t trg_len = trg_lets.size();
+ if (src_len >= graphs.size()) graphs.resize(src_len + 1);
+ if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
+ if (graphs[src_len][trg_len].initialized) return;
+ graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4));
+
+#if 0
+ if (HG::Intersect(tlat, &hg)) {
+ // TODO
} else {
- graphs[src][trg].lattice.clear();
- //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl;
- graphs[src][trg].est_prob = prob_t::Zero();
+ cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl;
+ hg.clear();
}
- for (unsigned i = 0; i < tlttofreelist.size(); ++i)
- delete tlttofreelist[i];
- tlttofreelist.clear();
//cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl;
- graphs[src][trg].initialized = true;
+#endif
+ graphs[src_len][trg_len].initialized = true;
}
- const prob_t& EstimateProbability(WordID src, WordID trg) const {
- assert(src < graphs.size());
- const unordered_map<WordID, TInfo>& um = graphs[src];
- const unordered_map<WordID, TInfo>::const_iterator it = um.find(trg);
- assert(it != um.end());
- assert(it->second.initialized);
- return it->second.est_prob;
+ void Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
+ const size_t src_len = src_lets.size();
+ const size_t trg_len = trg_lets.size();
+ if (src_len >= graphs.size()) graphs.resize(src_len + 1);
+ if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
+ graphs[src_len][trg_len].r.reset();
+ graphs[src_len][trg_len].initialized = true;
}
- void Forbid(WordID src, WordID trg) {
- if (src >= graphs.size()) graphs.resize(src + 1);
- graphs[src][trg].est_prob = prob_t::Zero();
- graphs[src][trg].initialized = true;
+ prob_t EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const {
+ assert(src.size() < graphs.size());
+ const vector<GraphStructure>& tv = graphs[src.size()];
+ assert(trg.size() < tv.size());
+ const GraphStructure& gs = tv[trg.size()];
+ // TODO: do prob
+ return prob_t::Zero();
}
void GraphSummary() const {
- double tlp = 0;
- int tt = 0;
+ double to = 0;
+ double tn = 0;
+ double tt = 0;
for (int i = 0; i < graphs.size(); ++i) {
- const unordered_map<WordID, TInfo>& um = graphs[i];
- unordered_map<WordID, TInfo>::const_iterator it;
- for (it = um.begin(); it != um.end(); ++it) {
- if (it->second.lattice.empty()) continue;
- //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl;
- tlp += log(it->second.lattice.NumberOfPaths());
+ const vector<GraphStructure>& vt = graphs[i];
+ for (int j = 0; j < vt.size(); ++j) {
+ const GraphStructure& gs = vt[j];
+ if (!gs.r) continue;
tt++;
+ for (int k = 0; k < i; ++k) {
+ for (int l = 0; l < j; ++l) {
+ size_t c = gs.r->valid_deltas[k][l].size();
+ if (c) {
+ tn += 1;
+ to += c;
+ }
+ }
+ }
}
}
- tlp /= tt;
- cerr << "E[log paths] = " << tlp << endl;
- cerr << "exp(E[log paths]) = " << exp(tlp) << endl;
+ cerr << " Average nodes = " << (tn / tt) << endl;
+ cerr << "Average out-degree = " << (to / tn) << endl;
+ cerr << " Unique structures = " << tt << endl;
}
- vector<unordered_map<WordID, TInfo> > graphs;
- vector<GrammarPtr> grammars;
+ vector<vector<GraphStructure> > graphs; // graphs[src_len][trg_len]
};
Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {}
@@ -178,16 +97,15 @@ void Transliterations::Initialize(WordID src, const vector<WordID>& src_lets, Wo
pimpl_->Initialize(src, src_lets, trg, trg_lets);
}
-prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const {
- return pimpl_->EstimateProbability(src,trg);
+prob_t Transliterations::EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const {
+ return pimpl_->EstimateProbability(s, src,t, trg);
}
-void Transliterations::Forbid(WordID src, WordID trg) {
- pimpl_->Forbid(src, trg);
+void Transliterations::Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
+ pimpl_->Forbid(src, src_lets, trg, trg_lets);
}
void Transliterations::GraphSummary() const {
pimpl_->GraphSummary();
}
-
diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h
index a548aacf..76eb2a05 100644
--- a/gi/pf/transliterations.h
+++ b/gi/pf/transliterations.h
@@ -10,9 +10,10 @@ struct Transliterations {
explicit Transliterations();
~Transliterations();
void Initialize(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);
- void Forbid(WordID src, WordID trg);
+ void Forbid(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);
void GraphSummary() const;
- prob_t EstimateProbability(WordID src, WordID trg) const;
+ prob_t EstimateProbability(WordID s, const std::vector<WordID>& src, WordID t, const std::vector<WordID>& trg) const;
+ private:
TransliterationsImpl* pimpl_;
};