diff options
Diffstat (limited to 'decoder')
| -rw-r--r-- | decoder/apply_models.cc | 1 | ||||
| -rw-r--r-- | decoder/csplit.cc | 140 | ||||
| -rw-r--r-- | decoder/decoder.cc | 25 | ||||
| -rw-r--r-- | decoder/ff_csplit.cc | 21 | 
4 files changed, 100 insertions, 87 deletions
| diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index 18460950..9390c809 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -177,6 +177,7 @@ public:    void Apply() {      int num_nodes = in.nodes_.size(); +    assert(num_nodes >= 2);      int goal_id = num_nodes - 1;      int pregoal = goal_id - 1;      int every = 1; diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 7d50e3af..4a723822 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -13,14 +13,16 @@ using namespace std;  struct CompoundSplitImpl {    CompoundSplitImpl(const boost::program_options::variables_map& conf) : -      fugen_elements_(true),   // TODO configure +      fugen_elements_(true),        min_size_(3),        kXCAT(TD::Convert("X")*-1),        kWORDBREAK_RULE(new TRule("[X] ||| # ||| #")),        kTEMPLATE_RULE(new TRule("[X] ||| [X,1] ? ||| [1] ?")),        kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")),        kFUGEN_S(FD::Convert("FugS")), -      kFUGEN_N(FD::Convert("FugN")) {} +      kFUGEN_N(FD::Convert("FugN")) { +    // TODO: use conf to turn fugenelements on and off +  }    void PasteTogetherStrings(const vector<string>& chars,                              const int i, @@ -40,73 +42,73 @@ struct CompoundSplitImpl {    void BuildTrellis(const vector<string>& chars,                      Hypergraph* forest) { -//    vector<int> nodes(chars.size()+1, -1); -//    nodes[0] = forest->AddNode(kXCAT)->id_;       // source -//    const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_; -//    forest->ConnectEdgeToHeadNode(left_rule, nodes[0]); -// -//    const int max_split_ = max(static_cast<int>(chars.size()) - min_size_ + 1, 1); -//    cerr << "max: " << max_split_ << "  " << " min: " << min_size_ << endl; -//    for (int i = min_size_; i < max_split_; ++i) -//      nodes[i] = forest->AddNode(kXCAT)->id_; -//    assert(nodes.back() == -1); -//    nodes.back() = forest->AddNode(kXCAT)->id_;   // sink -// -//    for (int i = 0; i < max_split_; ++i) { -//      if (nodes[i] < 0) continue; -//      const int start = min(i + min_size_, static_cast<int>(chars.size())); -//      for (int j = start; j <= chars.size(); ++j) { -//        if (nodes[j] < 0) continue; -//        string yield; -//        PasteTogetherStrings(chars, i, j, &yield); -//        // cerr << "[" << i << "," << j << "] " << yield << endl; -//        TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE)); -//        rule->e_[1] = rule->f_[1] = TD::Convert(yield); -//        // cerr << rule->AsString() << endl; -//        int edge = forest->AddEdge( -//          rule, -//          Hypergraph::TailNodeVector(1, nodes[i]))->id_; -//        forest->ConnectEdgeToHeadNode(edge, nodes[j]); -//        forest->edges_[edge].i_ = i; -//        forest->edges_[edge].j_ = j; -// -//        // handle "fugenelemente" here -//        // don't delete "fugenelemente" at the end of words -//        if (fugen_elements_ && j != chars.size()) { -//          const int len = yield.size(); -//          string alt; -//          int fid = 0; -//          if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') { -//            alt = yield.substr(0, len - 2); -//            fid = kFUGEN_S; -//          } else if (len > (min_size_ + 1) && yield[len-1] == 's') { -//            alt = yield.substr(0, len - 1); -//            fid = kFUGEN_S; -//          } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') { -//            alt = yield.substr(0, len - 1); -//            fid = kFUGEN_N; -//          } -//          if (alt.size()) { -//            TRulePtr altrule = TRulePtr(new TRule(*rule)); -//            altrule->e_[1] = TD::Convert(alt); -//            // cerr << altrule->AsString() << endl; -//            int edge = forest->AddEdge( -//              altrule, -//              Hypergraph::TailNodeVector(1, nodes[i]))->id_; -//            forest->ConnectEdgeToHeadNode(edge, nodes[j]); -//            forest->edges_[edge].feature_values_.set_value(fid, 1.0); -//            forest->edges_[edge].i_ = i; -//            forest->edges_[edge].j_ = j; -//          } -//        } -//      } -//    } -// -//    // add goal rule -//    Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); -//    Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); -//    Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); -//    forest->ConnectEdgeToHeadNode(hg_edge, goal); +    vector<int> nodes(chars.size()+1, -1); +    nodes[0] = forest->AddNode(kXCAT)->id_;       // source +    const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_; +    forest->ConnectEdgeToHeadNode(left_rule, nodes[0]); + +    const int max_split_ = max(static_cast<int>(chars.size()) - min_size_ + 1, 1); +    // cerr << "max: " << max_split_ << "  " << " min: " << min_size_ << endl; +    for (int i = min_size_; i < max_split_; ++i) +      nodes[i] = forest->AddNode(kXCAT)->id_; +    assert(nodes.back() == -1); +    nodes.back() = forest->AddNode(kXCAT)->id_;   // sink + +    for (int i = 0; i < max_split_; ++i) { +      if (nodes[i] < 0) continue; +      const int start = min(i + min_size_, static_cast<int>(chars.size())); +      for (int j = start; j <= chars.size(); ++j) { +        if (nodes[j] < 0) continue; +        string yield; +        PasteTogetherStrings(chars, i, j, &yield); +        // cerr << "[" << i << "," << j << "] " << yield << endl; +        TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE)); +        rule->e_[1] = rule->f_[1] = TD::Convert(yield); +        // cerr << rule->AsString() << endl; +        int edge = forest->AddEdge( +          rule, +          Hypergraph::TailNodeVector(1, nodes[i]))->id_; +        forest->ConnectEdgeToHeadNode(edge, nodes[j]); +        forest->edges_[edge].i_ = i; +        forest->edges_[edge].j_ = j; + +        // handle "fugenelemente" here +        // don't delete "fugenelemente" at the end of words +        if (fugen_elements_ && j != chars.size()) { +          const int len = yield.size(); +          string alt; +          int fid = 0; +          if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') { +            alt = yield.substr(0, len - 2); +            fid = kFUGEN_S; +          } else if (len > (min_size_ + 1) && yield[len-1] == 's') { +            alt = yield.substr(0, len - 1); +            fid = kFUGEN_S; +          } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') { +            alt = yield.substr(0, len - 1); +            fid = kFUGEN_N; +          } +          if (alt.size()) { +            TRulePtr altrule = TRulePtr(new TRule(*rule)); +            altrule->e_[1] = TD::Convert(alt); +            // cerr << altrule->AsString() << endl; +            int edge = forest->AddEdge( +              altrule, +              Hypergraph::TailNodeVector(1, nodes[i]))->id_; +            forest->ConnectEdgeToHeadNode(edge, nodes[j]); +            forest->edges_[edge].feature_values_.set_value(fid, 1.0); +            forest->edges_[edge].i_ = i; +            forest->edges_[edge].j_ = j; +          } +        } +      } +    } + +    // add goal rule +    Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); +    Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); +    Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); +    forest->ConnectEdgeToHeadNode(hg_edge, goal);    }   private:    const bool fugen_elements_; diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 3551b584..e28080aa 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -279,7 +279,6 @@ struct DecoderImpl {    bool encode_b64;    bool kbest;    bool unique_kbest; -  bool crf_uniform_empirical;    bool get_oracle_forest;    shared_ptr<WriteFile> extract_file;    int combine_size; @@ -379,7 +378,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream          ("max_translation_sample,X", po::value<int>(), "Sample the max translation from the chart")          ("pb_max_distortion,D", po::value<int>()->default_value(4), "Phrase-based decoder: maximum distortion")          ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)") -        ("crf_uniform_empirical", "If there are multple references use (i.e., lattice) a uniform distribution rather than posterior weighting a la EM")      ("get_oracle_forest,o", "Calculate rescored hypregraph using approximate BLEU scoring of rules")      ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)")          ("vector_format",po::value<string>()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") @@ -611,7 +609,6 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream    encode_b64 = str("vector_format",conf) == "b64";    kbest = conf.count("k_best");    unique_kbest = conf.count("unique_k_best"); -  crf_uniform_empirical = conf.count("crf_uniform_empirical");    get_oracle_forest = conf.count("get_oracle_forest");    cfg_options.Validate(); @@ -842,14 +839,12 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {    if (has_ref) {      if (HG::Intersect(ref, &forest)) {        if (!SILENT) forest_stats(forest,"  Constr. forest",show_tree_structure,show_features,feature_weights,oracle.show_derivation); -      if (crf_uniform_empirical) { -        if (!SILENT) cerr << "  USING UNIFORM WEIGHTS\n"; -        for (int i = 0; i < forest.edges_.size(); ++i) -          forest.edges_[i].edge_prob_=prob_t::One(); -      } else { -        forest.Reweight(feature_weights); -        if (!SILENT) cerr << "  Constr. VitTree: " << ViterbiFTree(forest) << endl; -      } +//      if (crf_uniform_empirical) { +//        if (!SILENT) cerr << "  USING UNIFORM WEIGHTS\n"; +//        for (int i = 0; i < forest.edges_.size(); ++i) +//          forest.edges_[i].edge_prob_=prob_t::One(); } +      forest.Reweight(feature_weights); +      if (!SILENT) cerr << "  Constr. VitTree: " << ViterbiFTree(forest) << endl;        if (conf.count("show_partition")) {           const prob_t z = Inside<prob_t, EdgeProb>(forest);           cerr << "  Contst. partition  log(Z): " << log(z) << endl; @@ -878,11 +873,9 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {        if (write_gradient) {          const prob_t ref_z = InsideOutside<prob_t, EdgeProb, SparseVector<prob_t>, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp);          ref_exp /= ref_z; -        if (crf_uniform_empirical) { -          log_ref_z = ref_exp.dot(feature_weights); -        } else { -          log_ref_z = log(ref_z); -        } +//        if (crf_uniform_empirical) +//          log_ref_z = ref_exp.dot(feature_weights); +        log_ref_z = log(ref_z);          //cerr << "      MODEL LOG Z: " << log_z << endl;          //cerr << "  EMPIRICAL LOG Z: " << log_ref_z << endl;          if ((log_z - log_ref_z) < kMINUS_EPSILON) { diff --git a/decoder/ff_csplit.cc b/decoder/ff_csplit.cc index 1485009b..204b7ce6 100644 --- a/decoder/ff_csplit.cc +++ b/decoder/ff_csplit.cc @@ -22,9 +22,11 @@ struct BasicCSplitFeaturesImpl {        letters_sq_(FD::Convert("LettersSq")),        letters_sqrt_(FD::Convert("LettersSqrt")),        in_dict_(FD::Convert("InDict")), +      in_dict_sub_word_(FD::Convert("InDictSubWord")),        short_(FD::Convert("Short")),        long_(FD::Convert("Long")),        oov_(FD::Convert("OOV")), +      oov_sub_word_(FD::Convert("OOVSubWord")),        short_range_(FD::Convert("ShortRange")),        high_freq_(FD::Convert("HighFreq")),        med_freq_(FD::Convert("MedFreq")), @@ -52,15 +54,18 @@ struct BasicCSplitFeaturesImpl {    }    void TraversalFeaturesImpl(const Hypergraph::Edge& edge, +                             const int src_word_size,                               SparseVector<double>* features) const;    const int word_count_;    const int letters_sq_;    const int letters_sqrt_;    const int in_dict_; +  const int in_dict_sub_word_;    const int short_;    const int long_;    const int oov_; +  const int oov_sub_word_;    const int short_range_;    const int high_freq_;    const int med_freq_; @@ -77,7 +82,9 @@ BasicCSplitFeatures::BasicCSplitFeatures(const string& param) :  void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(                                       const Hypergraph::Edge& edge, +                                     const int src_word_length,                                       SparseVector<double>* features) const { +  const bool subword = (edge.i_ > 0) || (edge.j_ < src_word_length);    features->set_value(word_count_, 1.0);    features->set_value(letters_sq_, (edge.j_ - edge.i_) * (edge.j_ - edge.i_));    features->set_value(letters_sqrt_, sqrt(edge.j_ - edge.i_)); @@ -108,8 +115,10 @@ void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(    if (freq) {      features->set_value(freq_, freq);      features->set_value(in_dict_, 1.0); +    if (subword) features->set_value(in_dict_sub_word_, 1.0);    } else {      features->set_value(oov_, 1.0); +    if (subword) features->set_value(oov_sub_word_, 1.0);      freq = 99.0f;    }    if (bad_words_.count(word) != 0) @@ -143,7 +152,7 @@ void BasicCSplitFeatures::TraversalFeaturesImpl(    (void) estimated_features;    if (edge.Arity() == 0) return;    if (edge.rule_->EWords() != 1) return; -  pimpl_->TraversalFeaturesImpl(edge, features); +  pimpl_->TraversalFeaturesImpl(edge, smeta.GetSourceLattice().size(), features);  }  struct ReverseCharLMCSplitFeatureImpl { @@ -208,9 +217,17 @@ void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(    if (edge.rule_->EWords() != 1) return;    const double lpp = pimpl_->LeftPhonotacticProb(smeta.GetSourceLattice(), edge.i_);    features->set_value(fid_, lpp); +#if 0    WordID neighbor_word = 0;    const WordID word = edge.rule_->e_[1]; -#if 0 +  const char* sword = TD::Convert(word); +  const int len = strlen(sword); +  int cur = 0; +  int chars = 0; +  while(cur < len) { +    cur += UTF8Len(sword[cur]); +    ++chars; +  }    if (chars > 4 && (sword[0] == 's' || sword[0] == 'n')) {      neighbor_word = TD::Convert(string(&sword[1]));    } | 
