diff options
Diffstat (limited to 'gi/pyp-topics')
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.cc | 69 | ||||
-rw-r--r-- | gi/pyp-topics/src/contexts_corpus.hh | 35 | ||||
-rw-r--r-- | gi/pyp-topics/src/corpus.cc | 2 | ||||
-rw-r--r-- | gi/pyp-topics/src/corpus.hh | 28 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.cc | 2 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp-topics.hh | 2 | ||||
-rw-r--r-- | gi/pyp-topics/src/pyp.hh | 2 | ||||
-rw-r--r-- | gi/pyp-topics/src/train.cc | 21 |
8 files changed, 138 insertions, 23 deletions
diff --git a/gi/pyp-topics/src/contexts_corpus.cc b/gi/pyp-topics/src/contexts_corpus.cc index 0b3ec644..afa1e19a 100644 --- a/gi/pyp-topics/src/contexts_corpus.cc +++ b/gi/pyp-topics/src/contexts_corpus.cc @@ -15,27 +15,59 @@ using namespace std; void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* extra) { assert(new_contexts.contexts.size() == new_contexts.counts.size()); - ContextsCorpus* corpus_ptr = static_cast<ContextsCorpus*>(extra); + std::pair<ContextsCorpus*, BackoffGenerator*>* extra_pair + = static_cast< std::pair<ContextsCorpus*, BackoffGenerator*>* >(extra); + + ContextsCorpus* corpus_ptr = extra_pair->first; + BackoffGenerator* backoff_gen = extra_pair->second; + Document* doc(new Document()); //std::cout << "READ: " << new_contexts.phrase << "\t"; - for (int i=0; i < new_contexts.contexts.size(); ++i) { - std::string context_str = ""; - for (ContextsLexer::Context::const_iterator it=new_contexts.contexts[i].begin(); - it != new_contexts.contexts[i].end(); ++it) { - //std::cout << *it << " "; - if (it != new_contexts.contexts[i].begin()) - context_str += "__"; - context_str += *it; + int cache_word_count = corpus_ptr->m_dict.max(); + WordID id = corpus_ptr->m_dict.Convert(new_contexts.contexts[i]); + if (cache_word_count != corpus_ptr->m_dict.max()) { + corpus_ptr->m_backoff->terms_at_level(0)++; + corpus_ptr->m_num_types++; } - WordID id = corpus_ptr->m_dict.Convert(context_str); int count = new_contexts.counts[i]; - for (int i=0; i<count; ++i) + for (int j=0; j<count; ++j) doc->push_back(id); corpus_ptr->m_num_terms += count; + // generate the backoff map + if (backoff_gen) { + int order = 1; + WordID backoff_id = id; + ContextsLexer::Context backedoff_context = new_contexts.contexts[i]; + while (true) { + if (!corpus_ptr->m_backoff->has_backoff(backoff_id)) { + //std::cerr << "Backing off from " << corpus_ptr->m_dict.Convert(backoff_id) << " to "; + backedoff_context = (*backoff_gen)(backedoff_context); + + if (backedoff_context.empty()) { + //std::cerr << "Nothing." << std::endl; + (*corpus_ptr->m_backoff)[backoff_id] = -1; + break; + } + + if (++order > corpus_ptr->m_backoff->order()) + corpus_ptr->m_backoff->order(order); + + int cache_word_count = corpus_ptr->m_dict.max(); + int new_backoff_id = corpus_ptr->m_dict.Convert(backedoff_context); + if (cache_word_count != corpus_ptr->m_dict.max()) + corpus_ptr->m_backoff->terms_at_level(order-1)++; + + //std::cerr << corpus_ptr->m_dict.Convert(new_backoff_id) << " ." << std::endl; + + backoff_id = ((*corpus_ptr->m_backoff)[backoff_id] = new_backoff_id); + } + else break; + } + } //std::cout << context_str << " (" << id << ") ||| C=" << count << " ||| "; } //std::cout << std::endl; @@ -43,14 +75,23 @@ void read_callback(const ContextsLexer::PhraseContextsType& new_contexts, void* corpus_ptr->m_documents.push_back(doc); } -unsigned ContextsCorpus::read_contexts(const std::string &filename) { +unsigned ContextsCorpus::read_contexts(const std::string &filename, + BackoffGenerator* backoff_gen_ptr) { m_num_terms = 0; m_num_types = 0; igzstream in(filename.c_str()); - ContextsLexer::ReadContexts(&in, read_callback, this); + std::pair<ContextsCorpus*, BackoffGenerator*> extra_pair(this,backoff_gen_ptr); + ContextsLexer::ReadContexts(&in, + read_callback, + &extra_pair); + + //m_num_types = m_dict.max(); - m_num_types = m_dict.max(); + std::cerr << "Read backoff with order " << m_backoff->order() << "\n"; + for (int o=0; o<m_backoff->order(); o++) + std::cerr << " Terms at " << o << " = " << m_backoff->terms_at_level(o) << std::endl; + std::cerr << std::endl; return m_documents.size(); } diff --git a/gi/pyp-topics/src/contexts_corpus.hh b/gi/pyp-topics/src/contexts_corpus.hh index e680cef5..bd0cd34c 100644 --- a/gi/pyp-topics/src/contexts_corpus.hh +++ b/gi/pyp-topics/src/contexts_corpus.hh @@ -11,6 +11,36 @@ #include "contexts_lexer.h" #include "../../../decoder/dict.h" + +class BackoffGenerator { +public: + virtual ContextsLexer::Context + operator()(const ContextsLexer::Context& c) = 0; + +protected: + ContextsLexer::Context strip_edges(const ContextsLexer::Context& c) { + if (c.size() <= 1) return ContextsLexer::Context(); + assert(c.size() % 2 == 1); + return ContextsLexer::Context(c.begin() + 1, c.end() - 1); + } +}; + +class NullBackoffGenerator : public BackoffGenerator { + virtual ContextsLexer::Context + operator()(const ContextsLexer::Context&) + { return ContextsLexer::Context(); } +}; + +class SimpleBackoffGenerator : public BackoffGenerator { + virtual ContextsLexer::Context + operator()(const ContextsLexer::Context& c) { + if (c.size() <= 3) + return ContextsLexer::Context(); + return strip_edges(c); + } +}; + + //////////////////////////////////////////////////////////////// // ContextsCorpus //////////////////////////////////////////////////////////////// @@ -22,10 +52,11 @@ public: typedef boost::ptr_vector<Document>::const_iterator const_iterator; public: - ContextsCorpus() {} + ContextsCorpus() : m_backoff(new TermBackoff) {} virtual ~ContextsCorpus() {} - unsigned read_contexts(const std::string &filename); + unsigned read_contexts(const std::string &filename, + BackoffGenerator* backoff_gen=0); TermBackoffPtr backoff_index() { return m_backoff; diff --git a/gi/pyp-topics/src/corpus.cc b/gi/pyp-topics/src/corpus.cc index 24b93a03..f182381f 100644 --- a/gi/pyp-topics/src/corpus.cc +++ b/gi/pyp-topics/src/corpus.cc @@ -11,7 +11,7 @@ using namespace std; // Corpus ////////////////////////////////////////////////// -Corpus::Corpus() {} +Corpus::Corpus() : m_num_terms(0), m_num_types(0) {} unsigned Corpus::read(const std::string &filename) { m_num_terms = 0; diff --git a/gi/pyp-topics/src/corpus.hh b/gi/pyp-topics/src/corpus.hh index 243f7e2c..c2f37130 100644 --- a/gi/pyp-topics/src/corpus.hh +++ b/gi/pyp-topics/src/corpus.hh @@ -22,7 +22,7 @@ public: public: Corpus(); - ~Corpus() {} + virtual ~Corpus() {} unsigned read(const std::string &filename); @@ -71,9 +71,10 @@ class TermBackoff { public: typedef std::vector<Term> dictionary_type; typedef dictionary_type::const_iterator const_iterator; + const static int NullBackoff=-1; public: - TermBackoff() : m_backoff_order(-1) {} + TermBackoff() { order(1); } ~TermBackoff() {} void read(const std::string &filename); @@ -86,12 +87,33 @@ public: return m_dict[t]; } + Term& operator[](const Term& t) { + if (t >= static_cast<int>(m_dict.size())) + m_dict.resize(t+1, -1); + return m_dict[t]; + } + + bool has_backoff(const Term& t) { + return t >= 0 && t < static_cast<int>(m_dict.size()) && m_dict[t] >= 0; + } + int order() const { return m_backoff_order; } + void order(int o) { + if (o >= (int)m_terms_at_order.size()) + m_terms_at_order.resize(o, 0); + m_backoff_order = o; + } + // int levels() const { return m_terms_at_order.size(); } bool is_null(const Term& term) const { return term < 0; } int terms_at_level(int level) const { assert (level < (int)m_terms_at_order.size()); - return m_terms_at_order[level]; + return m_terms_at_order.at(level); + } + + int& terms_at_level(int level) { + assert (level < (int)m_terms_at_order.size()); + return m_terms_at_order.at(level); } int size() const { return m_dict.size(); } diff --git a/gi/pyp-topics/src/pyp-topics.cc b/gi/pyp-topics/src/pyp-topics.cc index f3369f2e..c5fd728e 100644 --- a/gi/pyp-topics/src/pyp-topics.cc +++ b/gi/pyp-topics/src/pyp-topics.cc @@ -173,7 +173,7 @@ PYPTopics::F PYPTopics::word_pyps_p0(const Term& term, int topic, int level) con Term backoff_term = (*m_backoff)[term]; if (!m_backoff->is_null(backoff_term)) { assert (level < m_backoff->order()); - p0 = m_backoff->terms_at_level(level)*prob(backoff_term, topic, level+1); + p0 = (1.0/(double)m_backoff->terms_at_level(level))*prob(backoff_term, topic, level+1); } else p0 = m_term_p0; diff --git a/gi/pyp-topics/src/pyp-topics.hh b/gi/pyp-topics/src/pyp-topics.hh index 92d6f292..6b0b15f9 100644 --- a/gi/pyp-topics/src/pyp-topics.hh +++ b/gi/pyp-topics/src/pyp-topics.hh @@ -29,6 +29,8 @@ public: } void set_backoff(TermBackoffPtr backoff) { m_backoff = backoff; + m_word_pyps.clear(); + m_word_pyps.resize(m_backoff->order(), PYPs()); } F prob(const Term& term, int topic, int level=0) const; diff --git a/gi/pyp-topics/src/pyp.hh b/gi/pyp-topics/src/pyp.hh index b06c6021..874d84f5 100644 --- a/gi/pyp-topics/src/pyp.hh +++ b/gi/pyp-topics/src/pyp.hh @@ -121,7 +121,7 @@ private: }; template <typename Dish, typename Hash> -PYP<Dish,Hash>::PYP(long double a, long double b, Hash cmp) +PYP<Dish,Hash>::PYP(long double a, long double b, Hash) : std::tr1::unordered_map<Dish, int, Hash>(), _a(a), _b(b), _a_beta_a(1), _a_beta_b(1), _b_gamma_s(1), _b_gamma_c(1), //_a_beta_a(1), _a_beta_b(1), _b_gamma_s(10), _b_gamma_c(0.1), diff --git a/gi/pyp-topics/src/train.cc b/gi/pyp-topics/src/train.cc index 01ada182..a8fd994c 100644 --- a/gi/pyp-topics/src/train.cc +++ b/gi/pyp-topics/src/train.cc @@ -46,6 +46,7 @@ int main(int argc, char **argv) ("samples,s", value<int>()->default_value(10), "number of sampling passes through the data") ("test-corpus", value<string>(), "file containing the test data") ("backoff-paths", value<string>(), "file containing the term backoff paths") + ("backoff-type", value<string>(), "backoff type: none|simple") ; options_description config_options, cmdline_options; config_options.add(generic); @@ -91,10 +92,27 @@ int main(int argc, char **argv) } else { + BackoffGenerator* backoff_gen=0; + if (vm.count("backoff-type")) { + if (vm["backoff-type"].as<std::string>() == "none") { + backoff_gen = 0; + } + else if (vm["backoff-type"].as<std::string>() == "simple") { + backoff_gen = new SimpleBackoffGenerator(); + } + else { + std::cerr << "Backoff type (--backoff-type) must be one of none|simple." << std::endl; + return(1); + } + } + boost::shared_ptr<ContextsCorpus> contexts_corpus(new ContextsCorpus); - contexts_corpus->read_contexts(vm["contexts"].as<string>()); + contexts_corpus->read_contexts(vm["contexts"].as<string>(), backoff_gen); corpus = contexts_corpus; model.set_backoff(contexts_corpus->backoff_index()); + + if (backoff_gen) + delete backoff_gen; } // train the sampler @@ -146,6 +164,7 @@ int main(int argc, char **argv) } topics_out.close(); } + std::cout << std::endl; return 0; } |