summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2012-03-08 13:32:41 -0500
committerChris Dyer <cdyer@cs.cmu.edu>2012-03-08 13:32:41 -0500
commit9a8256604686a9283e7afce04e6feaab4922dd45 (patch)
treea49c4d714d7ac59907e91a4a559fcd86e5951324
parente2998fc79c9dd549b1c1bad537fdf1052276f82c (diff)
tl stuff
-rw-r--r--gi/pf/Makefile.am8
-rw-r--r--gi/pf/align-tl.cc8
-rw-r--r--gi/pf/reachability.cc17
-rw-r--r--gi/pf/reachability.h4
-rw-r--r--gi/pf/transliterations.cc82
-rw-r--r--gi/pf/transliterations.h3
6 files changed, 88 insertions, 34 deletions
diff --git a/gi/pf/Makefile.am b/gi/pf/Makefile.am
index 5e89f02a..9888a70b 100644
--- a/gi/pf/Makefile.am
+++ b/gi/pf/Makefile.am
@@ -2,15 +2,17 @@ bin_PROGRAMS = cbgi brat dpnaive pfbrat pfdist itg pfnaive condnaive align-lexon
noinst_LIBRARIES = libpf.a
-libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc
+libpf_a_SOURCES = base_distributions.cc reachability.cc cfg_wfst_composer.cc corpus.cc unigrams.cc ngram_base.cc transliterations.cc
-nuisance_test_SOURCES = nuisance_test.cc transliterations.cc
+nuisance_test_SOURCES = nuisance_test.cc
+nuisance_test_LDADD = libpf.a
align_lexonly_SOURCES = align-lexonly.cc
align_lexonly_pyp_SOURCES = align-lexonly-pyp.cc
-align_tl_SOURCES = align-tl.cc transliterations.cc
+align_tl_SOURCES = align-tl.cc
+align_tl_LDADD = libpf.a
itg_SOURCES = itg.cc
diff --git a/gi/pf/align-tl.cc b/gi/pf/align-tl.cc
index 6bb8c886..fe8950b5 100644
--- a/gi/pf/align-tl.cc
+++ b/gi/pf/align-tl.cc
@@ -305,7 +305,10 @@ int main(int argc, char** argv) {
ExtractLetters(vocabf, &letters, NULL);
letters[TD::Convert("NULL")].clear();
- Transliterations tl;
+ // TODO configure this
+ int max_src_chunk = 4;
+ int max_trg_chunk = 4;
+ Transliterations tl(max_src_chunk, max_trg_chunk);
// TODO CONFIGURE THIS
int min_trans_src = 4;
@@ -318,10 +321,9 @@ int main(int argc, char** argv) {
const vector<int>& src_let = letters[src[j]];
for (int k = 0; k < trg.size(); ++k) {
const vector<int>& trg_let = letters[trg[k]];
+ tl.Initialize(src[j], src_let, trg[k], trg_let);
if (src_let.size() < min_trans_src)
tl.Forbid(src[j], src_let, trg[k], trg_let);
- else
- tl.Initialize(src[j], src_let, trg[k], trg_let);
}
}
}
diff --git a/gi/pf/reachability.cc b/gi/pf/reachability.cc
index 70fb76da..59bc6ace 100644
--- a/gi/pf/reachability.cc
+++ b/gi/pf/reachability.cc
@@ -39,6 +39,7 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
typedef boost::multi_array<bool, 2> rarray_type;
rarray_type r(boost::extents[srclen + 1][trglen + 1]);
r[srclen][trglen] = true;
+ nodes = 0;
for (int i = srclen; i >= 0; --i) {
for (int j = trglen; j >= 0; --j) {
vector<SState>& prevs = a[i][j];
@@ -57,10 +58,16 @@ void Reachability::ComputeReachability(int srclen, int trglen, int src_max_phras
assert(!edges[0][0][0][1]);
assert(!edges[0][0][0][0]);
assert(max_src_delta[0][0] > 0);
- cerr << "Sentence with length (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node\n";
- //cerr << "First cell contains " << b[0][0].size() << " forward pointers\n";
- //for (int i = 0; i < b[0][0].size(); ++i) {
- // cerr << " -> (" << b[0][0][i].next_src_covered << "," << b[0][0][i].next_trg_covered << ")\n";
- //}
+ nodes = 0;
+ for (int i = 0; i < srclen; ++i) {
+ for (int j = 0; j < trglen; ++j) {
+ if (valid_deltas[i][j].size() > 0) {
+ node_addresses[i][j] = nodes++;
+ } else {
+ node_addresses[i][j] = -1;
+ }
+ }
+ }
+ cerr << "Sequence pair with lengths (" << srclen << ',' << trglen << ") has " << valid_deltas[0][0].size() << " out edges in its root node, " << nodes << " nodes in total, and outside estimate matrix will require " << sizeof(float)*nodes << " bytes\n";
}
diff --git a/gi/pf/reachability.h b/gi/pf/reachability.h
index fb2f4965..1e22c76a 100644
--- a/gi/pf/reachability.h
+++ b/gi/pf/reachability.h
@@ -12,13 +12,17 @@
// currently forbids 0 -> n and n -> 0 alignments
struct Reachability {
+ unsigned nodes;
boost::multi_array<bool, 4> edges; // edges[src_covered][trg_covered][src_delta][trg_delta] is this edge worth exploring?
boost::multi_array<short, 2> max_src_delta; // msd[src_covered][trg_covered] -- the largest src delta that's valid
+ boost::multi_array<short, 2> node_addresses; // na[src_covered][trg_covered] -- the index of the node in a one-dimensional array (of size "nodes")
boost::multi_array<std::vector<std::pair<short,short> >, 2> valid_deltas; // valid_deltas[src_covered][trg_covered] list of valid transitions leaving a particular node
Reachability(int srclen, int trglen, int src_max_phrase_len, int trg_max_phrase_len) :
+ nodes(),
edges(boost::extents[srclen][trglen][src_max_phrase_len+1][trg_max_phrase_len+1]),
max_src_delta(boost::extents[srclen][trglen]),
+ node_addresses(boost::extents[srclen][trglen]),
valid_deltas(boost::extents[srclen][trglen]) {
ComputeReachability(srclen, trglen, src_max_phrase_len, trg_max_phrase_len);
}
diff --git a/gi/pf/transliterations.cc b/gi/pf/transliterations.cc
index e29334fd..61e95b82 100644
--- a/gi/pf/transliterations.cc
+++ b/gi/pf/transliterations.cc
@@ -14,42 +14,75 @@ using namespace std;
using namespace std::tr1;
struct GraphStructure {
- GraphStructure() : initialized(false) {}
- boost::shared_ptr<Reachability> r;
- bool initialized;
+ GraphStructure() : r() {}
+ // leak memory - these are basically static
+ const Reachability* r;
+ bool IsReachable() const { return r->nodes > 0; }
+};
+
+struct BackwardEstimates {
+ BackwardEstimates() : gs(), backward() {}
+ explicit BackwardEstimates(const GraphStructure& g) :
+ gs(&g), backward() {
+ if (g.r->nodes > 0)
+ backward = new float[g.r->nodes];
+ }
+ // leak memory, these are static
+
+ // returns an estimate of the marginal probability
+ double MarginalEstimate() const {
+ if (!backward) return 0;
+ return backward[0];
+ }
+
+ // returns an backward estimate
+ double operator()(int src_covered, int trg_covered) const {
+ if (!backward) return 0;
+ int ind = gs->r->node_addresses[src_covered][trg_covered];
+ if (ind < 0) return 0;
+ return backward[ind];
+ }
+ private:
+ const GraphStructure* gs;
+ float* backward;
};
struct TransliterationsImpl {
- TransliterationsImpl() {
+ TransliterationsImpl(int max_src, int max_trg) :
+ kMAX_SRC_CHUNK(max_src),
+ kMAX_TRG_CHUNK(max_trg),
+ tot_pairs() {
}
void Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
const size_t src_len = src_lets.size();
const size_t trg_len = trg_lets.size();
+
+ // init graph structure
if (src_len >= graphs.size()) graphs.resize(src_len + 1);
if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
- if (graphs[src_len][trg_len].initialized) return;
- graphs[src_len][trg_len].r.reset(new Reachability(src_len, trg_len, 4, 4));
-
-#if 0
- if (HG::Intersect(tlat, &hg)) {
- // TODO
- } else {
- cerr << "No transliteration lattice possible for src_len=" << src_len << " trg_len=" << trg_len << endl;
- hg.clear();
- }
- //cerr << "Number of paths: " << graphs[src][trg].lattice.NumberOfPaths() << endl;
-#endif
- graphs[src_len][trg_len].initialized = true;
+ GraphStructure& gs = graphs[src_len][trg_len];
+ if (!gs.r)
+ gs.r = new Reachability(src_len, trg_len, kMAX_SRC_CHUNK, kMAX_TRG_CHUNK);
+ const Reachability& r = *gs.r;
+
+ // init backward estimates
+ if (src >= bes.size()) bes.resize(src + 1);
+ unordered_map<WordID, BackwardEstimates>::iterator it = bes[src].find(trg);
+ if (it != bes[src].end()) return; // already initialized
+
+ it = bes[src].insert(make_pair(trg, BackwardEstimates(gs))).first;
+ BackwardEstimates& b = it->second;
+ if (!gs.r->nodes) return; // not derivable subject to length constraints
+
+ // TODO
+ tot_pairs++;
}
void Forbid(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
const size_t src_len = src_lets.size();
const size_t trg_len = trg_lets.size();
- if (src_len >= graphs.size()) graphs.resize(src_len + 1);
- if (trg_len >= graphs[src_len].size()) graphs[src_len].resize(trg_len + 1);
- graphs[src_len][trg_len].r.reset();
- graphs[src_len][trg_len].initialized = true;
+ // TODO
}
prob_t EstimateProbability(WordID s, const vector<WordID>& src, WordID t, const vector<WordID>& trg) const {
@@ -85,12 +118,17 @@ struct TransliterationsImpl {
cerr << " Average nodes = " << (tn / tt) << endl;
cerr << "Average out-degree = " << (to / tn) << endl;
cerr << " Unique structures = " << tt << endl;
+ cerr << " Unique pairs = " << tot_pairs << endl;
}
+ const int kMAX_SRC_CHUNK;
+ const int kMAX_TRG_CHUNK;
+ unsigned tot_pairs;
vector<vector<GraphStructure> > graphs; // graphs[src_len][trg_len]
+ vector<unordered_map<WordID, BackwardEstimates> > bes; // bes[src][trg]
};
-Transliterations::Transliterations() : pimpl_(new TransliterationsImpl) {}
+Transliterations::Transliterations(int max_src, int max_trg) : pimpl_(new TransliterationsImpl(max_src, max_trg)) {}
Transliterations::~Transliterations() { delete pimpl_; }
void Transliterations::Initialize(WordID src, const vector<WordID>& src_lets, WordID trg, const vector<WordID>& trg_lets) {
diff --git a/gi/pf/transliterations.h b/gi/pf/transliterations.h
index 76eb2a05..e025547e 100644
--- a/gi/pf/transliterations.h
+++ b/gi/pf/transliterations.h
@@ -7,7 +7,8 @@
struct TransliterationsImpl;
struct Transliterations {
- explicit Transliterations();
+ // max_src and max_trg indicate how big the transliteration phrases can be
+ explicit Transliterations(int max_src, int max_trg);
~Transliterations();
void Initialize(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);
void Forbid(WordID src, const std::vector<WordID>& src_lets, WordID trg, const std::vector<WordID>& trg_lets);