diff options
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/kbest.h | 18 | ||||
-rw-r--r-- | decoder/scfg_translator.cc | 7 |
2 files changed, 20 insertions, 5 deletions
diff --git a/decoder/kbest.h b/decoder/kbest.h index 9a55f653..44c23151 100644 --- a/decoder/kbest.h +++ b/decoder/kbest.h @@ -6,6 +6,7 @@ #include <tr1/unordered_set> #include <boost/shared_ptr.hpp> +#include <boost/type_traits.hpp> #include "wordid.h" #include "hg.h" @@ -134,7 +135,7 @@ namespace KBest { } add_next = false; - if (cand.size() > 0) { + while (!add_next && cand.size() > 0) { std::pop_heap(cand.begin(), cand.end(), HeapCompare()); Derivation* d = cand.back(); cand.pop_back(); @@ -145,10 +146,15 @@ namespace KBest { if (!filter(d->yield)) { D.push_back(d); add_next = true; + } else { + // just because a node already derived a string (or whatever + // equivalent derivation class), you need to add its successors + // to the node's candidate pool + LazyNext(d, &cand, &s.ds); } - } else { - break; } + if (!add_next) + break; } if (k < D.size()) return D[k]; else return NULL; } @@ -184,7 +190,11 @@ namespace KBest { s.cand.push_back(d); } - const unsigned effective_k = std::min(k_prime, s.cand.size()); + unsigned effective_k = s.cand.size(); + if (boost::is_same<DerivationFilter,NoFilter<T> >::value) { + // if there's no filter you can use this optimization + effective_k = std::min(k_prime, s.cand.size()); + } const typename CandidateHeap::iterator kth = s.cand.begin() + effective_k; std::nth_element(s.cand.begin(), kth, s.cand.end(), DerivationCompare()); s.cand.resize(effective_k); diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 3b43b586..6f0b003b 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -12,6 +12,7 @@ #include "grammar.h" #include "bottom_up_parser.h" #include "sentence_metadata.h" +#include "stringlib.h" #include "tdict.h" #include "viterbi.h" #include "verbose.h" @@ -68,7 +69,11 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, const int j = alts[k].dist2next + i; const string& src = TD::Convert(alts[k].label); if (ss.count(alts[k].label) == 0) { - TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1")); + int length = static_cast<int>(log(UTF8StringLen(src)) / log(1.6)) + 1; + if (length > 6) length = 6; + string len_feat = "PassThrough_0=1"; + len_feat[12] += length; + TRulePtr pt(new TRule("[" + cat + "] ||| " + src + " ||| " + src + " ||| PassThrough=1 " + len_feat)); pt->a_.push_back(AlignmentPoint(0,0)); AddRule(pt); RefineRule(pt, ctf_level); |