summaryrefslogtreecommitdiff
path: root/decoder/ff_csplit.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/ff_csplit.cc')
-rw-r--r--decoder/ff_csplit.cc212
1 files changed, 212 insertions, 0 deletions
diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc
new file mode 100644
index 00000000..cac4bb8e
--- /dev/null
+++ b/decoder/ff_csplit.cc
@@ -0,0 +1,212 @@
+#include "ff_csplit.h"
+
+#include <set>
+#include <cstring>
+
+#include "Vocab.h"
+#include "Ngram.h"
+
+#include "sentence_metadata.h"
+#include "lattice.h"
+#include "tdict.h"
+#include "freqdict.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct BasicCSplitFeaturesImpl {
+ BasicCSplitFeaturesImpl(const string& param) :
+ word_count_(FD::Convert("WordCount")),
+ letters_sq_(FD::Convert("LettersSq")),
+ letters_sqrt_(FD::Convert("LettersSqrt")),
+ in_dict_(FD::Convert("InDict")),
+ short_(FD::Convert("Short")),
+ long_(FD::Convert("Long")),
+ oov_(FD::Convert("OOV")),
+ short_range_(FD::Convert("ShortRange")),
+ high_freq_(FD::Convert("HighFreq")),
+ med_freq_(FD::Convert("MedFreq")),
+ freq_(FD::Convert("Freq")),
+ fl1_(FD::Convert("FreqLen1")),
+ fl2_(FD::Convert("FreqLen2")),
+ bad_(FD::Convert("Bad")) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc != 1 && argc != 2) {
+ cerr << "Expected: freqdict.txt [badwords.txt]\n";
+ abort();
+ }
+ freq_dict_.Load(argv[0]);
+ if (argc == 2) {
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream();
+ while(in) {
+ string badword;
+ in >> badword;
+ if (badword.empty()) continue;
+ bad_words_.insert(TD::Convert(badword));
+ }
+ }
+ }
+
+ void TraversalFeaturesImpl(const Hypergraph::Edge& edge,
+ SparseVector<double>* features) const;
+
+ const int word_count_;
+ const int letters_sq_;
+ const int letters_sqrt_;
+ const int in_dict_;
+ const int short_;
+ const int long_;
+ const int oov_;
+ const int short_range_;
+ const int high_freq_;
+ const int med_freq_;
+ const int freq_;
+ const int fl1_;
+ const int fl2_;
+ const int bad_;
+ FreqDict freq_dict_;
+ set<WordID> bad_words_;
+};
+
+BasicCSplitFeatures::BasicCSplitFeatures(const string& param) :
+ pimpl_(new BasicCSplitFeaturesImpl(param)) {}
+
+void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(
+ const Hypergraph::Edge& edge,
+ SparseVector<double>* features) const {
+ features->set_value(word_count_, 1.0);
+ features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_));
+ features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_));
+ const WordID word = edge.rule_->e_[1];
+ const char* sword = TD::Convert(word);
+ const int len = strlen(sword);
+ int cur = 0;
+ int chars = 0;
+ while(cur < len) {
+ cur += UTF8Len(sword[cur]);
+ ++chars;
+ }
+
+ // these are corrections that attempt to make chars
+ // more like a phoneme count than a letter count, they
+ // are only really meaningful for german and should
+ // probably be gotten rid of
+ bool has_sch = strstr(sword, "sch");
+ bool has_ch = (!has_sch && strstr(sword, "ch"));
+ bool has_ie = strstr(sword, "ie");
+ bool has_zw = strstr(sword, "zw");
+ if (has_sch) chars -= 2;
+ if (has_ch) --chars;
+ if (has_ie) --chars;
+ if (has_zw) --chars;
+
+ float freq = freq_dict_.LookUp(word);
+ if (freq) {
+ features->set_value(freq_, freq);
+ features->set_value(in_dict_, 1.0);
+ } else {
+ features->set_value(oov_, 1.0);
+ freq = 99.0f;
+ }
+ if (bad_words_.count(word) != 0)
+ features->set_value(bad_, 1.0);
+ if (chars < 5)
+ features->set_value(short_, 1.0);
+ if (chars > 10)
+ features->set_value(long_, 1.0);
+ if (freq < 7.0f)
+ features->set_value(high_freq_, 1.0);
+ if (freq > 8.0f && freq < 10.f)
+ features->set_value(med_freq_, 1.0);
+ if (freq < 10.0f && chars < 5)
+ features->set_value(short_range_, 1.0);
+
+ // i don't understand these features, but they really help!
+ features->set_value(fl1_, sqrt(chars * freq));
+ features->set_value(fl2_, freq / chars);
+}
+
+void BasicCSplitFeatures::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const {
+ (void) smeta;
+ (void) ant_contexts;
+ (void) out_context;
+ (void) estimated_features;
+ if (edge.Arity() == 0) return;
+ if (edge.rule_->EWords() != 1) return;
+ pimpl_->TraversalFeaturesImpl(edge, features);
+}
+
+struct ReverseCharLMCSplitFeatureImpl {
+ ReverseCharLMCSplitFeatureImpl(const string& param) :
+ order_(5),
+ vocab_(*TD::dict_),
+ ngram_(vocab_, order_) {
+ kBOS = vocab_.getIndex("<s>");
+ kEOS = vocab_.getIndex("</s>");
+ File file(param.c_str(), "r", 0);
+ assert(file);
+ cerr << "Reading " << order_ << "-gram LM from " << param << endl;
+ ngram_.read(file);
+ }
+
+ double LeftPhonotacticProb(const Lattice& inword, const int start) {
+ const int end = inword.size();
+ for (int i = 0; i < order_; ++i)
+ sc[i] = kBOS;
+ int sp = min(end - start, order_ - 1);
+ // cerr << "[" << start << "," << sp << "]\n";
+ int ci = (order_ - sp - 1);
+ int wi = start;
+ while (sp > 0) {
+ sc[ci] = inword[wi][0].label;
+ // cerr << " CHAR: " << TD::Convert(sc[ci]) << " ci=" << ci << endl;
+ ++wi;
+ ++ci;
+ --sp;
+ }
+ // cerr << " END ci=" << ci << endl;
+ sc[ci] = Vocab_None;
+ const double startprob = ngram_.wordProb(kEOS, sc);
+ // cerr << " PROB=" << startprob << endl;
+ return startprob;
+ }
+ private:
+ const int order_;
+ Vocab& vocab_;
+ VocabIndex kBOS;
+ VocabIndex kEOS;
+ Ngram ngram_;
+ VocabIndex sc[80];
+};
+
+ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) :
+ pimpl_(new ReverseCharLMCSplitFeatureImpl(param)),
+ fid_(FD::Convert("RevCharLM")) {}
+
+void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const {
+ (void) ant_contexts;
+ (void) estimated_features;
+ (void) out_context;
+
+ if (edge.Arity() != 1) return;
+ if (edge.rule_->EWords() != 1) return;
+ const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);
+ features->set_value(fid_, lpp);
+}
+