summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-08 01:46:32 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-08 01:46:32 -0500
commite2998fc79c9dd549b1c1bad537fdf1052276f82c (patch)
tree5ed5a727b030c804ddb026e61b184e179b685132 /gi
parent7fd9fe26f00cf31a7b407364399d37b4eaf04eba (diff)
simple context feature for tagger
Diffstat (limited to 'gi')
-rw-r--r--gi/pf/align-tl.cc6
-rw-r--r--gi/pf/reachability.cc2
-rw-r--r--gi/pf/reachability.h6
-rw-r--r--gi/pf/transliterations.cc198
-rw-r--r--gi/pf/transliterations.h5
5 files changed, 69 insertions, 148 deletions
diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc
index 0e0454e5..6bb8c886 100644
--- a/gi/pf/align-tl.cc
+++ b/gi/pf/align-tl.cc
@@ -310,18 +310,16 @@ int main(int argc, char** argv) {
// TODO CONFIGURE THIS
int min_trans_src = 4;
- cerr << "Initializing transliteration DPs ...\n";
+ cerr << "Initializing transliteration graph structures ...\n";
for (int i = 0; i < corpus.size(); ++i) {
const vector<int>& src = corpus[i].src;
const vector<int>& trg = corpus[i].trg;
- cerr << '.' << flush;
- if (i % 80 == 79) cerr << endl;
for (int j = 0; j < src.size(); ++j) {
const vector<int>& src_let = letters[src[j]];
for (int k = 0; k < trg.size(); ++k) {
const vector<int>& trg_let = letters[trg[k]];
if (src_let.size() < min_trans_src)
- tl.Forbid(src[j], trg[k]);
+ tl.Forbid(src[j], src_let, trg[k], trg_let);
else
tl.Initialize(src[j], src_let, trg[k], trg_let);
}
diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc
index 73dd8d39..70fb76da 100644
--- a/gi/pf/reachability.cc
+++ b/gi/pf/reachability.cc
@@ -47,6 +47,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
r[prevs[k].prev_src_covered][prevs[k].prev_trg_covered] = true;
int src_delta = i - prevs[k].prev_src_covered;
edges[prevs[k].prev_src_covered][prevs[k].prev_trg_covered][src_delta][j - prevs[k].prev_trg_covered] = true;
+ valid_deltas[prevs[k].prev_src_covered][prevs[k].prev_trg_covered].push_back(make_pair<short,short>(src_delta,j - prevs[k].prev_trg_covered));
short &msd = max_src_delta[prevs[k].prev_src_covered][prevs[k].prev_trg_covered];
if (src_delta > msd) msd = src_delta;
}
@@ -56,6 +57,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
assert(!edges[0][0][0][1]);
assert(!edges[0][0][0][0]);
assert(max_src_delta[0][0] > 0);
+ cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n";
//cerr << "First cell contains " << b[0][0].size() << " forward pointers\n";
//for (int i = 0; i < b[0][0].size(); ++i) {
// cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n";
diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h
index 98450ec1..fb2f4965 100644
--- a/gi/pf/reachability.h
+++ b/gi/pf/reachability.h
@@ -12,12 +12,14 @@
// currently forbids 0 -> n and n -> 0 alignments
struct Reachability {
- boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][x][trg_delta] is this edge worth exploring?
+ boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring?
boost::multi_array<short, 2> max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid
+ boost::multi_array<std::vector<std::pair<short,short> >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node
Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) :
edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]),
- max_src_delta(boost::extents[srclen][trglen]) {
+ max_src_delta(boost::extents[srclen][trglen]),
+ valid_deltas(boost::extents[srclen][trglen]) {
ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len);
}
diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc
index 6e0c2e93..e29334fd 100644
--- a/gi/pf/transliterations.cc
+++ b/gi/pf/transliterations.cc
@@ -2,173 +2,92 @@
#include <iostream>
#include <vector>
-#include <tr1/unordered_map>
-#include "grammar.h"
-#include "bottom_up_parser.h"
-#include "hg.h"
-#include "hg_intersect.h"
+#include "boost/shared_ptr.hpp"
+
#include "filelib.h"
#include "ccrp.h"
#include "m.h"
-#include "lattice.h"
-#include "verbose.h"
+#include "reachability.h"
using namespace std;
using namespace std::tr1;
-static WordID kX;
-static int kMAX_SRC_SIZE = 0;
-static vector<vector<WordID> > cur_trg_chunks;
-
-vector<GrammarIter*> tlttofreelist;
-
-static void InitTargetChunks(int max_size, const vector<WordID>& trg) {
- cur_trg_chunks.clear();
- vector<WordID> tmp;
- unordered_set<vector<WordID>, boost::hash<vector<WordID> > > u;
- for (int len = 1; len <= max_size; ++len) {
- int end = trg.size() + 1;
- end -= len;
- for (int i = 0; i < end; ++i) {
- tmp.clear();
- for (int j = 0; j < len; ++j)
- tmp.push_back(trg[i + j]);
- if (u.insert(tmp).second) cur_trg_chunks.push_back(tmp);
- }
- }
-}
-
-struct TransliterationGrammarIter : public GrammarIter, public RuleBin {
- TransliterationGrammarIter() { tlttofreelist.push_back(this); }
- TransliterationGrammarIter(const TRulePtr& inr, int symbol) {
- if (inr) {
- r.reset(new TRule(*inr));
- } else {
- r.reset(new TRule);
- }
- TRule& rr = *r;
- rr.lhs_ = kX;
- rr.f_.push_back(symbol);
- tlttofreelist.push_back(this);
- }
- virtual int GetNumRules() const {
- if (!r) return 0;
- return cur_trg_chunks.size();
- }
- virtual TRulePtr GetIthRule(int i) const {
- TRulePtr nr(new TRule(*r));
- nr->e_ = cur_trg_chunks[i];
- //cerr << nr->AsString() << endl;
- return nr;
- }
- virtual int Arity() const {
- return 0;
- }
- virtual const RuleBin* GetRules() const {
- if (!r) return NULL; else return this;
- }
- virtual const GrammarIter* Extend(int symbol) const {
- if (symbol <= 0) return NULL;
- if (!r || !kMAX_SRC_SIZE || r->f_.size() < kMAX_SRC_SIZE)
- return new TransliterationGrammarIter(r, symbol);
- else
- return NULL;
- }
- TRulePtr r;
-};
-
-struct TransliterationGrammar : public Grammar {
- virtual const GrammarIter* GetRoot() const {
- return new TransliterationGrammarIter;
- }
- virtual bool HasRuleForSpan(int, int, int distance) const {
- return (distance < kMAX_SRC_SIZE);
- }
-};
-
-struct TInfo {
- TInfo() : initialized(false) {}
+struct GraphStructure {
+ GraphStructure() : initialized(false) {}
+ boost::shared_ptr<Reachability> r;
bool initialized;
- Hypergraph lattice; // may be empty if transliteration is not possible
- prob_t est_prob; // will be zero if not possible
};
struct TransliterationsImpl {
TransliterationsImpl() {
- kX = TD::Convert("X")*-1;
- kMAX_SRC_SIZE = 4;
- grammars.push_back(GrammarPtr(new TransliterationGrammar));
- grammars.push_back(GrammarPtr(new GlueGrammar("S", "X")));
- SetSilent(true);
}
void Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
- if (src >= graphs.size()) graphs.resize(src + 1);
- if (graphs[src][trg].initialized) return;
- int kMAX_TRG_SIZE = 4;
- InitTargetChunks(kMAX_TRG_SIZE, trg_lets);
- ExhaustiveBottomUpParser parser("S", grammars);
- Lattice lat(src_lets.size()), tlat(trg_lets.size());
- for (unsigned i = 0; i < src_lets.size(); ++i)
- lat[i].push_back(LatticeArc(src_lets[i], 0.0, 1));
- for (unsigned i = 0; i < trg_lets.size(); ++i)
- tlat[i].push_back(LatticeArc(trg_lets[i], 0.0, 1));
- //cerr << "Creating lattice for: " << TD::Convert(src) << " --> " << TD::Convert(trg) << endl;
- //cerr << "'" << TD::GetString(src_lets) << "' --> " << TD::GetString(trg_lets) << endl;
- if (!parser.Parse(lat, &graphs[src][trg].lattice)) {
- //cerr << "Failed to parse " << TD::GetString(src_lets) << endl;
- abort();
- }
- if (HG::Intersect(tlat, &graphs[src][trg].lattice)) {
- graphs[src][trg].est_prob = prob_t(1e-4);
+ const size_t src_len = src_lets.size();
+ const size_t trg_len = trg_lets.size();
+ if (src_len >= graphs.size()) graphs.resize(src_len + 1);
+ if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
+ if (graphs[src_len][trg_len].initialized) return;
+ graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4));
+
+#if 0
+ if (HG::Intersect(tlat, &hg)) {
+ // TODO
} else {
- graphs[src][trg].lattice.clear();
- //cerr << "Failed to intersect " << TD::GetString(src_lets) << " ||| " << TD::GetString(trg_lets) << endl;
- graphs[src][trg].est_prob = prob_t::Zero();
+ cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl;
+ hg.clear();
}
- for (unsigned i = 0; i < tlttofreelist.size(); ++i)
- delete tlttofreelist[i];
- tlttofreelist.clear();
//cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl;
- graphs[src][trg].initialized = true;
+#endif
+ graphs[src_len][trg_len].initialized = true;
}
- const prob_t& EstimateProbability(WordID src, WordID trg) const {
- assert(src < graphs.size());
- const unordered_map<WordID, TInfo>& um = graphs[src];
- const unordered_map<WordID, TInfo>::const_iterator it = um.find(trg);
- assert(it != um.end());
- assert(it->second.initialized);
- return it->second.est_prob;
+ void Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
+ const size_t src_len = src_lets.size();
+ const size_t trg_len = trg_lets.size();
+ if (src_len >= graphs.size()) graphs.resize(src_len + 1);
+ if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
+ graphs[src_len][trg_len].r.reset();
+ graphs[src_len][trg_len].initialized = true;
}
- void Forbid(WordID src, WordID trg) {
- if (src >= graphs.size()) graphs.resize(src + 1);
- graphs[src][trg].est_prob = prob_t::Zero();
- graphs[src][trg].initialized = true;
+ prob_t EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const {
+ assert(src.size() < graphs.size());
+ const vector<GraphStructure>& tv = graphs[src.size()];
+ assert(trg.size() < tv.size());
+ const GraphStructure& gs = tv[trg.size()];
+ // TODO: do prob
+ return prob_t::Zero();
}
void GraphSummary() const {
- double tlp = 0;
- int tt = 0;
+ double to = 0;
+ double tn = 0;
+ double tt = 0;
for (int i = 0; i < graphs.size(); ++i) {
- const unordered_map<WordID, TInfo>& um = graphs[i];
- unordered_map<WordID, TInfo>::const_iterator it;
- for (it = um.begin(); it != um.end(); ++it) {
- if (it->second.lattice.empty()) continue;
- //cerr << TD::Convert(i) << " --> " << TD::Convert(it->first) << ": " << it->second.lattice.NumberOfPaths() << endl;
- tlp += log(it->second.lattice.NumberOfPaths());
+ const vector<GraphStructure>& vt = graphs[i];
+ for (int j = 0; j < vt.size(); ++j) {
+ const GraphStructure& gs = vt[j];
+ if (!gs.r) continue;
tt++;
+ for (int k = 0; k < i; ++k) {
+ for (int l = 0; l < j; ++l) {
+ size_t c = gs.r->valid_deltas[k][l].size();
+ if (c) {
+ tn += 1;
+ to += c;
+ }
+ }
+ }
}
}
- tlp /= tt;
- cerr << "E[log paths] = " << tlp << endl;
- cerr << "exp(E[log paths]) = " << exp(tlp) << endl;
+ cerr << " Average nodes = " << (tn / tt) << endl;
+ cerr << "Average out-degree = " << (to / tn) << endl;
+ cerr << " Unique structures = " << tt << endl;
}
- vector<unordered_map<WordID, TInfo> > graphs;
- vector<GrammarPtr> grammars;
+ vector<vector<GraphStructure> > graphs; // graphs[src_len][trg_len]
};
Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {}
@@ -178,16 +97,15 @@ void Transliterations::Initialize(WordID src, const vector<WordID>& src_lets, Wo
pimpl_->Initialize(src, src_lets, trg, trg_lets);
}
-prob_t Transliterations::EstimateProbability(WordID src, WordID trg) const {
- return pimpl_->EstimateProbability(src,trg);
+prob_t Transliterations::EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const {
+ return pimpl_->EstimateProbability(s, src,t, trg);
}
-void Transliterations::Forbid(WordID src, WordID trg) {
- pimpl_->Forbid(src, trg);
+void Transliterations::Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
+ pimpl_->Forbid(src, src_lets, trg, trg_lets);
}
void Transliterations::GraphSummary() const {
pimpl_->GraphSummary();
}
-
diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h
index a548aacf..76eb2a05 100644
--- a/gi/pf/transliterations.h
+++ b/gi/pf/transliterations.h
@@ -10,9 +10,10 @@ struct Transliterations {
explicit Transliterations();
~Transliterations();
void Initialize(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);
- void Forbid(WordID src, WordID trg);
+ void Forbid(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);
void GraphSummary() const;
- prob_t EstimateProbability(WordID src, WordID trg) const;
+ prob_t EstimateProbability(WordID s, const std::vector<WordID>& src, WordID t, const std::vector<WordID>& trg) const;
+ private:
TransliterationsImpl* pimpl_;
};