From 3a7bca942d838f945c1cd0cbe5977e20c61ebc2d Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 18 Feb 2010 22:34:17 -0500 Subject: check in modified ones too --- decoder/Makefile.am | 3 +- decoder/cdec.cc | 17 ++++--- decoder/dict_test.cc | 17 +++++++ decoder/fdict.cc | 124 ++++++++++++++++++++++++++++++++++++++++++++++++ decoder/fdict.h | 3 ++ decoder/ff_wordalign.cc | 79 +++++++++++++++--------------- decoder/ff_wordalign.h | 6 +-- decoder/lexalign.cc | 34 ++++++++----- decoder/lextrans.cc | 12 ++--- decoder/lextrans.h | 12 ++--- decoder/stringlib.cc | 1 + 11 files changed, 231 insertions(+), 77 deletions(-) (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index d4e2a77c..81cd43e7 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -65,7 +65,8 @@ libcdec_a_SOURCES = \ ff_csplit.cc \ ff_tagger.cc \ freqdict.cc \ - lexcrf.cc \ + lexalign.cc \ + lextrans.cc \ tagger.cc \ bottom_up_parser.cc \ phrasebased_translator.cc \ diff --git a/decoder/cdec.cc b/decoder/cdec.cc index b130e7fd..811a0d04 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -18,7 +18,8 @@ #include "sampler.h" #include "sparse_vector.h" #include "tagger.h" -#include "lexcrf.h" +#include "lextrans.h" +#include "lexalign.h" #include "csplit.h" #include "weights.h" #include "tdict.h" @@ -50,7 +51,7 @@ void ShowBanner() { void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("formalism,f",po::value(),"Decoding formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting), Tagger (sequence labeling)") + ("formalism,f",po::value(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)") ("input,i",po::value()->default_value("-"),"Source file") ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("weights,w",po::value(),"Feature weights file") @@ -72,7 +73,7 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") ("beam_prune", po::value(), "Prune paths from +LM forest") - ("lexcrf_use_null", "Support source-side null words in lexical translation") + ("lexalign_use_null", "Support source-side null words in lexical translation") ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") @@ -117,8 +118,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } const string formalism = LowercaseString((*conf)["formalism"].as()); - if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit" && formalism != "tagger") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lexcrf', or 'tagger'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -273,8 +274,10 @@ int main(int argc, char** argv) { translator.reset(new PhraseBasedTranslator(conf)); else if (formalism == "csplit") translator.reset(new CompoundSplit(conf)); - else if (formalism == "lexcrf") - translator.reset(new LexicalCRF(conf)); + else if (formalism == "lextrans") + translator.reset(new LexicalTrans(conf)); + else if (formalism == "lexalign") + translator.reset(new LexicalAlign(conf)); else if (formalism == "tagger") translator.reset(new Tagger(conf)); else diff --git a/decoder/dict_test.cc b/decoder/dict_test.cc index 5c5d84f0..2049ec27 100644 --- a/decoder/dict_test.cc +++ b/decoder/dict_test.cc @@ -1,8 +1,13 @@ #include "dict.h" +#include "fdict.h" + +#include #include #include +using namespace std; + class DTest : public testing::Test { public: DTest() {} @@ -23,6 +28,18 @@ TEST_F(DTest, Convert) { EXPECT_EQ(d.Convert(b), "bar"); } +TEST_F(DTest, FDictTest) { + int fid = FD::Convert("First"); + EXPECT_GT(fid, 0); + EXPECT_EQ(FD::Convert(fid), "First"); + string x = FD::Escape("="); + cerr << x << endl; + EXPECT_NE(x, "="); + x = FD::Escape(";"); + cerr << x << endl; + EXPECT_NE(x, ";"); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/decoder/fdict.cc b/decoder/fdict.cc index 8218a5d3..7e1b0e1f 100644 --- a/decoder/fdict.cc +++ b/decoder/fdict.cc @@ -1,5 +1,129 @@ #include "fdict.h" +#include + +using namespace std; + Dict FD::dict_; bool FD::frozen_ = false; +static int HexPairValue(const char * code) { + int value = 0; + const char * pch = code; + for (;;) { + int digit = *pch++; + if (digit >= '0' && digit <= '9') { + value += digit - '0'; + } + else if (digit >= 'A' && digit <= 'F') { + value += digit - 'A' + 10; + } + else if (digit >= 'a' && digit <= 'f') { + value += digit - 'a' + 10; + } + else { + return -1; + } + if (pch == code + 2) + return value; + value <<= 4; + } +} + +int UrlDecode(const char *source, char *dest) +{ + char * start = dest; + + while (*source) { + switch (*source) { + case '+': + *(dest++) = ' '; + break; + case '%': + if (source[1] && source[2]) { + int value = HexPairValue(source + 1); + if (value >= 0) { + *(dest++) = value; + source += 2; + } + else { + *dest++ = '?'; + } + } + else { + *dest++ = '?'; + } + break; + default: + *dest++ = *source; + } + source++; + } + + *dest = 0; + return dest - start; +} + +int UrlEncode(const char *source, char *dest, unsigned max) { + static const char *digits = "0123456789ABCDEF"; + unsigned char ch; + unsigned len = 0; + char *start = dest; + + while (len < max - 4 && *source) + { + ch = (unsigned char)*source; + if (*source == ' ') { + *dest++ = '+'; + } + else if (strchr("=:;,_| %", ch)) { + *dest++ = '%'; + *dest++ = digits[(ch >> 4) & 0x0F]; + *dest++ = digits[ ch & 0x0F]; + } + else { + *dest++ = *source; + } + source++; + } + *dest = 0; + return start - dest; +} + +std::string UrlDecodeString(const std::string & encoded) { + const char * sz_encoded = encoded.c_str(); + size_t needed_length = encoded.length(); + for (const char * pch = sz_encoded; *pch; pch++) { + if (*pch == '%') + needed_length += 2; + } + needed_length += 10; + char stackalloc[64]; + char * buf = needed_length > sizeof(stackalloc)/sizeof(*stackalloc) ? + (char *)malloc(needed_length) : stackalloc; + UrlDecode(encoded.c_str(), buf); + std::string result(buf); + if (buf != stackalloc) { + free(buf); + } + return result; +} + +std::string UrlEncodeString(const std::string & decoded) { + const char * sz_decoded = decoded.c_str(); + size_t needed_length = decoded.length() * 3 + 3; + char stackalloc[64]; + char * buf = needed_length > sizeof(stackalloc)/sizeof(*stackalloc) ? + (char *)malloc(needed_length) : stackalloc; + UrlEncode(decoded.c_str(), buf, needed_length); + std::string result(buf); + if (buf != stackalloc) { + free(buf); + } + return result; +} + +string FD::Escape(const string& s) { + return UrlEncodeString(s); +} + diff --git a/decoder/fdict.h b/decoder/fdict.h index d05f1706..c4236580 100644 --- a/decoder/fdict.h +++ b/decoder/fdict.h @@ -20,6 +20,9 @@ struct FD { static inline const std::string& Convert(const WordID& w) { return dict_.Convert(w); } + // Escape any string to a form that can be used as the name + // of a weight in a weights file + static std::string Escape(const std::string& s); static Dict dict_; private: static bool frozen_; diff --git a/decoder/ff_wordalign.cc b/decoder/ff_wordalign.cc index fb90df62..669aa530 100644 --- a/decoder/ff_wordalign.cc +++ b/decoder/ff_wordalign.cc @@ -26,7 +26,7 @@ Model2BinaryFeatures::Model2BinaryFeatures(const string& param) : val = -1; if (j < i) { ostringstream os; - os << "M2_FL:" << i << "_SI:" << j << "_TI:" << k; + os << "M2FL:" << i << ":TI:" << k << "_SI:" << j; val = FD::Convert(os.str()); } } @@ -181,32 +181,27 @@ void MarkovJumpFClass::TraversalFeaturesImpl(const SentenceMetadata& smeta, } } +// std::vector > flen2jump2fid_; MarkovJump::MarkovJump(const string& param) : FeatureFunction(1), fid_(FD::Convert("MarkovJump")), - individual_params_per_jumpsize_(false), - condition_on_flen_(false) { + binary_params_(false) { cerr << " MarkovJump"; vector argv; int argc = SplitOnWhitespace(param, &argv); - if (argc > 0) { - if (argv[0] == "--fclasses") { - argc--; - assert(argc > 0); - const string f_class_file = argv[1]; - } - if (argc != 1 || !(argv[0] == "-f" || argv[0] == "-i" || argv[0] == "-if")) { - cerr << "MarkovJump: expected parameters to be -f, -i, or -if\n"; - exit(1); - } - individual_params_per_jumpsize_ = (argv[0][1] == 'i'); - condition_on_flen_ = (argv[0][argv[0].size() - 1] == 'f'); - if (individual_params_per_jumpsize_) { - template_ = "Jump:000"; - cerr << ", individual jump parameters"; - if (condition_on_flen_) { - template_ += ":F00"; - cerr << " (split by f-length)"; + if (argc != 1 || !(argv[0] == "-b" || argv[0] == "+b")) { + cerr << "MarkovJump: expected parameters to be -b or +b\n"; + exit(1); + } + binary_params_ = argv[0] == "+b"; + if (binary_params_) { + flen2jump2fid_.resize(MAX_SENTENCE_SIZE); + for (int i = 1; i < MAX_SENTENCE_SIZE; ++i) { + map& jump2fid = flen2jump2fid_[i]; + for (int jump = -i; jump <= i; ++jump) { + ostringstream os; + os << "Jump:FLen:" << i << "_J:" << jump; + jump2fid[jump] = FD::Convert(os.str()); } } } else { @@ -215,6 +210,7 @@ MarkovJump::MarkovJump(const string& param) : cerr << endl; } +// TODO handle NULLs according to Och 2000 void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const vector& ant_states, @@ -222,8 +218,24 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, SparseVector* estimated_features, void* state) const { unsigned char& dpstate = *((unsigned char*)state); + const int flen = smeta.GetSourceLength(); if (edge.Arity() == 0) { dpstate = static_cast(edge.i_); + if (edge.prev_i_ == 0) { + if (binary_params_) { + // NULL will be tricky + // TODO initial state distribution, not normal jumps + const int fid = flen2jump2fid_[flen].find(edge.i_ + 1)->second; + features->set_value(fid, 1.0); + } + } else if (edge.prev_i_ == smeta.GetTargetLength() - 1) { + // NULL will be tricky + if (binary_params_) { + int jumpsize = flen - edge.i_; + const int fid = flen2jump2fid_[flen].find(jumpsize)->second; + features->set_value(fid, 1.0); + } + } } else if (edge.Arity() == 1) { dpstate = *((unsigned char*)ant_states[0]); } else if (edge.Arity() == 2) { @@ -234,27 +246,12 @@ void MarkovJump::TraversalFeaturesImpl(const SentenceMetadata& smeta, else dpstate = static_cast(right_index); const int jumpsize = right_index - left_index; - features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def - if (individual_params_per_jumpsize_) { - string fname = template_; - int param = jumpsize; - if (jumpsize < 0) { - param *= -1; - fname[5]='L'; - } else if (jumpsize > 0) { - fname[5]='R'; - } - if (param) { - fname[6] = '0' + (param / 10); - fname[7] = '0' + (param % 10); - } - if (condition_on_flen_) { - const int flen = smeta.GetSourceLength(); - fname[10] = '0' + (flen / 10); - fname[11] = '0' + (flen % 10); - } - features->set_value(FD::Convert(fname), 1.0); + if (binary_params_) { + const int fid = flen2jump2fid_[flen].find(jumpsize)->second; + features->set_value(fid, 1.0); + } else { + features->set_value(fid_, fabs(jumpsize - 1)); // Blunsom and Cohn def } } else { assert(!"something really unexpected is happening"); diff --git a/decoder/ff_wordalign.h b/decoder/ff_wordalign.h index 688750de..c44ad26b 100644 --- a/decoder/ff_wordalign.h +++ b/decoder/ff_wordalign.h @@ -49,10 +49,8 @@ class MarkovJump : public FeatureFunction { void* out_context) const; private: const int fid_; - bool individual_params_per_jumpsize_; - bool condition_on_flen_; - bool condition_on_fclass_; - std::string template_; + bool binary_params_; + std::vector > flen2jump2fid_; }; class MarkovJumpFClass : public FeatureFunction { diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index ee3b5fe0..8dd77c53 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -31,17 +31,24 @@ struct LexicalAlignImpl { const WordID& e_i = target[i][0].label; Hypergraph::Node* node = forest->AddNode(kXCAT); const int new_node_id = node->id_; + int num_srcs = 0; for (int j = f_start; j < f_len; ++j) { // for each word in the source const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); - TRulePtr& rule = LexRule(src_sym, e_i); - Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); - edge->i_ = j; - edge->j_ = j+1; - edge->prev_i_ = i; - edge->prev_j_ = i+1; - edge->feature_values_ += edge->rule_->GetFeatureValues(); - forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + const TRulePtr& rule = LexRule(src_sym, e_i); + if (rule) { + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->prev_i_ = i; + edge->prev_j_ = i+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + ++num_srcs; + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } else { + cerr << TD::Convert(src_sym) << " does not translate to " << TD::Convert(e_i) << endl; + } } + assert(num_srcs > 0); if (prev_node_id >= 0) { const int comb_node_id = forest->AddNode(kXCAT)->id_; Hypergraph::TailNodeVector tail(2, prev_node_id); @@ -66,21 +73,23 @@ struct LexicalAlignImpl { return it->second; int& fid = e2fid[e]; if (f == 0) { - fid = FD::Convert("Lx__" + FD::Escape(TD::Convert(e))); + fid = FD::Convert("Lx:_" + FD::Escape(TD::Convert(e))); } else { - fid = FD::Convert("Lx_" + FD::Escape(TD::Convert(f)) + "_" + FD::Escape(TD::Convert(e))); + fid = FD::Convert("Lx:" + FD::Escape(TD::Convert(f)) + "_" + FD::Escape(TD::Convert(e))); } return fid; } - inline TRulePtr& LexRule(const WordID& f, const WordID& e) { + inline const TRulePtr& LexRule(const WordID& f, const WordID& e) { + const int fid = LexFeatureId(f, e); + if (!fid) { return kNULL_PTR; } map& e2rule = f2e2rule[f]; map::iterator it = e2rule.find(e); if (it != e2rule.end()) return it->second; TRulePtr& tr = e2rule[e]; tr.reset(TRule::CreateLexicalRule(f, e)); - tr->scores_.set_value(LexFeatureId(f, e), 1.0); + tr->scores_.set_value(fid, 1.0); return tr; } @@ -90,6 +99,7 @@ struct LexicalAlignImpl { const WordID kNULL; const TRulePtr kBINARY; const TRulePtr kGOAL_RULE; + const TRulePtr kNULL_PTR; map > f2e2rule; map > f2e2fid; GrammarPtr grammar; diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index b0e03c69..e7fa1aa1 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -1,4 +1,4 @@ -#include "lexcrf.h" +#include "lextrans.h" #include @@ -10,8 +10,8 @@ using namespace std; -struct LexicalCRFImpl { - LexicalCRFImpl(const boost::program_options::variables_map& conf) : +struct LexicalTransImpl { + LexicalTransImpl(const boost::program_options::variables_map& conf) : use_null(conf.count("lexcrf_use_null") > 0), kXCAT(TD::Convert("X")*-1), kNULL(TD::Convert("")), @@ -95,10 +95,10 @@ struct LexicalCRFImpl { GrammarPtr grammar; }; -LexicalCRF::LexicalCRF(const boost::program_options::variables_map& conf) : - pimpl_(new LexicalCRFImpl(conf)) {} +LexicalTrans::LexicalTrans(const boost::program_options::variables_map& conf) : + pimpl_(new LexicalTransImpl(conf)) {} -bool LexicalCRF::Translate(const string& input, +bool LexicalTrans::Translate(const string& input, SentenceMetadata* smeta, const vector& weights, Hypergraph* forest) { diff --git a/decoder/lextrans.h b/decoder/lextrans.h index 99362c81..9920f79c 100644 --- a/decoder/lextrans.h +++ b/decoder/lextrans.h @@ -1,18 +1,18 @@ -#ifndef _LEXCRF_H_ -#define _LEXCRF_H_ +#ifndef _LEXTrans_H_ +#define _LEXTrans_H_ #include "translator.h" #include "lattice.h" -struct LexicalCRFImpl; -struct LexicalCRF : public Translator { - LexicalCRF(const boost::program_options::variables_map& conf); +struct LexicalTransImpl; +struct LexicalTrans : public Translator { + LexicalTrans(const boost::program_options::variables_map& conf); bool Translate(const std::string& input, SentenceMetadata* smeta, const std::vector& weights, Hypergraph* forest); private: - boost::shared_ptr pimpl_; + boost::shared_ptr pimpl_; }; #endif diff --git a/decoder/stringlib.cc b/decoder/stringlib.cc index 3ed74bef..3e52ae87 100644 --- a/decoder/stringlib.cc +++ b/decoder/stringlib.cc @@ -1,5 +1,6 @@ #include "stringlib.h" +#include #include #include #include -- cgit v1.2.3