diff options
author | Chris Dyer <redpony@gmail.com> | 2009-12-06 22:25:25 -0500 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2009-12-06 22:25:25 -0500 |
commit | 2a18010e255810cc2b5bcbe688f3db8eabda23ca (patch) | |
tree | e310286257e5445072303dcca03acb85a865c26a | |
parent | 59ea352f3dcf3bf58969f404615fed4ff6b931f7 (diff) |
add compound splitting logic and features (Dyer 2008, NAACL)
-rw-r--r-- | src/Makefile.am | 3 | ||||
-rw-r--r-- | src/cdec.cc | 11 | ||||
-rw-r--r-- | src/cdec_ff.cc | 5 | ||||
-rw-r--r-- | src/csplit.cc | 152 | ||||
-rw-r--r-- | src/csplit.h | 18 | ||||
-rw-r--r-- | src/ff_csplit.cc | 197 | ||||
-rw-r--r-- | src/ff_csplit.h | 39 | ||||
-rw-r--r-- | src/freqdict.cc | 14 | ||||
-rw-r--r-- | src/freqdict.h | 9 | ||||
-rw-r--r-- | src/stringlib.h | 10 |
10 files changed, 445 insertions, 13 deletions
diff --git a/src/Makefile.am b/src/Makefile.am index 34ffb170..4d0459ef 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -36,6 +36,7 @@ noinst_LIBRARIES = libhg.a libhg_a_SOURCES = \ fst_translator.cc \ + csplit.cc \ scfg_translator.cc \ hg.cc \ hg_io.cc \ @@ -58,6 +59,8 @@ libhg_a_SOURCES = \ ff.cc \ ff_lm.cc \ ff_wordalign.cc \ + ff_csplit.cc \ + freqdict.cc \ lexcrf.cc \ bottom_up_parser.cc \ phrasebased_translator.cc \ diff --git a/src/cdec.cc b/src/cdec.cc index c5780cef..7bdf7bcc 100644 --- a/src/cdec.cc +++ b/src/cdec.cc @@ -17,6 +17,7 @@ #include "sampler.h" #include "sparse_vector.h" #include "lexcrf.h" +#include "csplit.h" #include "weights.h" #include "tdict.h" #include "ff.h" @@ -46,7 +47,7 @@ void ShowBanner() { void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment") + ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSPLIT (compound splitting)") ("input,i",po::value<string>()->default_value("-"),"Source file") ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("weights,w",po::value<string>(),"Feature weights file") @@ -100,14 +101,14 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { exit(1); } - if (conf->count("help") || conf->count("grammar") == 0) { + if (conf->count("help") || conf->count("formalism") == 0) { cerr << dcmdline_options << endl; exit(1); } const string formalism = LowercaseString((*conf)["formalism"].as<string>()); - if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit' or 'lexcrf'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -231,6 +232,8 @@ int main(int argc, char** argv) { translator.reset(new FSTTranslator(conf)); else if (formalism == "pb") translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "csplit") + translator.reset(new CompoundSplit(conf)); else if (formalism == "lexcrf") translator.reset(new LexicalCRF(conf)); else diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc index 89353f5f..846a908e 100644 --- a/src/cdec_ff.cc +++ b/src/cdec_ff.cc @@ -2,8 +2,9 @@ #include "ff.h" #include "ff_lm.h" -#include "ff_factory.h" +#include "ff_csplit.h" #include "ff_wordalign.h" +#include "ff_factory.h" boost::shared_ptr<FFRegistry> global_ff_registry; @@ -14,5 +15,7 @@ void register_feature_functions() { global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>); global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>); global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>); + global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>); + global_ff_registry->Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>); }; diff --git a/src/csplit.cc b/src/csplit.cc new file mode 100644 index 00000000..788f3112 --- /dev/null +++ b/src/csplit.cc @@ -0,0 +1,152 @@ +#include "csplit.h" + +#include <iostream> + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct CompoundSplitImpl { + CompoundSplitImpl(const boost::program_options::variables_map& conf) : + fugen_elements_(true), // TODO configure + min_size_(3), + kXCAT(TD::Convert("X")*-1), + kWORDBREAK_RULE(new TRule("[X] ||| # ||| #")), + kTEMPLATE_RULE(new TRule("[X] ||| [X,1] ? ||| [1] ?")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")), + kFUGEN_S(FD::Convert("FugS")), + kFUGEN_N(FD::Convert("FugN")) {} + + void PasteTogetherStrings(const vector<string>& chars, + const int i, + const int j, + string* yield) { + int size = 0; + for (int k=i; k<j; ++k) + size += chars[k].size(); + yield->resize(size); + int cur = 0; + for (int k=i; k<j; ++k) { + const string& cs = chars[k]; + for (int l = 0; l < cs.size(); ++l) + (*yield)[cur++] = cs[l]; + } + } + + void BuildTrellis(const vector<string>& chars, + Hypergraph* forest) { + vector<int> nodes(chars.size()+1, -1); + nodes[0] = forest->AddNode(kXCAT)->id_; // source + const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_; + forest->ConnectEdgeToHeadNode(left_rule, nodes[0]); + + const int max_split_ = chars.size() - min_size_ + 1; + for (int i = min_size_; i < max_split_; ++i) + nodes[i] = forest->AddNode(kXCAT)->id_; + assert(nodes.back() == -1); + nodes.back() = forest->AddNode(kXCAT)->id_; // sink + + for (int i = 0; i < max_split_; ++i) { + if (nodes[i] < 0) continue; + for (int j = i + min_size_; j <= chars.size(); ++j) { + if (nodes[j] < 0) continue; + string yield; + PasteTogetherStrings(chars, i, j, &yield); + // cerr << "[" << i << "," << j << "] " << yield << endl; + TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE)); + rule->e_[1] = rule->f_[1] = TD::Convert(yield); + // cerr << rule->AsString() << endl; + int edge = forest->AddEdge( + rule, + Hypergraph::TailNodeVector(1, nodes[i]))->id_; + forest->ConnectEdgeToHeadNode(edge, nodes[j]); + forest->edges_[edge].i_ = i; + forest->edges_[edge].j_ = j; + + // handle "fugenelemente" here + // don't delete "fugenelemente" at the end of words + if (fugen_elements_ && j != chars.size()) { + const int len = yield.size(); + string alt; + int fid = 0; + if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') { + alt = yield.substr(0, len - 2); + fid = kFUGEN_S; + } else if (len > (min_size_ + 1) && yield[len-1] == 's') { + alt = yield.substr(0, len - 1); + fid = kFUGEN_S; + } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') { + alt = yield.substr(0, len - 1); + fid = kFUGEN_N; + } + if (alt.size()) { + TRulePtr altrule = TRulePtr(new TRule(*rule)); + altrule->e_[1] = TD::Convert(alt); + // cerr << altrule->AsString() << endl; + int edge = forest->AddEdge( + altrule, + Hypergraph::TailNodeVector(1, nodes[i]))->id_; + forest->ConnectEdgeToHeadNode(edge, nodes[j]); + forest->edges_[edge].feature_values_.set_value(fid, 1.0); + forest->edges_[edge].i_ = i; + forest->edges_[edge].j_ = j; + } + } + } + } + + // add goal rule + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + private: + const bool fugen_elements_; + const int min_size_; + const WordID kXCAT; + const TRulePtr kWORDBREAK_RULE; + const TRulePtr kTEMPLATE_RULE; + const TRulePtr kGOAL_RULE; + const int kFUGEN_S; + const int kFUGEN_N; +}; + +CompoundSplit::CompoundSplit(const boost::program_options::variables_map& conf) : + pimpl_(new CompoundSplitImpl(conf)) {} + +static void SplitUTF8String(const string& in, vector<string>* out) { + out->resize(in.size()); + int i = 0; + int c = 0; + while (i < in.size()) { + const int len = UTF8Len(in[i]); + assert(len); + (*out)[c] = in.substr(i, len); + ++c; + i += len; + } + out->resize(c); +} + +bool CompoundSplit::Translate(const string& input, + SentenceMetadata* smeta, + const vector<double>& weights, + Hypergraph* forest) { + if (input.find(" ") != string::npos) { + cerr << " BAD INPUT: " << input << "\n CompoundSplit expects single words\n"; + abort(); + } + vector<string> in; + SplitUTF8String(input, &in); + smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign + pimpl_->BuildTrellis(in, forest); + forest->Reweight(weights); + return true; +} + diff --git a/src/csplit.h b/src/csplit.h new file mode 100644 index 00000000..5911af77 --- /dev/null +++ b/src/csplit.h @@ -0,0 +1,18 @@ +#ifndef _CSPLIT_H_ +#define _CSPLIT_H_ + +#include "translator.h" +#include "lattice.h" + +struct CompoundSplitImpl; +struct CompoundSplit : public Translator { + CompoundSplit(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector<double>& weights, + Hypergraph* forest); + private: + boost::shared_ptr<CompoundSplitImpl> pimpl_; +}; + +#endif diff --git a/src/ff_csplit.cc b/src/ff_csplit.cc new file mode 100644 index 00000000..5d8dfefb --- /dev/null +++ b/src/ff_csplit.cc @@ -0,0 +1,197 @@ +#include "ff_csplit.h" + +#include <set> +#include <cstring> + +#include "tdict.h" +#include "freqdict.h" +#include "filelib.h" +#include "stringlib.h" +#include "tdict.h" + +#include "Vocab.h" +#include "Ngram.h" + +using namespace std; + +struct BasicCSplitFeaturesImpl { + BasicCSplitFeaturesImpl(const string& param) : + word_count_(FD::Convert("WordCount")), + in_dict_(FD::Convert("InDict")), + short_(FD::Convert("Short")), + long_(FD::Convert("Long")), + oov_(FD::Convert("OOV")), + short_range_(FD::Convert("ShortRange")), + high_freq_(FD::Convert("HighFreq")), + med_freq_(FD::Convert("MedFreq")), + freq_(FD::Convert("Freq")), + bad_(FD::Convert("Bad")) { + vector<string> argv; + int argc = SplitOnWhitespace(param, &argv); + if (argc != 1 && argc != 2) { + cerr << "Expected: freqdict.txt [badwords.txt]\n"; + abort(); + } + freq_dict_.Load(argv[0]); + if (argc == 2) { + ReadFile rf(argv[1]); + istream& in = *rf.stream(); + while(in) { + string badword; + in >> badword; + if (badword.empty()) continue; + bad_words_.insert(TD::Convert(badword)); + } + } + } + + void TraversalFeaturesImpl(const Hypergraph::Edge& edge, + SparseVector<double>* features) const; + + const int word_count_; + const int in_dict_; + const int short_; + const int long_; + const int oov_; + const int short_range_; + const int high_freq_; + const int med_freq_; + const int freq_; + const int bad_; + FreqDict freq_dict_; + set<WordID> bad_words_; +}; + +BasicCSplitFeatures::BasicCSplitFeatures(const string& param) : + pimpl_(new BasicCSplitFeaturesImpl(param)) {} + +void BasicCSplitFeaturesImpl::TraversalFeaturesImpl( + const Hypergraph::Edge& edge, + SparseVector<double>* features) const { + features->set_value(word_count_, 1.0); + const WordID word = edge.rule_->e_[1]; + const char* sword = TD::Convert(word); + const int len = strlen(sword); + int cur = 0; + int chars = 0; + while(cur < len) { + cur += UTF8Len(sword[cur]); + ++chars; + } + bool has_sch = strstr(sword, "sch"); + bool has_ch = (!has_sch && strstr(sword, "ch")); + bool has_ie = strstr(sword, "ie"); + bool has_zw = strstr(sword, "zw"); + if (has_sch) chars -= 2; + if (has_ch) --chars; + if (has_ie) --chars; + if (has_zw) --chars; + + float freq = freq_dict_.LookUp(word); + if (freq) { + features->set_value(freq_, freq); + features->set_value(in_dict_, 1.0); + } else { + features->set_value(oov_, 1.0); + freq = 99.0f; + } + if (bad_words_.count(word) != 0) + features->set_value(bad_, 1.0); + if (chars < 5) + features->set_value(short_, 1.0); + if (chars > 10) + features->set_value(long_, 1.0); + if (freq < 7.0f) + features->set_value(high_freq_, 1.0); + if (freq > 8.0f && freq < 10.f) + features->set_value(med_freq_, 1.0); + if (freq < 10.0f && chars < 5) + features->set_value(short_range_, 1.0); +} + +void BasicCSplitFeatures::TraversalFeaturesImpl( + const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const { + (void) smeta; + (void) ant_contexts; + (void) out_context; + (void) estimated_features; + if (edge.Arity() == 0) return; + if (edge.rule_->EWords() != 1) return; + pimpl_->TraversalFeaturesImpl(edge, features); +} + +struct ReverseCharLMCSplitFeatureImpl { + ReverseCharLMCSplitFeatureImpl(const string& param) : + order_(5), + ngram_(vocab_, order_) { + kBOS = vocab_.getIndex("<s>"); + kEOS = vocab_.getIndex("</s>"); + File file(param.c_str(), "r", 0); + assert(file); + cerr << "Reading " << order_ << "-gram LM from " << param << endl; + ngram_.read(file); + } + + double LeftPhonotacticProb(const char* word) { + for (int i = 0; i < order_; ++i) + sc[i] = kBOS; + const int len = strlen(word); + int cur = 0; + int chars = 0; + while(cur < len) { + cur += UTF8Len(word[cur]); + ++chars; + } + const int sp = min(chars, order_-1); + int wend = 0; cur = 0; + while(cur < sp) { + wend += UTF8Len(word[wend]); + ++cur; + } + int wi = 0; + int ci = (order_ - sp - 1); + // cerr << "WORD: " << word << endl; + while (wi != wend) { + const int clen = UTF8Len(word[wi]); + string cur_char(&word[wi], clen); + wi += clen; + // cerr << " char: " << cur_char << " ci=" << ci << endl; + sc[ci++] = vocab_.getIndex(cur_char.c_str()); + } + // cerr << " END sp=" << sp << endl; + sc[sp] = Vocab_None; + const double startprob = -ngram_.wordProb(kEOS, sc); + // cerr << " PROB=" << startprob << endl; + return startprob; + } + private: + const int order_; + Vocab vocab_; + VocabIndex kBOS; + VocabIndex kEOS; + Ngram ngram_; + VocabIndex sc[80]; +}; + +ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) : + pimpl_(new ReverseCharLMCSplitFeatureImpl(param)), + fid_(FD::Convert("RevCharLM")) {} + +void ReverseCharLMCSplitFeature::TraversalFeaturesImpl( + const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const { + if (edge.Arity() != 1) return; + if (edge.rule_->EWords() != 1) return; + const double lpp = pimpl_->LeftPhonotacticProb(TD::Convert(edge.rule_->e_[1])); + features->set_value(fid_, lpp); +} + diff --git a/src/ff_csplit.h b/src/ff_csplit.h new file mode 100644 index 00000000..c1cfb64b --- /dev/null +++ b/src/ff_csplit.h @@ -0,0 +1,39 @@ +#ifndef _FF_CSPLIT_H_ +#define _FF_CSPLIT_H_ + +#include <boost/shared_ptr.hpp> + +#include "ff.h" + +class BasicCSplitFeaturesImpl; +class BasicCSplitFeatures : public FeatureFunction { + public: + BasicCSplitFeatures(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + boost::shared_ptr<BasicCSplitFeaturesImpl> pimpl_; +}; + +class ReverseCharLMCSplitFeatureImpl; +class ReverseCharLMCSplitFeature : public FeatureFunction { + public: + ReverseCharLMCSplitFeature(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* out_context) const; + private: + boost::shared_ptr<ReverseCharLMCSplitFeatureImpl> pimpl_; + const int fid_; +}; + +#endif diff --git a/src/freqdict.cc b/src/freqdict.cc index 4cfffe58..9e25d346 100644 --- a/src/freqdict.cc +++ b/src/freqdict.cc @@ -2,11 +2,17 @@ #include <fstream> #include <cassert> #include "freqdict.h" +#include "tdict.h" +#include "filelib.h" -void FreqDict::load(const std::string& fname) { - std::ifstream ifs(fname.c_str()); +using namespace std; + +void FreqDict::Load(const std::string& fname) { + cerr << "Reading word frequencies: " << fname << endl; + ReadFile rf(fname); + istream& ifs = *rf.stream(); int cc=0; - while (!ifs.eof()) { + while (ifs) { std::string word; ifs >> word; if (word.size() == 0) continue; @@ -14,7 +20,7 @@ void FreqDict::load(const std::string& fname) { double count = 0; ifs >> count; assert(count > 0.0); // use -log(f) - counts_[word]=count; + counts_[TD::Convert(word)]=count; ++cc; if (cc % 10000 == 0) { std::cerr << "."; } } diff --git a/src/freqdict.h b/src/freqdict.h index c9bb4c42..9acf0c33 100644 --- a/src/freqdict.h +++ b/src/freqdict.h @@ -3,17 +3,18 @@ #include <map> #include <string> +#include "wordid.h" class FreqDict { public: - void load(const std::string& fname); - float frequency(const std::string& word) const { - std::map<std::string,float>::const_iterator i = counts_.find(word); + void Load(const std::string& fname); + float LookUp(const WordID& word) const { + std::map<WordID,float>::const_iterator i = counts_.find(word); if (i == counts_.end()) return 0; return i->second; } private: - std::map<std::string, float> counts_; + std::map<WordID, float> counts_; }; #endif diff --git a/src/stringlib.h b/src/stringlib.h index d26952c7..76efee8f 100644 --- a/src/stringlib.h +++ b/src/stringlib.h @@ -88,4 +88,14 @@ inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::s void ProcessAndStripSGML(std::string* line, std::map<std::string, std::string>* out); +// given the first character of a UTF8 block, find out how wide it is +// see http://en.wikipedia.org/wiki/UTF-8 for more info +inline unsigned int UTF8Len(unsigned char x) { + if (x < 0x80) return 1; + else if ((x >> 5) == 0x06) return 2; + else if ((x >> 4) == 0x0e) return 3; + else if ((x >> 3) == 0x1e) return 4; + else return 0; +} + #endif |