#include "ff_ngrams.h" #include <cstring> #include <iostream> #include <boost/scoped_ptr.hpp> #include "filelib.h" #include "stringlib.h" #include "hg.h" #include "tdict.h" using namespace std; static const unsigned char HAS_FULL_CONTEXT = 1; static const unsigned char HAS_EOS_ON_RIGHT = 2; static const unsigned char MASK = 7; namespace { template <unsigned MAX_ORDER = 5> struct State { explicit State() { memset(state, 0, sizeof(state)); } explicit State(int order) { memset(state, 0, (order - 1) * sizeof(WordID)); } State<MAX_ORDER>(char order, const WordID* mem) { memcpy(state, mem, (order - 1) * sizeof(WordID)); } State(const State<MAX_ORDER>& other) { memcpy(state, other.state, sizeof(state)); } const State& operator=(const State<MAX_ORDER>& other) { memcpy(state, other.state, sizeof(state)); } explicit State(const State<MAX_ORDER>& other, unsigned order, WordID extend) { char om1 = order - 1; if (!om1) { memset(state, 0, sizeof(state)); return; } for (char i = 1; i < om1; ++i) state[i - 1]= other.state[i]; state[om1 - 1] = extend; } const WordID& operator[](size_t i) const { return state[i]; } WordID& operator[](size_t i) { return state[i]; } WordID state[MAX_ORDER]; }; } namespace { string Escape(const string& x) { if (x.find('=') == string::npos && x.find(';') == string::npos) { return x; } string y = x; for (int i = 0; i < y.size(); ++i) { if (y[i] == '=') y[i]='_'; if (y[i] == ';') y[i]='_'; } return y; } } static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator, string* cluster_file, string* featname) { vector<string> const& argv=SplitOnWhitespace(in); *featname = ""; *explicit_markers = false; *order = 3; prefixes.push_back("NOT-USED"); prefixes.push_back("U:"); // default unigram prefix prefixes.push_back("B:"); // default bigram prefix prefixes.push_back("T:"); // ...etc prefixes.push_back("4:"); // ...etc prefixes.push_back("5:"); // max allowed! target_separator = "_"; #define LMSPEC_NEXTARG if (i==argv.end()) { \ cerr << "Missing argument for "<<*last<<". "; goto usage; \ } else { ++i; } for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { string const& s=*i; if (s[0]=='-') { if (s.size()>2) goto fail; switch (s[1]) { case 'x': *explicit_markers = true; break; case 'n': LMSPEC_NEXTARG; *featname=*i; break; case 'U': LMSPEC_NEXTARG; prefixes[1] = *i; break; case 'B': LMSPEC_NEXTARG; prefixes[2] = *i; break; case 'T': LMSPEC_NEXTARG; prefixes[3] = *i; break; case '4': LMSPEC_NEXTARG; prefixes[4] = *i; break; case '5': LMSPEC_NEXTARG; prefixes[5] = *i; break; case 'c': LMSPEC_NEXTARG; *cluster_file = *i; break; case 'S': LMSPEC_NEXTARG; target_separator = *i; break; case 'o': LMSPEC_NEXTARG; *order=atoi((*i).c_str()); break; #undef LMSPEC_NEXTARG default: fail: cerr<<"Unknown option on NgramFeatures "<<s<<" ; "; goto usage; } } } return true; usage: cerr << "Wrong parameters for NgramFeatures.\n\n" << "NgramFeatures Usage: \n" << " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n" << " [-c <cluster-file>]\n" << " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n" << " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n" << "Defaults: \n" << " <order> = 3\n" << " <unigram-prefix> = U:\n" << " <bigram-prefix> = B:\n" << " <trigram-prefix> = T:\n" << " <4-gram-prefix> = 4:\n" << " <5-gram-prefix> = 5:\n" << " <separator> = _\n" << " -x (i.e. explicit sos/eos markers) is turned off\n\n" << "Example configuration: \n" << " feature_function=NgramFeatures -o 3 -T tri: -S |\n\n" << "Example feature instantiation: \n" << " tri:a|b|c \n\n"; abort(); } class NgramDetectorImpl { // returns the number of unscored words at the left edge of a span inline int UnscoredSize(const void* state) const { return *(static_cast<const char*>(state) + unscored_size_offset_); } inline void SetUnscoredSize(int size, void* state) const { *(static_cast<char*>(state) + unscored_size_offset_) = size; } inline State<5> RemnantLMState(const void* cstate) const { return State<5>(order_, static_cast<const WordID*>(cstate)); } inline const State<5> BeginSentenceState() const { State<5> state(order_); state.state[0] = kSOS_; return state; } inline void SetRemnantLMState(const State<5>& lmstate, void* state) const { // if we were clever, we could use the memory pointed to by state to do all // the work, avoiding this copy memcpy(state, lmstate.state, (order_-1) * sizeof(WordID)); } WordID IthUnscoredWord(int i, const void* state) const { const WordID* const mem = reinterpret_cast<const WordID*>(static_cast<const char*>(state) + unscored_words_offset_); return mem[i]; } void SetIthUnscoredWord(int i, const WordID index, void *state) const { WordID* mem = reinterpret_cast<WordID*>(static_cast<char*>(state) + unscored_words_offset_); mem[i] = index; } inline bool GetFlag(const void *state, unsigned char flag) const { return (*(static_cast<const char*>(state) + is_complete_offset_) & flag); } inline void SetFlag(bool on, unsigned char flag, void *state) const { if (on) { *(static_cast<char*>(state) + is_complete_offset_) |= flag; } else { *(static_cast<char*>(state) + is_complete_offset_) &= (MASK ^ flag); } } inline bool HasFullContext(const void *state) const { return GetFlag(state, HAS_FULL_CONTEXT); } inline void SetHasFullContext(bool flag, void *state) const { SetFlag(flag, HAS_FULL_CONTEXT, state); } WordID MapToClusterIfNecessary(WordID w) const { if (cluster_map.size() == 0) return w; if (w >= cluster_map.size()) return kCDEC_UNK; return cluster_map[w]; } void FireFeatures(const State<5>& state, WordID cur, SparseVector<double>* feats) { FidTree* ft = &fidroot_; int n = 0; WordID buf[10]; int ci = order_ - 1; WordID curword = cur; while(curword) { buf[n] = curword; int& fid = ft->fids[curword]; ++n; if (!fid) { ostringstream os; os << featname_; os << prefixes_[n]; for (int i = n-1; i >= 0; --i) { os << (i != n-1 ? target_separator_ : ""); const string& tok = TD::Convert(buf[i]); os << Escape(tok); } fid = FD::Convert(os.str()); } feats->set_value(fid, 1); ft = &ft->levels[curword]; --ci; if (ci < 0) break; curword = state[ci]; } } public: void LookupWords(const TRule& rule, const vector<const void*>& ant_states, SparseVector<double>* feats, SparseVector<double>* est_feats, void* remnant) { double sum = 0.0; double est_sum = 0.0; int num_scored = 0; int num_estimated = 0; bool saw_eos = false; bool has_some_history = false; State<5> state; const vector<WordID>& e = rule.e(); bool context_complete = false; for (int j = 0; j < e.size(); ++j) { if (e[j] < 1) { // handle non-terminal substitution const void* astate = (ant_states[-e[j]]); int unscored_ant_len = UnscoredSize(astate); for (int k = 0; k < unscored_ant_len; ++k) { const WordID cur_word = IthUnscoredWord(k, astate); const bool is_oov = (cur_word == 0); SparseVector<double> p; if (cur_word == kSOS_) { state = BeginSentenceState(); if (has_some_history) { // this is immediately fully scored, and bad p.set_value(FD::Convert("Malformed"), 1.0); context_complete = true; } else { // this might be a real <s> num_scored = max(0, order_ - 2); } } else { FireFeatures(state, cur_word, &p); const State<5> scopy = State<5>(state, order_, cur_word); state = scopy; if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } saw_eos = (cur_word == kEOS_); } has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; } if (context_complete) { (*feats) += p; } else { if (remnant) SetIthUnscoredWord(num_estimated, cur_word, remnant); ++num_estimated; (*est_feats) += p; } } saw_eos = GetFlag(astate, HAS_EOS_ON_RIGHT); if (HasFullContext(astate)) { // this is equivalent to the "star" in Chiang 2007 state = RemnantLMState(astate); context_complete = true; } } else { // handle terminal const WordID cur_word = MapToClusterIfNecessary(e[j]); SparseVector<double> p; if (cur_word == kSOS_) { state = BeginSentenceState(); if (has_some_history) { // this is immediately fully scored, and bad p.set_value(FD::Convert("Malformed"), -100); context_complete = true; } else { // this might be a real <s> num_scored = max(0, order_ - 2); } } else { FireFeatures(state, cur_word, &p); const State<5> scopy = State<5>(state, order_, cur_word); state = scopy; if (saw_eos) { p.set_value(FD::Convert("Malformed"), 1.0); } saw_eos = (cur_word == kEOS_); } has_some_history = true; ++num_scored; if (!context_complete) { if (num_scored >= order_) context_complete = true; } if (context_complete) { (*feats) += p; } else { if (remnant) SetIthUnscoredWord(num_estimated, cur_word, remnant); ++num_estimated; (*est_feats) += p; } } } if (remnant) { SetFlag(saw_eos, HAS_EOS_ON_RIGHT, remnant); SetRemnantLMState(state, remnant); SetUnscoredSize(num_estimated, remnant); SetHasFullContext(context_complete || (num_scored >= order_), remnant); } } // this assumes no target words on final unary -> goal rule. is that ok? // for <s> (n-1 left words) and (n-1 right words) </s> void FinalTraversal(const void* state, SparseVector<double>* feats) { if (add_sos_eos_) { // rules do not produce <s> </s>, so do it here SetRemnantLMState(BeginSentenceState(), dummy_state_); SetHasFullContext(1, dummy_state_); SetUnscoredSize(0, dummy_state_); dummy_ants_[1] = state; LookupWords(*dummy_rule_, dummy_ants_, feats, NULL, NULL); } else { // rules DO produce <s> ... </s> #if 0 double p = 0; if (!GetFlag(state, HAS_EOS_ON_RIGHT)) { p -= 100; } if (UnscoredSize(state) > 0) { // are there unscored words if (kSOS_ != IthUnscoredWord(0, state)) { p -= 100 * UnscoredSize(state); } } return p; #endif } } void ReadClusterFile(const string& clusters) { ReadFile rf(clusters); istream& in = *rf.stream(); string line; int lc = 0; string cluster; string word; while(getline(in, line)) { ++lc; if (line.size() == 0) continue; if (line[0] == '#') continue; unsigned cend = 1; while((line[cend] != ' ' && line[cend] != '\t') && cend < line.size()) { ++cend; } if (cend == line.size()) { cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl; abort(); } unsigned wbeg = cend + 1; while((line[wbeg] == ' ' || line[wbeg] == '\t') && wbeg < line.size()) { ++wbeg; } if (wbeg == line.size()) { cerr << "Line " << lc << " in " << clusters << " malformed: " << line << endl; abort(); } unsigned wend = wbeg + 1; while((line[wend] != ' ' && line[wend] != '\t') && wend < line.size()) { ++wend; } const WordID clusterid = TD::Convert(line.substr(0, cend)); const WordID wordid = TD::Convert(line.substr(wbeg, wend - wbeg)); if (wordid >= cluster_map.size()) cluster_map.resize(wordid + 10, kCDEC_UNK); cluster_map[wordid] = clusterid; } cluster_map[kSOS_] = kSOS_; cluster_map[kEOS_] = kEOS_; } vector<WordID> cluster_map; public: explicit NgramDetectorImpl(bool explicit_markers, unsigned order, vector<string>& prefixes, string& target_separator, const string& clusters, const string& featname) : kCDEC_UNK(TD::Convert("<unk>")) , add_sos_eos_(!explicit_markers) { order_ = order; state_size_ = (order_ - 1) * sizeof(WordID) + 2 + (order_ - 1) * sizeof(WordID); unscored_size_offset_ = (order_ - 1) * sizeof(WordID); is_complete_offset_ = unscored_size_offset_ + 1; unscored_words_offset_ = is_complete_offset_ + 1; prefixes_ = prefixes; target_separator_ = target_separator; featname_ = featname; // special handling of beginning / ending sentence markers dummy_state_ = new char[state_size_]; memset(dummy_state_, 0, state_size_); dummy_ants_.push_back(dummy_state_); dummy_ants_.push_back(NULL); dummy_rule_.reset(new TRule("[DUMMY] ||| [BOS] [DUMMY] ||| [1] [2] </s> ||| X=0")); kSOS_ = TD::Convert("<s>"); kEOS_ = TD::Convert("</s>"); if (clusters.size()) ReadClusterFile(clusters); } ~NgramDetectorImpl() { delete[] dummy_state_; } int ReserveStateSize() const { return state_size_; } private: const WordID kCDEC_UNK; WordID kSOS_; // <s> - requires special handling. WordID kEOS_; // </s> const bool add_sos_eos_; // flag indicating whether the hypergraph produces <s> and </s> // if this is true, FinalTransitionFeatures will "add" <s> and </s> // if false, FinalTransitionFeatures will score anything with the // markers in the right place (i.e., the beginning and end of // the sentence) with 0, and anything else with -100 int order_; int state_size_; int unscored_size_offset_; int is_complete_offset_; int unscored_words_offset_; char* dummy_state_; vector<const void*> dummy_ants_; TRulePtr dummy_rule_; vector<string> prefixes_; string target_separator_; string featname_; struct FidTree { map<WordID, int> fids; map<WordID, FidTree> levels; }; mutable FidTree fidroot_; }; NgramDetector::NgramDetector(const string& param) { string filename, mapfile, featname, target_separator; vector<string> prefixes; bool explicit_markers = false; unsigned order = 3; string clusters; ParseArgs(param, &explicit_markers, &order, prefixes, target_separator, &clusters, &featname); pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes, target_separator, clusters, featname); SetStateSize(pimpl_->ReserveStateSize()); } NgramDetector::~NgramDetector() { delete pimpl_; } void NgramDetector::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, const Hypergraph::Edge& edge, const vector<const void*>& ant_states, SparseVector<double>* features, SparseVector<double>* estimated_features, void* state) const { pimpl_->LookupWords(*edge.rule_, ant_states, features, estimated_features, state); } void NgramDetector::FinalTraversalFeatures(const void* ant_state, SparseVector<double>* features) const { pimpl_->FinalTraversal(ant_state, features); }