diff options
Diffstat (limited to 'decoder/ff_csplit.cc')
-rw-r--r-- | decoder/ff_csplit.cc | 212 |
1 files changed, 212 insertions, 0 deletions
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc new file mode 100644 index 00000000..cac4bb8e --- /dev/null +++ b/decoder/ff_csplit.cc @@ -0,0 +1,212 @@ +#include "ff_csplit.h" + +#include <set> +#include <cstring> + +#include "Vocab.h" +#include "Ngram.h" + +#include "sentence_metadata.h" +#include "lattice.h" +#include "tdict.h" +#include "freqdict.h" +#include "filelib.h" +#include "stringlib.h" +#include "tdict.h" + +using namespace std; + +struct BasicCSplitFeaturesImpl { + BasicCSplitFeaturesImpl(const string& param) : + word_count_(FD::Convert("WordCount")), + letters_sq_(FD::Convert("LettersSq")), + letters_sqrt_(FD::Convert("LettersSqrt")), + in_dict_(FD::Convert("InDict")), + short_(FD::Convert("Short")), + long_(FD::Convert("Long")), + oov_(FD::Convert("OOV")), + short_range_(FD::Convert("ShortRange")), + high_freq_(FD::Convert("HighFreq")), + med_freq_(FD::Convert("MedFreq")), + freq_(FD::Convert("Freq")), + fl1_(FD::Convert("FreqLen1")), + fl2_(FD::Convert("FreqLen2")), + bad_(FD::Convert("Bad")) { + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc != 1 && argc != 2) { + cerr << "Expected: freqdict.txt [badwords.txt]\n"; + abort(); + } + freq_dict_.Load(argv[0]); + if (argc == 2) { + ReadFile rf(argv[1]); + istream& in = *rf.stream(); + while(in) { + string badword; + in >> badword; + if (badword.empty()) continue; + bad_words_.insert(TD::Convert(badword)); + } + } + } + + void TraversalFeaturesImpl(const Hypergraph::Edge& edge, + SparseVector<double>* features) const; + + const int word_count_; + const int letters_sq_; + const int letters_sqrt_; + const int in_dict_; + const int short_; + const int long_; + const int oov_; + const int short_range_; + const int high_freq_; + const int med_freq_; + const int freq_; + const int fl1_; + const int fl2_; + const int bad_; + FreqDict freq_dict_; + set<WordID> bad_words_; +}; + +BasicCSplitFeatures::BasicCSplitFeatures(const string& param) : + pimpl_(new BasicCSplitFeaturesImpl(param)) {} + +void BasicCSplitFeaturesImpl::TraversalFeaturesImpl( + const Hypergraph::Edge& edge, + SparseVector<double>* features) const { + features->set_value(word_count_, 1.0); + features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_)); + features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_)); + const WordID word = edge.rule_->e_[1]; + const char* sword = TD::Convert(word); + const int len = strlen(sword); + int cur = 0; + int chars = 0; + while(cur < len) { + cur += UTF8Len(sword[cur]); + ++chars; + } + + // these are corrections that attempt to make chars + // more like a phoneme count than a letter count, they + // are only really meaningful for german and should + // probably be gotten rid of + bool has_sch = strstr(sword, "sch"); + bool has_ch = (!has_sch && strstr(sword, "ch")); + bool has_ie = strstr(sword, "ie"); + bool has_zw = strstr(sword, "zw"); + if (has_sch) chars -= 2; + if (has_ch) --chars; + if (has_ie) --chars; + if (has_zw) --chars; + + float freq = freq_dict_.LookUp(word); + if (freq) { + features->set_value(freq_, freq); + features->set_value(in_dict_, 1.0); + } else { + features->set_value(oov_, 1.0); + freq = 99.0f; + } + if (bad_words_.count(word) != 0) + features->set_value(bad_, 1.0); + if (chars < 5) + features->set_value(short_, 1.0); + if (chars > 10) + features->set_value(long_, 1.0); + if (freq < 7.0f) + features->set_value(high_freq_, 1.0); + if (freq > 8.0f && freq < 10.f) + features->set_value(med_freq_, 1.0); + if (freq < 10.0f && chars < 5) + features->set_value(short_range_, 1.0); + + // i don't understand these features, but they really help! + features->set_value(fl1_, sqrt(chars * freq)); + features->set_value(fl2_, freq / chars); +} + +void BasicCSplitFeatures::TraversalFeaturesImpl( + const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const { + (void) smeta; + (void) ant_contexts; + (void) out_context; + (void) estimated_features; + if (edge.Arity() == 0) return; + if (edge.rule_->EWords() != 1) return; + pimpl_->TraversalFeaturesImpl(edge, features); +} + +struct ReverseCharLMCSplitFeatureImpl { + ReverseCharLMCSplitFeatureImpl(const string& param) : + order_(5), + vocab_(*TD::dict_), + ngram_(vocab_, order_) { + kBOS = vocab_.getIndex("<s>"); + kEOS = vocab_.getIndex("</s>"); + File file(param.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << param << endl; + ngram_.read(file); + } + + double LeftPhonotacticProb(const Lattice& inword, const int start) { + const int end = inword.size(); + for (int i = 0; i < order_; ++i) + sc[i] = kBOS; + int sp = min(end - start, order_ - 1); + // cerr << "[" << start << "," << sp << "]\n"; + int ci = (order_ - sp - 1); + int wi = start; + while (sp > 0) { + sc[ci] = inword[wi][0].label; + // cerr << " CHAR: " << TD::Convert(sc[ci]) << " ci=" << ci << endl; + ++wi; + ++ci; + --sp; + } + // cerr << " END ci=" << ci << endl; + sc[ci] = Vocab_None; + const double startprob = ngram_.wordProb(kEOS, sc); + // cerr << " PROB=" << startprob << endl; + return startprob; + } + private: + const int order_; + Vocab& vocab_; + VocabIndex kBOS; + VocabIndex kEOS; + Ngram ngram_; + VocabIndex sc[80]; +}; + +ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) : + pimpl_(new ReverseCharLMCSplitFeatureImpl(param)), + fid_(FD::Convert("RevCharLM")) {} + +void ReverseCharLMCSplitFeature::TraversalFeaturesImpl( + const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const { + (void) ant_contexts; + (void) estimated_features; + (void) out_context; + + if (edge.Arity() != 1) return; + if (edge.rule_->EWords() != 1) return; + const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_); + features->set_value(fid_, lpp); +} + |