summaryrefslogtreecommitdiff
path: root/gi/pyp-topics
diff options
context:
space:
mode:
Diffstat (limited to 'gi/pyp-topics')
-rw-r--r--gi/pyp-topics/src/contexts_corpus.cc69
-rw-r--r--gi/pyp-topics/src/contexts_corpus.hh35
-rw-r--r--gi/pyp-topics/src/corpus.cc2
-rw-r--r--gi/pyp-topics/src/corpus.hh28
-rw-r--r--gi/pyp-topics/src/pyp-topics.cc2
-rw-r--r--gi/pyp-topics/src/pyp-topics.hh2
-rw-r--r--gi/pyp-topics/src/pyp.hh2
-rw-r--r--gi/pyp-topics/src/train.cc21
8 files changed, 138 insertions, 23 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc
index 0b3ec644..afa1e19a 100644
--- a/gi/pyp-topics/src/contexts_corpus.cc
+++ b/gi/pyp-topics/src/contexts_corpus.cc
@@ -15,27 +15,59 @@ using namespace std;
void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* extra) {
assert(new_contexts.contexts.size() == new_contexts.counts.size());
- ContextsCorpus* corpus_ptr = static_cast<ContextsCorpus*>(extra);
+ std::pair<ContextsCorpus*, BackoffGenerator*>* extra_pair
+ = static_cast< std::pair<ContextsCorpus*, BackoffGenerator*>* >(extra);
+
+ ContextsCorpus* corpus_ptr = extra_pair->first;
+ BackoffGenerator* backoff_gen = extra_pair->second;
+
Document* doc(new Document());
//std::cout << "READ: " << new_contexts.phrase << "\t";
-
for (int i=0; i < new_contexts.contexts.size(); ++i) {
- std::string context_str = "";
- for (ContextsLexer::Context::const_iterator it=new_contexts.contexts[i].begin();
- it != new_contexts.contexts[i].end(); ++it) {
- //std::cout << *it << " ";
- if (it != new_contexts.contexts[i].begin())
- context_str += "__";
- context_str += *it;
+ int cache_word_count = corpus_ptr->m_dict.max();
+ WordID id = corpus_ptr->m_dict.Convert(new_contexts.contexts[i]);
+ if (cache_word_count != corpus_ptr->m_dict.max()) {
+ corpus_ptr->m_backoff->terms_at_level(0)++;
+ corpus_ptr->m_num_types++;
}
- WordID id = corpus_ptr->m_dict.Convert(context_str);
int count = new_contexts.counts[i];
- for (int i=0; i<count; ++i)
+ for (int j=0; j<count; ++j)
doc->push_back(id);
corpus_ptr->m_num_terms += count;
+ // generate the backoff map
+ if (backoff_gen) {
+ int order = 1;
+ WordID backoff_id = id;
+ ContextsLexer::Context backedoff_context = new_contexts.contexts[i];
+ while (true) {
+ if (!corpus_ptr->m_backoff->has_backoff(backoff_id)) {
+ //std::cerr << "Backing off from " << corpus_ptr->m_dict.Convert(backoff_id) << " to ";
+ backedoff_context = (*backoff_gen)(backedoff_context);
+
+ if (backedoff_context.empty()) {
+ //std::cerr << "Nothing." << std::endl;
+ (*corpus_ptr->m_backoff)[backoff_id] = -1;
+ break;
+ }
+
+ if (++order > corpus_ptr->m_backoff->order())
+ corpus_ptr->m_backoff->order(order);
+
+ int cache_word_count = corpus_ptr->m_dict.max();
+ int new_backoff_id = corpus_ptr->m_dict.Convert(backedoff_context);
+ if (cache_word_count != corpus_ptr->m_dict.max())
+ corpus_ptr->m_backoff->terms_at_level(order-1)++;
+
+ //std::cerr << corpus_ptr->m_dict.Convert(new_backoff_id) << " ." << std::endl;
+
+ backoff_id = ((*corpus_ptr->m_backoff)[backoff_id] = new_backoff_id);
+ }
+ else break;
+ }
+ }
//std::cout << context_str << " (" << id << ") ||| C=" << count << " ||| ";
}
//std::cout << std::endl;
@@ -43,14 +75,23 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void*
corpus_ptr->m_documents.push_back(doc);
}
-unsigned ContextsCorpus::read_contexts(const std::string &filename) {
+unsigned ContextsCorpus::read_contexts(const std::string &filename,
+ BackoffGenerator* backoff_gen_ptr) {
m_num_terms = 0;
m_num_types = 0;
igzstream in(filename.c_str());
- ContextsLexer::ReadContexts(&in, read_callback, this);
+ std::pair<ContextsCorpus*, BackoffGenerator*> extra_pair(this,backoff_gen_ptr);
+ ContextsLexer::ReadContexts(&in,
+ read_callback,
+ &extra_pair);
+
+ //m_num_types = m_dict.max();
- m_num_types = m_dict.max();
+ std::cerr << "Read backoff with order " << m_backoff->order() << "\n";
+ for (int o=0; o<m_backoff->order(); o++)
+ std::cerr << " Terms at " << o << " = " << m_backoff->terms_at_level(o) << std::endl;
+ std::cerr << std::endl;
return m_documents.size();
}
diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh
index e680cef5..bd0cd34c 100644
--- a/gi/pyp-topics/src/contexts_corpus.hh
+++ b/gi/pyp-topics/src/contexts_corpus.hh
@@ -11,6 +11,36 @@
#include "contexts_lexer.h"
#include "../../../decoder/dict.h"
+
+class BackoffGenerator {
+public:
+ virtual ContextsLexer::Context
+ operator()(const ContextsLexer::Context& c) = 0;
+
+protected:
+ ContextsLexer::Context strip_edges(const ContextsLexer::Context& c) {
+ if (c.size() <= 1) return ContextsLexer::Context();
+ assert(c.size() % 2 == 1);
+ return ContextsLexer::Context(c.begin() + 1, c.end() - 1);
+ }
+};
+
+class NullBackoffGenerator : public BackoffGenerator {
+ virtual ContextsLexer::Context
+ operator()(const ContextsLexer::Context&)
+ { return ContextsLexer::Context(); }
+};
+
+class SimpleBackoffGenerator : public BackoffGenerator {
+ virtual ContextsLexer::Context
+ operator()(const ContextsLexer::Context& c) {
+ if (c.size() <= 3)
+ return ContextsLexer::Context();
+ return strip_edges(c);
+ }
+};
+
+
////////////////////////////////////////////////////////////////
// ContextsCorpus
////////////////////////////////////////////////////////////////
@@ -22,10 +52,11 @@ public:
typedef boost::ptr_vector<Document>::const_iterator const_iterator;
public:
- ContextsCorpus() {}
+ ContextsCorpus() : m_backoff(new TermBackoff) {}
virtual ~ContextsCorpus() {}
- unsigned read_contexts(const std::string &filename);
+ unsigned read_contexts(const std::string &filename,
+ BackoffGenerator* backoff_gen=0);
TermBackoffPtr backoff_index() {
return m_backoff;
diff --git a/gi/pyp-topics/src/corpus.cc b/gi/pyp-topics/src/corpus.cc
index 24b93a03..f182381f 100644
--- a/gi/pyp-topics/src/corpus.cc
+++ b/gi/pyp-topics/src/corpus.cc
@@ -11,7 +11,7 @@ using namespace std;
// Corpus
//////////////////////////////////////////////////
-Corpus::Corpus() {}
+Corpus::Corpus() : m_num_terms(0), m_num_types(0) {}
unsigned Corpus::read(const std::string &filename) {
m_num_terms = 0;
diff --git a/gi/pyp-topics/src/corpus.hh b/gi/pyp-topics/src/corpus.hh
index 243f7e2c..c2f37130 100644
--- a/gi/pyp-topics/src/corpus.hh
+++ b/gi/pyp-topics/src/corpus.hh
@@ -22,7 +22,7 @@ public:
public:
Corpus();
- ~Corpus() {}
+ virtual ~Corpus() {}
unsigned read(const std::string &filename);
@@ -71,9 +71,10 @@ class TermBackoff {
public:
typedef std::vector<Term> dictionary_type;
typedef dictionary_type::const_iterator const_iterator;
+ const static int NullBackoff=-1;
public:
- TermBackoff() : m_backoff_order(-1) {}
+ TermBackoff() { order(1); }
~TermBackoff() {}
void read(const std::string &filename);
@@ -86,12 +87,33 @@ public:
return m_dict[t];
}
+ Term& operator[](const Term& t) {
+ if (t >= static_cast<int>(m_dict.size()))
+ m_dict.resize(t+1, -1);
+ return m_dict[t];
+ }
+
+ bool has_backoff(const Term& t) {
+ return t >= 0 && t < static_cast<int>(m_dict.size()) && m_dict[t] >= 0;
+ }
+
int order() const { return m_backoff_order; }
+ void order(int o) {
+ if (o >= (int)m_terms_at_order.size())
+ m_terms_at_order.resize(o, 0);
+ m_backoff_order = o;
+ }
+
// int levels() const { return m_terms_at_order.size(); }
bool is_null(const Term& term) const { return term < 0; }
int terms_at_level(int level) const {
assert (level < (int)m_terms_at_order.size());
- return m_terms_at_order[level];
+ return m_terms_at_order.at(level);
+ }
+
+ int& terms_at_level(int level) {
+ assert (level < (int)m_terms_at_order.size());
+ return m_terms_at_order.at(level);
}
int size() const { return m_dict.size(); }
diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc
index f3369f2e..c5fd728e 100644
--- a/gi/pyp-topics/src/pyp-topics.cc
+++ b/gi/pyp-topics/src/pyp-topics.cc
@@ -173,7 +173,7 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con
Term backoff_term = (*m_backoff)[term];
if (!m_backoff->is_null(backoff_term)) {
assert (level < m_backoff->order());
- p0 = m_backoff->terms_at_level(level)*prob(backoff_term, topic, level+1);
+ p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1);
}
else
p0 = m_term_p0;
diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh
index 92d6f292..6b0b15f9 100644
--- a/gi/pyp-topics/src/pyp-topics.hh
+++ b/gi/pyp-topics/src/pyp-topics.hh
@@ -29,6 +29,8 @@ public:
}
void set_backoff(TermBackoffPtr backoff) {
m_backoff = backoff;
+ m_word_pyps.clear();
+ m_word_pyps.resize(m_backoff->order(), PYPs());
}
F prob(const Term& term, int topic, int level=0) const;
diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh
index b06c6021..874d84f5 100644
--- a/gi/pyp-topics/src/pyp.hh
+++ b/gi/pyp-topics/src/pyp.hh
@@ -121,7 +121,7 @@ private:
};
template <typename Dish, typename Hash>
-PYP<Dish,Hash>::PYP(long double a, long double b, Hash cmp)
+PYP<Dish,Hash>::PYP(long double a, long double b, Hash)
: std::tr1::unordered_map<Dish, int, Hash>(), _a(a), _b(b),
_a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1),
//_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1),
diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc
index 01ada182..a8fd994c 100644
--- a/gi/pyp-topics/src/train.cc
+++ b/gi/pyp-topics/src/train.cc
@@ -46,6 +46,7 @@ int main(int argc, char **argv)
("samples,s", value<int>()->default_value(10), "number of sampling passes through the data")
("test-corpus", value<string>(), "file containing the test data")
("backoff-paths", value<string>(), "file containing the term backoff paths")
+ ("backoff-type", value<string>(), "backoff type: none|simple")
;
options_description config_options, cmdline_options;
config_options.add(generic);
@@ -91,10 +92,27 @@ int main(int argc, char **argv)
}
else {
+ BackoffGenerator* backoff_gen=0;
+ if (vm.count("backoff-type")) {
+ if (vm["backoff-type"].as<std::string>() == "none") {
+ backoff_gen = 0;
+ }
+ else if (vm["backoff-type"].as<std::string>() == "simple") {
+ backoff_gen = new SimpleBackoffGenerator();
+ }
+ else {
+ std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl;
+ return(1);
+ }
+ }
+
boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus);
- contexts_corpus->read_contexts(vm["contexts"].as<string>());
+ contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen);
corpus = contexts_corpus;
model.set_backoff(contexts_corpus->backoff_index());
+
+ if (backoff_gen)
+ delete backoff_gen;
}
// train the sampler
@@ -146,6 +164,7 @@ int main(int argc, char **argv)
}
topics_out.close();
}
+ std::cout << std::endl;
return 0;
}