From e2998fc79c9dd549b1c1bad537fdf1052276f82c Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 8 Mar 2012 01:46:32 -0500 Subject: simple context feature for tagger --- gi/pf/align-tl.cc | 6 +- gi/pf/reachability.cc | 2 + gi/pf/reachability.h | 6 +- gi/pf/transliterations.cc | 198 ++++++++++++++-------------------------------- gi/pf/transliterations.h | 5 +- 5 files changed, 69 insertions(+), 148 deletions(-) (limited to 'gi/pf') diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc index 0e0454e5..6bb8c886 100644 --- a/gi/pf/align-tl.cc +++ b/gi/pf/align-tl.cc @@ -310,18 +310,16 @@ int main(int argc, char** argv) { // TODO CONFIGURE THIS int min_trans_src = 4; - cerr << "Initializing transliteration DPs ...\n"; + cerr << "Initializing transliteration graph structures ...\n"; for (int i = 0; i < corpus.size(); ++i) { const vector& src = corpus[i].src; const vector& trg = corpus[i].trg; - cerr << '.' << flush; - if (i % 80 == 79) cerr << endl; for (int j = 0; j < src.size(); ++j) { const vector& src_let = letters[src[j]]; for (int k = 0; k < trg.size(); ++k) { const vector& trg_let = letters[trg[k]]; if (src_let.size() < min_trans_src) - tl.Forbid(src[j], trg[k]); + tl.Forbid(src[j], src_let, trg[k], trg_let); else tl.Initialize(src[j], src_let, trg[k], trg_let); } diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc index 73dd8d39..70fb76da 100644 --- a/gi/pf/reachability.cc +++ b/gi/pf/reachability.cc @@ -47,6 +47,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true; int src_delta = i - prevs[k].prev_src_covered; edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true; + valid_deltas[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(make_pair(src_delta,j - prevs[k].prev_trg_covered)); short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered]; if (src_delta > msd) msd = src_delta; } @@ -56,6 +57,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras assert(!edges[0][0][0][1]); assert(!edges[0][0][0][0]); assert(max_src_delta[0][0] > 0); + cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n"; //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n"; //for (int i = 0; i < b[0][0].size(); ++i) { // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n"; diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h index 98450ec1..fb2f4965 100644 --- a/gi/pf/reachability.h +++ b/gi/pf/reachability.h @@ -12,12 +12,14 @@ // currently forbids 0 -> n and n -> 0 alignments struct Reachability { - boost::multi_array edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring? + boost::multi_array edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring? boost::multi_array max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid + boost::multi_array >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) : edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]), - max_src_delta(boost::extents[srclen][trglen]) { + max_src_delta(boost::extents[srclen][trglen]), + valid_deltas(boost::extents[srclen][trglen]) { ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len); } diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc index 6e0c2e93..e29334fd 100644 --- a/gi/pf/transliterations.cc +++ b/gi/pf/transliterations.cc @@ -2,173 +2,92 @@ #include #include -#include -#include "grammar.h" -#include "bottom_up_parser.h" -#include "hg.h" -#include "hg_intersect.h" +#include "boost/shared_ptr.hpp" + #include "filelib.h" #include "ccrp.h" #include "m.h" -#include "lattice.h" -#include "verbose.h" +#include "reachability.h" using namespace std; using namespace std::tr1; -static WordID kX; -static int kMAX_SRC_SIZE = 0; -static vector > cur_trg_chunks; - -vector tlttofreelist; - -static void InitTargetChunks(int max_size, const vector& trg) { - cur_trg_chunks.clear(); - vector tmp; - unordered_set, boost::hash > > u; - for (int len = 1; len <= max_size; ++len) { - int end = trg.size() + 1; - end -= len; - for (int i = 0; i < end; ++i) { - tmp.clear(); - for (int j = 0; j < len; ++j) - tmp.push_back(trg[i + j]); - if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp); - } - } -} - -struct TransliterationGrammarIter : public GrammarIter, public RuleBin { - TransliterationGrammarIter() { tlttofreelist.push_back(this); } - TransliterationGrammarIter(const TRulePtr& inr, int symbol) { - if (inr) { - r.reset(new TRule(*inr)); - } else { - r.reset(new TRule); - } - TRule& rr = *r; - rr.lhs_ = kX; - rr.f_.push_back(symbol); - tlttofreelist.push_back(this); - } - virtual int GetNumRules() const { - if (!r) return 0; - return cur_trg_chunks.size(); - } - virtual TRulePtr GetIthRule(int i) const { - TRulePtr nr(new TRule(*r)); - nr->e_ = cur_trg_chunks[i]; - //cerr << nr->AsString() << endl; - return nr; - } - virtual int Arity() const { - return 0; - } - virtual const RuleBin* GetRules() const { - if (!r) return NULL; else return this; - } - virtual const GrammarIter* Extend(int symbol) const { - if (symbol <= 0) return NULL; - if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE) - return new TransliterationGrammarIter(r, symbol); - else - return NULL; - } - TRulePtr r; -}; - -struct TransliterationGrammar : public Grammar { - virtual const GrammarIter* GetRoot() const { - return new TransliterationGrammarIter; - } - virtual bool HasRuleForSpan(int, int, int distance) const { - return (distance < kMAX_SRC_SIZE); - } -}; - -struct TInfo { - TInfo() : initialized(false) {} +struct GraphStructure { + GraphStructure() : initialized(false) {} + boost::shared_ptr r; bool initialized; - Hypergraph lattice; // may be empty if transliteration is not possible - prob_t est_prob; // will be zero if not possible }; struct TransliterationsImpl { TransliterationsImpl() { - kX = TD::Convert("X")*-1; - kMAX_SRC_SIZE = 4; - grammars.push_back(GrammarPtr(new TransliterationGrammar)); - grammars.push_back(GrammarPtr(new GlueGrammar("S", "X"))); - SetSilent(true); } void Initialize(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { - if (src >= graphs.size()) graphs.resize(src + 1); - if (graphs[src][trg].initialized) return; - int kMAX_TRG_SIZE = 4; - InitTargetChunks(kMAX_TRG_SIZE, trg_lets); - ExhaustiveBottomUpParser parser("S", grammars); - Lattice lat(src_lets.size()), tlat(trg_lets.size()); - for (unsigned i = 0; i < src_lets.size(); ++i) - lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1)); - for (unsigned i = 0; i < trg_lets.size(); ++i) - tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1)); - //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl; - //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl; - if (!parser.Parse(lat, &graphs[src][trg].lattice)) { - //cerr << "Failed to parse " << TD::GetString(src_lets) << endl; - abort(); - } - if (HG::Intersect(tlat, &graphs[src][trg].lattice)) { - graphs[src][trg].est_prob = prob_t(1e-4); + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + if (graphs[src_len][trg_len].initialized) return; + graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4)); + +#if 0 + if (HG::Intersect(tlat, &hg)) { + // TODO } else { - graphs[src][trg].lattice.clear(); - //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl; - graphs[src][trg].est_prob = prob_t::Zero(); + cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl; + hg.clear(); } - for (unsigned i = 0; i < tlttofreelist.size(); ++i) - delete tlttofreelist[i]; - tlttofreelist.clear(); //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl; - graphs[src][trg].initialized = true; +#endif + graphs[src_len][trg_len].initialized = true; } - const prob_t& EstimateProbability(WordID src, WordID trg) const { - assert(src < graphs.size()); - const unordered_map& um = graphs[src]; - const unordered_map::const_iterator it = um.find(trg); - assert(it != um.end()); - assert(it->second.initialized); - return it->second.est_prob; + void Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + const size_t src_len = src_lets.size(); + const size_t trg_len = trg_lets.size(); + if (src_len >= graphs.size()) graphs.resize(src_len + 1); + if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1); + graphs[src_len][trg_len].r.reset(); + graphs[src_len][trg_len].initialized = true; } - void Forbid(WordID src, WordID trg) { - if (src >= graphs.size()) graphs.resize(src + 1); - graphs[src][trg].est_prob = prob_t::Zero(); - graphs[src][trg].initialized = true; + prob_t EstimateProbability(WordID s, const vector& src, WordID t, const vector& trg) const { + assert(src.size() < graphs.size()); + const vector& tv = graphs[src.size()]; + assert(trg.size() < tv.size()); + const GraphStructure& gs = tv[trg.size()]; + // TODO: do prob + return prob_t::Zero(); } void GraphSummary() const { - double tlp = 0; - int tt = 0; + double to = 0; + double tn = 0; + double tt = 0; for (int i = 0; i < graphs.size(); ++i) { - const unordered_map& um = graphs[i]; - unordered_map::const_iterator it; - for (it = um.begin(); it != um.end(); ++it) { - if (it->second.lattice.empty()) continue; - //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl; - tlp += log(it->second.lattice.NumberOfPaths()); + const vector& vt = graphs[i]; + for (int j = 0; j < vt.size(); ++j) { + const GraphStructure& gs = vt[j]; + if (!gs.r) continue; tt++; + for (int k = 0; k < i; ++k) { + for (int l = 0; l < j; ++l) { + size_t c = gs.r->valid_deltas[k][l].size(); + if (c) { + tn += 1; + to += c; + } + } + } } } - tlp /= tt; - cerr << "E[log paths] = " << tlp << endl; - cerr << "exp(E[log paths]) = " << exp(tlp) << endl; + cerr << " Average nodes = " << (tn / tt) << endl; + cerr << "Average out-degree = " << (to / tn) << endl; + cerr << " Unique structures = " << tt << endl; } - vector > graphs; - vector grammars; + vector > graphs; // graphs[src_len][trg_len] }; Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {} @@ -178,16 +97,15 @@ void Transliterations::Initialize(WordID src, const vector& src_lets, Wo pimpl_->Initialize(src, src_lets, trg, trg_lets); } -prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const { - return pimpl_->EstimateProbability(src,trg); +prob_t Transliterations::EstimateProbability(WordID s, const vector& src, WordID t, const vector& trg) const { + return pimpl_->EstimateProbability(s, src,t, trg); } -void Transliterations::Forbid(WordID src, WordID trg) { - pimpl_->Forbid(src, trg); +void Transliterations::Forbid(WordID src, const vector& src_lets, WordID trg, const vector& trg_lets) { + pimpl_->Forbid(src, src_lets, trg, trg_lets); } void Transliterations::GraphSummary() const { pimpl_->GraphSummary(); } - diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h index a548aacf..76eb2a05 100644 --- a/gi/pf/transliterations.h +++ b/gi/pf/transliterations.h @@ -10,9 +10,10 @@ struct Transliterations { explicit Transliterations(); ~Transliterations(); void Initialize(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); - void Forbid(WordID src, WordID trg); + void Forbid(WordID src, const std::vector& src_lets, WordID trg, const std::vector& trg_lets); void GraphSummary() const; - prob_t EstimateProbability(WordID src, WordID trg) const; + prob_t EstimateProbability(WordID s, const std::vector& src, WordID t, const std::vector& trg) const; + private: TransliterationsImpl* pimpl_; }; -- cgit v1.2.3