summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-06 22:25:25 -0500
committerChris Dyer <redpony@gmail.com>2009-12-06 22:25:25 -0500
commit2a18010e255810cc2b5bcbe688f3db8eabda23ca (patch)
treee310286257e5445072303dcca03acb85a865c26a
parent59ea352f3dcf3bf58969f404615fed4ff6b931f7 (diff)
add compound splitting logic and features (Dyer 2008, NAACL)
-rw-r--r--src/Makefile.am3
-rw-r--r--src/cdec.cc11
-rw-r--r--src/cdec_ff.cc5
-rw-r--r--src/csplit.cc152
-rw-r--r--src/csplit.h18
-rw-r--r--src/ff_csplit.cc197
-rw-r--r--src/ff_csplit.h39
-rw-r--r--src/freqdict.cc14
-rw-r--r--src/freqdict.h9
-rw-r--r--src/stringlib.h10
10 files changed, 445 insertions, 13 deletions
diff --git a/src/Makefile.am b/src/Makefile.am
index 34ffb170..4d0459ef 100644
--- a/src/Makefile.am
+++ b/src/Makefile.am
@@ -36,6 +36,7 @@ noinst_LIBRARIES = libhg.a
libhg_a_SOURCES = \
fst_translator.cc \
+ csplit.cc \
scfg_translator.cc \
hg.cc \
hg_io.cc \
@@ -58,6 +59,8 @@ libhg_a_SOURCES = \
ff.cc \
ff_lm.cc \
ff_wordalign.cc \
+ ff_csplit.cc \
+ freqdict.cc \
lexcrf.cc \
bottom_up_parser.cc \
phrasebased_translator.cc \
diff --git a/src/cdec.cc b/src/cdec.cc
index c5780cef..7bdf7bcc 100644
--- a/src/cdec.cc
+++ b/src/cdec.cc
@@ -17,6 +17,7 @@
#include "sampler.h"
#include "sparse_vector.h"
#include "lexcrf.h"
+#include "csplit.h"
#include "weights.h"
#include "tdict.h"
#include "ff.h"
@@ -46,7 +47,7 @@ void ShowBanner() {
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, or PB. Specify LexicalCRF for experimental unsupervised CRF word alignment")
+ ("formalism,f",po::value<string>()->default_value("scfg"),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSPLIT (compound splitting)")
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
@@ -100,14 +101,14 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
exit(1);
}
- if (conf->count("help") || conf->count("grammar") == 0) {
+ if (conf->count("help") || conf->count("formalism") == 0) {
cerr << dcmdline_options << endl;
exit(1);
}
const string formalism = LowercaseString((*conf)["formalism"].as<string>());
- if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', or 'lexcrf'\n";
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit' or 'lexcrf'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -231,6 +232,8 @@ int main(int argc, char** argv) {
translator.reset(new FSTTranslator(conf));
else if (formalism == "pb")
translator.reset(new PhraseBasedTranslator(conf));
+ else if (formalism == "csplit")
+ translator.reset(new CompoundSplit(conf));
else if (formalism == "lexcrf")
translator.reset(new LexicalCRF(conf));
else
diff --git a/src/cdec_ff.cc b/src/cdec_ff.cc
index 89353f5f..846a908e 100644
--- a/src/cdec_ff.cc
+++ b/src/cdec_ff.cc
@@ -2,8 +2,9 @@
#include "ff.h"
#include "ff_lm.h"
-#include "ff_factory.h"
+#include "ff_csplit.h"
#include "ff_wordalign.h"
+#include "ff_factory.h"
boost::shared_ptr<FFRegistry> global_ff_registry;
@@ -14,5 +15,7 @@ void register_feature_functions() {
global_ff_registry->Register("MarkovJump", new FFFactory<MarkovJump>);
global_ff_registry->Register("BlunsomSynchronousParseHack", new FFFactory<BlunsomSynchronousParseHack>);
global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>);
+ global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>);
+ global_ff_registry->Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>);
};
diff --git a/src/csplit.cc b/src/csplit.cc
new file mode 100644
index 00000000..788f3112
--- /dev/null
+++ b/src/csplit.cc
@@ -0,0 +1,152 @@
+#include "csplit.h"
+
+#include <iostream>
+
+#include "filelib.h"
+#include "stringlib.h"
+#include "hg.h"
+#include "tdict.h"
+#include "grammar.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+struct CompoundSplitImpl {
+ CompoundSplitImpl(const boost::program_options::variables_map& conf) :
+ fugen_elements_(true), // TODO configure
+ min_size_(3),
+ kXCAT(TD::Convert("X")*-1),
+ kWORDBREAK_RULE(new TRule("[X] ||| # ||| #")),
+ kTEMPLATE_RULE(new TRule("[X] ||| [X,1] ? ||| [1] ?")),
+ kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")),
+ kFUGEN_S(FD::Convert("FugS")),
+ kFUGEN_N(FD::Convert("FugN")) {}
+
+ void PasteTogetherStrings(const vector<string>& chars,
+ const int i,
+ const int j,
+ string* yield) {
+ int size = 0;
+ for (int k=i; k<j; ++k)
+ size += chars[k].size();
+ yield->resize(size);
+ int cur = 0;
+ for (int k=i; k<j; ++k) {
+ const string& cs = chars[k];
+ for (int l = 0; l < cs.size(); ++l)
+ (*yield)[cur++] = cs[l];
+ }
+ }
+
+ void BuildTrellis(const vector<string>& chars,
+ Hypergraph* forest) {
+ vector<int> nodes(chars.size()+1, -1);
+ nodes[0] = forest->AddNode(kXCAT)->id_; // source
+ const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_;
+ forest->ConnectEdgeToHeadNode(left_rule, nodes[0]);
+
+ const int max_split_ = chars.size() - min_size_ + 1;
+ for (int i = min_size_; i < max_split_; ++i)
+ nodes[i] = forest->AddNode(kXCAT)->id_;
+ assert(nodes.back() == -1);
+ nodes.back() = forest->AddNode(kXCAT)->id_; // sink
+
+ for (int i = 0; i < max_split_; ++i) {
+ if (nodes[i] < 0) continue;
+ for (int j = i + min_size_; j <= chars.size(); ++j) {
+ if (nodes[j] < 0) continue;
+ string yield;
+ PasteTogetherStrings(chars, i, j, &yield);
+ // cerr << "[" << i << "," << j << "] " << yield << endl;
+ TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE));
+ rule->e_[1] = rule->f_[1] = TD::Convert(yield);
+ // cerr << rule->AsString() << endl;
+ int edge = forest->AddEdge(
+ rule,
+ Hypergraph::TailNodeVector(1, nodes[i]))->id_;
+ forest->ConnectEdgeToHeadNode(edge, nodes[j]);
+ forest->edges_[edge].i_ = i;
+ forest->edges_[edge].j_ = j;
+
+ // handle "fugenelemente" here
+ // don't delete "fugenelemente" at the end of words
+ if (fugen_elements_ && j != chars.size()) {
+ const int len = yield.size();
+ string alt;
+ int fid = 0;
+ if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') {
+ alt = yield.substr(0, len - 2);
+ fid = kFUGEN_S;
+ } else if (len > (min_size_ + 1) && yield[len-1] == 's') {
+ alt = yield.substr(0, len - 1);
+ fid = kFUGEN_S;
+ } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') {
+ alt = yield.substr(0, len - 1);
+ fid = kFUGEN_N;
+ }
+ if (alt.size()) {
+ TRulePtr altrule = TRulePtr(new TRule(*rule));
+ altrule->e_[1] = TD::Convert(alt);
+ // cerr << altrule->AsString() << endl;
+ int edge = forest->AddEdge(
+ altrule,
+ Hypergraph::TailNodeVector(1, nodes[i]))->id_;
+ forest->ConnectEdgeToHeadNode(edge, nodes[j]);
+ forest->edges_[edge].feature_values_.set_value(fid, 1.0);
+ forest->edges_[edge].i_ = i;
+ forest->edges_[edge].j_ = j;
+ }
+ }
+ }
+ }
+
+ // add goal rule
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ }
+ private:
+ const bool fugen_elements_;
+ const int min_size_;
+ const WordID kXCAT;
+ const TRulePtr kWORDBREAK_RULE;
+ const TRulePtr kTEMPLATE_RULE;
+ const TRulePtr kGOAL_RULE;
+ const int kFUGEN_S;
+ const int kFUGEN_N;
+};
+
+CompoundSplit::CompoundSplit(const boost::program_options::variables_map& conf) :
+ pimpl_(new CompoundSplitImpl(conf)) {}
+
+static void SplitUTF8String(const string& in, vector<string>* out) {
+ out->resize(in.size());
+ int i = 0;
+ int c = 0;
+ while (i < in.size()) {
+ const int len = UTF8Len(in[i]);
+ assert(len);
+ (*out)[c] = in.substr(i, len);
+ ++c;
+ i += len;
+ }
+ out->resize(c);
+}
+
+bool CompoundSplit::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ if (input.find(" ") != string::npos) {
+ cerr << " BAD INPUT: " << input << "\n CompoundSplit expects single words\n";
+ abort();
+ }
+ vector<string> in;
+ SplitUTF8String(input, &in);
+ smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign
+ pimpl_->BuildTrellis(in, forest);
+ forest->Reweight(weights);
+ return true;
+}
+
diff --git a/src/csplit.h b/src/csplit.h
new file mode 100644
index 00000000..5911af77
--- /dev/null
+++ b/src/csplit.h
@@ -0,0 +1,18 @@
+#ifndef _CSPLIT_H_
+#define _CSPLIT_H_
+
+#include "translator.h"
+#include "lattice.h"
+
+struct CompoundSplitImpl;
+struct CompoundSplit : public Translator {
+ CompoundSplit(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* forest);
+ private:
+ boost::shared_ptr<CompoundSplitImpl> pimpl_;
+};
+
+#endif
diff --git a/src/ff_csplit.cc b/src/ff_csplit.cc
new file mode 100644
index 00000000..5d8dfefb
--- /dev/null
+++ b/src/ff_csplit.cc
@@ -0,0 +1,197 @@
+#include "ff_csplit.h"
+
+#include <set>
+#include <cstring>
+
+#include "tdict.h"
+#include "freqdict.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "tdict.h"
+
+#include "Vocab.h"
+#include "Ngram.h"
+
+using namespace std;
+
+struct BasicCSplitFeaturesImpl {
+ BasicCSplitFeaturesImpl(const string& param) :
+ word_count_(FD::Convert("WordCount")),
+ in_dict_(FD::Convert("InDict")),
+ short_(FD::Convert("Short")),
+ long_(FD::Convert("Long")),
+ oov_(FD::Convert("OOV")),
+ short_range_(FD::Convert("ShortRange")),
+ high_freq_(FD::Convert("HighFreq")),
+ med_freq_(FD::Convert("MedFreq")),
+ freq_(FD::Convert("Freq")),
+ bad_(FD::Convert("Bad")) {
+ vector<string> argv;
+ int argc = SplitOnWhitespace(param, &argv);
+ if (argc != 1 && argc != 2) {
+ cerr << "Expected: freqdict.txt [badwords.txt]\n";
+ abort();
+ }
+ freq_dict_.Load(argv[0]);
+ if (argc == 2) {
+ ReadFile rf(argv[1]);
+ istream& in = *rf.stream();
+ while(in) {
+ string badword;
+ in >> badword;
+ if (badword.empty()) continue;
+ bad_words_.insert(TD::Convert(badword));
+ }
+ }
+ }
+
+ void TraversalFeaturesImpl(const Hypergraph::Edge& edge,
+ SparseVector<double>* features) const;
+
+ const int word_count_;
+ const int in_dict_;
+ const int short_;
+ const int long_;
+ const int oov_;
+ const int short_range_;
+ const int high_freq_;
+ const int med_freq_;
+ const int freq_;
+ const int bad_;
+ FreqDict freq_dict_;
+ set<WordID> bad_words_;
+};
+
+BasicCSplitFeatures::BasicCSplitFeatures(const string& param) :
+ pimpl_(new BasicCSplitFeaturesImpl(param)) {}
+
+void BasicCSplitFeaturesImpl::TraversalFeaturesImpl(
+ const Hypergraph::Edge& edge,
+ SparseVector<double>* features) const {
+ features->set_value(word_count_, 1.0);
+ const WordID word = edge.rule_->e_[1];
+ const char* sword = TD::Convert(word);
+ const int len = strlen(sword);
+ int cur = 0;
+ int chars = 0;
+ while(cur < len) {
+ cur += UTF8Len(sword[cur]);
+ ++chars;
+ }
+ bool has_sch = strstr(sword, "sch");
+ bool has_ch = (!has_sch && strstr(sword, "ch"));
+ bool has_ie = strstr(sword, "ie");
+ bool has_zw = strstr(sword, "zw");
+ if (has_sch) chars -= 2;
+ if (has_ch) --chars;
+ if (has_ie) --chars;
+ if (has_zw) --chars;
+
+ float freq = freq_dict_.LookUp(word);
+ if (freq) {
+ features->set_value(freq_, freq);
+ features->set_value(in_dict_, 1.0);
+ } else {
+ features->set_value(oov_, 1.0);
+ freq = 99.0f;
+ }
+ if (bad_words_.count(word) != 0)
+ features->set_value(bad_, 1.0);
+ if (chars < 5)
+ features->set_value(short_, 1.0);
+ if (chars > 10)
+ features->set_value(long_, 1.0);
+ if (freq < 7.0f)
+ features->set_value(high_freq_, 1.0);
+ if (freq > 8.0f && freq < 10.f)
+ features->set_value(med_freq_, 1.0);
+ if (freq < 10.0f && chars < 5)
+ features->set_value(short_range_, 1.0);
+}
+
+void BasicCSplitFeatures::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const {
+ (void) smeta;
+ (void) ant_contexts;
+ (void) out_context;
+ (void) estimated_features;
+ if (edge.Arity() == 0) return;
+ if (edge.rule_->EWords() != 1) return;
+ pimpl_->TraversalFeaturesImpl(edge, features);
+}
+
+struct ReverseCharLMCSplitFeatureImpl {
+ ReverseCharLMCSplitFeatureImpl(const string& param) :
+ order_(5),
+ ngram_(vocab_, order_) {
+ kBOS = vocab_.getIndex("<s>");
+ kEOS = vocab_.getIndex("</s>");
+ File file(param.c_str(), "r", 0);
+ assert(file);
+ cerr << "Reading " << order_ << "-gram LM from " << param << endl;
+ ngram_.read(file);
+ }
+
+ double LeftPhonotacticProb(const char* word) {
+ for (int i = 0; i < order_; ++i)
+ sc[i] = kBOS;
+ const int len = strlen(word);
+ int cur = 0;
+ int chars = 0;
+ while(cur < len) {
+ cur += UTF8Len(word[cur]);
+ ++chars;
+ }
+ const int sp = min(chars, order_-1);
+ int wend = 0; cur = 0;
+ while(cur < sp) {
+ wend += UTF8Len(word[wend]);
+ ++cur;
+ }
+ int wi = 0;
+ int ci = (order_ - sp - 1);
+ // cerr << "WORD: " << word << endl;
+ while (wi != wend) {
+ const int clen = UTF8Len(word[wi]);
+ string cur_char(&word[wi], clen);
+ wi += clen;
+ // cerr << " char: " << cur_char << " ci=" << ci << endl;
+ sc[ci++] = vocab_.getIndex(cur_char.c_str());
+ }
+ // cerr << " END sp=" << sp << endl;
+ sc[sp] = Vocab_None;
+ const double startprob = -ngram_.wordProb(kEOS, sc);
+ // cerr << " PROB=" << startprob << endl;
+ return startprob;
+ }
+ private:
+ const int order_;
+ Vocab vocab_;
+ VocabIndex kBOS;
+ VocabIndex kEOS;
+ Ngram ngram_;
+ VocabIndex sc[80];
+};
+
+ReverseCharLMCSplitFeature::ReverseCharLMCSplitFeature(const string& param) :
+ pimpl_(new ReverseCharLMCSplitFeatureImpl(param)),
+ fid_(FD::Convert("RevCharLM")) {}
+
+void ReverseCharLMCSplitFeature::TraversalFeaturesImpl(
+ const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const {
+ if (edge.Arity() != 1) return;
+ if (edge.rule_->EWords() != 1) return;
+ const double lpp = pimpl_->LeftPhonotacticProb(TD::Convert(edge.rule_->e_[1]));
+ features->set_value(fid_, lpp);
+}
+
diff --git a/src/ff_csplit.h b/src/ff_csplit.h
new file mode 100644
index 00000000..c1cfb64b
--- /dev/null
+++ b/src/ff_csplit.h
@@ -0,0 +1,39 @@
+#ifndef _FF_CSPLIT_H_
+#define _FF_CSPLIT_H_
+
+#include <boost/shared_ptr.hpp>
+
+#include "ff.h"
+
+class BasicCSplitFeaturesImpl;
+class BasicCSplitFeatures : public FeatureFunction {
+ public:
+ BasicCSplitFeatures(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ boost::shared_ptr<BasicCSplitFeaturesImpl> pimpl_;
+};
+
+class ReverseCharLMCSplitFeatureImpl;
+class ReverseCharLMCSplitFeature : public FeatureFunction {
+ public:
+ ReverseCharLMCSplitFeature(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* out_context) const;
+ private:
+ boost::shared_ptr<ReverseCharLMCSplitFeatureImpl> pimpl_;
+ const int fid_;
+};
+
+#endif
diff --git a/src/freqdict.cc b/src/freqdict.cc
index 4cfffe58..9e25d346 100644
--- a/src/freqdict.cc
+++ b/src/freqdict.cc
@@ -2,11 +2,17 @@
#include <fstream>
#include <cassert>
#include "freqdict.h"
+#include "tdict.h"
+#include "filelib.h"
-void FreqDict::load(const std::string& fname) {
- std::ifstream ifs(fname.c_str());
+using namespace std;
+
+void FreqDict::Load(const std::string& fname) {
+ cerr << "Reading word frequencies: " << fname << endl;
+ ReadFile rf(fname);
+ istream& ifs = *rf.stream();
int cc=0;
- while (!ifs.eof()) {
+ while (ifs) {
std::string word;
ifs >> word;
if (word.size() == 0) continue;
@@ -14,7 +20,7 @@ void FreqDict::load(const std::string& fname) {
double count = 0;
ifs >> count;
assert(count > 0.0); // use -log(f)
- counts_[word]=count;
+ counts_[TD::Convert(word)]=count;
++cc;
if (cc % 10000 == 0) { std::cerr << "."; }
}
diff --git a/src/freqdict.h b/src/freqdict.h
index c9bb4c42..9acf0c33 100644
--- a/src/freqdict.h
+++ b/src/freqdict.h
@@ -3,17 +3,18 @@
#include <map>
#include <string>
+#include "wordid.h"
class FreqDict {
public:
- void load(const std::string& fname);
- float frequency(const std::string& word) const {
- std::map<std::string,float>::const_iterator i = counts_.find(word);
+ void Load(const std::string& fname);
+ float LookUp(const WordID& word) const {
+ std::map<WordID,float>::const_iterator i = counts_.find(word);
if (i == counts_.end()) return 0;
return i->second;
}
private:
- std::map<std::string, float> counts_;
+ std::map<WordID, float> counts_;
};
#endif
diff --git a/src/stringlib.h b/src/stringlib.h
index d26952c7..76efee8f 100644
--- a/src/stringlib.h
+++ b/src/stringlib.h
@@ -88,4 +88,14 @@ inline void SplitCommandAndParam(const std::string& in, std::string* cmd, std::s
void ProcessAndStripSGML(std::string* line, std::map<std::string, std::string>* out);
+// given the first character of a UTF8 block, find out how wide it is
+// see http://en.wikipedia.org/wiki/UTF-8 for more info
+inline unsigned int UTF8Len(unsigned char x) {
+ if (x < 0x80) return 1;
+ else if ((x >> 5) == 0x06) return 2;
+ else if ((x >> 4) == 0x0e) return 3;
+ else if ((x >> 3) == 0x1e) return 4;
+ else return 0;
+}
+
#endif