diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/ff_csplit.cc | 36 | ||||
| -rw-r--r-- | decoder/ff_csplit.h | 1 | 
2 files changed, 35 insertions, 2 deletions
| diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index e6f78f84..33b6cea8 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -14,12 +14,19 @@  #include "stringlib.h"  #include "tdict.h" +#ifndef HAVE_OLD_CPP +# include <unordered_set> +#else +# include <tr1/unordered_set> +namespace std { using std::tr1::unordered_set; } +#endif  using namespace std;  struct BasicCSplitFeaturesImpl {    BasicCSplitFeaturesImpl(const string& param) :        word_count_(FD::Convert("WordCount")),        letters_sq_(FD::Convert("LettersSq")), +      letters_log_(FD::Convert("LettersLog")),        letters_sqrt_(FD::Convert("LettersSqrt")),        in_dict_(FD::Convert("InDict")),        in_dict_sub_word_(FD::Convert("InDictSubWord")), @@ -31,13 +38,14 @@ struct BasicCSplitFeaturesImpl {        high_freq_(FD::Convert("HighFreq")),        med_freq_(FD::Convert("MedFreq")),        freq_(FD::Convert("Freq")), +      in_dict_full_word_(FD::Convert("InDictFullWord")),        fl1_(FD::Convert("FreqLen1")),        fl2_(FD::Convert("FreqLen2")),        bad_(FD::Convert("Bad")) {      vector<string> argv;      int argc = SplitOnWhitespace(param, &argv); -    if (argc != 1 && argc != 2) { -      cerr << "Expected: freqdict.txt [badwords.txt]\n"; +    if (argc != 1 && argc != 2 && argc != 3) { +      cerr << "Expected: freqdict.txt [badwords.txt] [sensitvewords.txt]\n";        abort();      }      freq_dict_.Load(argv[0]); @@ -51,6 +59,14 @@ struct BasicCSplitFeaturesImpl {          bad_words_.insert(TD::Convert(badword));        }      } +    if (argc == 3) { +      ReadFile rf(argv[2]); +      istream& in = *rf.stream(); +      string line; +      while(getline(in, line)) { +        special_feats_[TD::Convert(line)] = FD::Convert("CS:"+line); +      } +    }    }    void TraversalFeaturesImpl(const Hypergraph::Edge& edge, @@ -59,6 +75,7 @@ struct BasicCSplitFeaturesImpl {    const int word_count_;    const int letters_sq_; +  const int letters_log_;    const int letters_sqrt_;    const int in_dict_;    const int in_dict_sub_word_; @@ -70,11 +87,13 @@ struct BasicCSplitFeaturesImpl {    const int high_freq_;    const int med_freq_;    const int freq_; +  const int in_dict_full_word_;    const int fl1_;    const int fl2_;    const int bad_;    FreqDict<float> freq_dict_;    set<WordID> bad_words_; +  unordered_map<WordID, int> special_feats_;  };  BasicCSplitFeatures::BasicCSplitFeatures(const string& param) : @@ -85,8 +104,15 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(                                       const int src_word_length,                                       SparseVector<double>* features) const {    const bool subword = (edge.i_ > 0) || (edge.j_ < src_word_length); +  string len_bias = "LenBias_0"; +  int swlen = log(src_word_length) / log(1.69); +  if (swlen > 9) swlen = 9; +  len_bias[8] += swlen; +  int fid_len_bias_ = FD::Convert(len_bias); +  features->set_value(fid_len_bias_, 1.0);     features->set_value(word_count_, 1.0);    features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_)); +  features->set_value(letters_log_, log(edge.j_ - edge.i_));    features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_));    const WordID word = edge.rule_->e_[1];    const char* sword = TD::Convert(word).c_str(); @@ -117,10 +143,14 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(      features->set_value(in_dict_, 1.0);      if (subword) features->set_value(in_dict_sub_word_, 1.0);    } else { +    if (!subword) features->set_value(in_dict_full_word_, 1.0);      features->set_value(oov_, 1.0);      if (subword) features->set_value(oov_sub_word_, 1.0);      freq = 99.0f;    } +  const unordered_map<WordID, int>::const_iterator it = special_feats_.find(word); +  if (it != special_feats_.end()) +    features->set_value(it->second, 1.0);    if (bad_words_.count(word) != 0)      features->set_value(bad_, 1.0);    if (chars < 5) @@ -139,6 +169,8 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(    features->set_value(fl2_, freq / chars);  } +void BasicCSplitFeatures::PrepareForInput(const SentenceMetadata& smeta) {} +  void BasicCSplitFeatures::TraversalFeaturesImpl(                                       const SentenceMetadata& smeta,                                       const Hypergraph::Edge& edge, diff --git a/decoder/ff_csplit.h b/decoder/ff_csplit.h index 64d42526..79bf2886 100644 --- a/decoder/ff_csplit.h +++ b/decoder/ff_csplit.h @@ -10,6 +10,7 @@ class BasicCSplitFeaturesImpl;  class BasicCSplitFeatures : public FeatureFunction {   public:    BasicCSplitFeatures(const std::string& param); +  virtual void PrepareForInput(const SentenceMetadata& smeta);   protected:    virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,                                       const HG::Edge& edge, | 
