diff options
39 files changed, 617 insertions, 741 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index f02299e6..8280b22c 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -61,12 +61,10 @@ libcdec_a_SOURCES = \ ff_ruleshape.h \ ff_sample_fsa.h \ ff_soft_syntax.h \ - ff_soft_syntax2.h \ + ff_soft_syntax_mindist.h \ ff_source_path.h \ ff_source_syntax.h \ ff_source_syntax2.h \ - ff_source_syntax2_p.h \ - ff_source_syntax_p.h \ ff_spans.h \ ff_tagger.h \ ff_wordalign.h \ @@ -127,12 +125,10 @@ libcdec_a_SOURCES = \ ff_rules.cc \ ff_ruleshape.cc \ ff_soft_syntax.cc \ - ff_soft_syntax2.cc \ + ff_soft_syntax_mindist.cc \ ff_source_path.cc \ ff_source_syntax.cc \ ff_source_syntax2.cc \ - ff_source_syntax2_p.cc \ - ff_source_syntax_p.cc \ ff_spans.cc \ ff_tagger.cc \ ff_wordalign.cc \ diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 09597e87..d586c1d1 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -15,17 +15,11 @@ #include "ff_ruleshape.h" #include "ff_bleu.h" #include "ff_soft_syntax.h" -#include "ff_soft_syntax2.h" +#include "ff_soft_syntax_mindist.h" #include "ff_source_path.h" - - #include "ff_parse_match.h" #include "ff_source_syntax.h" -#include "ff_source_syntax_p.h" #include "ff_source_syntax2.h" -#include "ff_source_syntax2_p.h" - - #include "ff_register.h" #include "ff_charset.h" #include "ff_wordset.h" @@ -51,23 +45,12 @@ void register_feature_functions() { ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>()); ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>()); ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>()); - - ff_registry.Register("ParseMatchFeatures", new FFFactory<ParseMatchFeatures>); - - ff_registry.Register("SoftSyntacticFeatures", new FFFactory<SoftSyntacticFeatures>); - ff_registry.Register("SoftSyntacticFeatures2", new FFFactory<SoftSyntacticFeatures2>); - + ff_registry.Register("SoftSyntaxFeatures", new FFFactory<SoftSyntaxFeatures>); + ff_registry.Register("SoftSyntaxFeaturesMindist", new FFFactory<SoftSyntaxFeaturesMindist>); ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>); - ff_registry.Register("SourceSyntaxFeatures2", new FFFactory<SourceSyntaxFeatures2>); - ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>); - - //ff_registry.Register("PSourceSyntaxFeatures", new FFFactory<PSourceSyntaxFeatures>); - //ff_registry.Register("PSourceSpanSizeFeatures", new FFFactory<PSourceSpanSizeFeatures>); - //ff_registry.Register("PSourceSyntaxFeatures2", new FFFactory<PSourceSyntaxFeatures2>); - - + ff_registry.Register("SourceSyntaxFeatures2", new FFFactory<SourceSyntaxFeatures2>); ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>()); ff_registry.Register("RuleSourceBigramFeatures", new FFFactory<RuleSourceBigramFeatures>()); ff_registry.Register("RuleTargetBigramFeatures", new FFFactory<RuleTargetBigramFeatures>()); diff --git a/decoder/ff_parse_match.cc b/decoder/ff_parse_match.cc index ed556b91..58026975 100644 --- a/decoder/ff_parse_match.cc +++ b/decoder/ff_parse_match.cc @@ -42,10 +42,8 @@ struct ParseMatchFeaturesImpl { void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -112,7 +110,7 @@ struct ParseMatchFeaturesImpl { int fid_ef = FD::Convert("PM"); int min_dist; // minimal distance to next syntactic constituent of this rule's LHS int summed_min_dists; // minimal distances of LHS and NTs summed up - if (TD::Convert(lhs).compare("XX") != 0) + if (TD::Convert(lhs).compare("XX") != 0) min_dist= 0; // compute the distance to the next syntactical constituent else { @@ -131,7 +129,7 @@ struct ParseMatchFeaturesImpl { ok = 1; break; } - // check if removing k words from the rule span will + // check if removing k words from the rule span will // lead to a syntactical constituent else { //cerr << "Hilfe...!" << endl; @@ -144,7 +142,7 @@ struct ParseMatchFeaturesImpl { ok = 1; break; } - } + } } if (ok) break; } @@ -183,9 +181,9 @@ struct ParseMatchFeaturesImpl { return min_dist; } - Array2D<WordID> src_tree; // src_tree(i,j) NT = type + Array2D<WordID> src_tree; // src_tree(i,j) NT = type unsigned int src_sent_len; - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized int scoring_method; }; @@ -214,5 +212,9 @@ void ParseMatchFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, } void ParseMatchFeatures::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); + ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree")); + string tree; + f.ReadAll(tree); + impl->InitializeGrids(tree, smeta.GetSourceLength()); } + diff --git a/decoder/ff_parse_match.h b/decoder/ff_parse_match.h index fa73481a..7820b418 100644 --- a/decoder/ff_parse_match.h +++ b/decoder/ff_parse_match.h @@ -23,3 +23,4 @@ class ParseMatchFeatures : public FeatureFunction { }; #endif + diff --git a/decoder/ff_soft_syntax.cc b/decoder/ff_soft_syntax.cc index 9981fa45..23fe87bd 100644 --- a/decoder/ff_soft_syntax.cc +++ b/decoder/ff_soft_syntax.cc @@ -13,16 +13,15 @@ using namespace std; -// Implements the soft syntactic features described in +// Implements the soft syntactic features described in // Marton and Resnik (2008): "Soft Syntacitc Constraints for Hierarchical Phrase-Based Translation". // Source trees must be represented in Penn Treebank format, // e.g. (S (NP John) (VP (V left))). -struct SoftSyntacticFeaturesImpl { - SoftSyntacticFeaturesImpl(const string& param) { +struct SoftSyntaxFeaturesImpl { + SoftSyntaxFeaturesImpl(const string& param) { vector<string> labels = SplitOnWhitespace(param); - for (unsigned int i = 0; i < labels.size(); i++) - //cerr << "Labels: " << labels.at(i) << endl; + //for (unsigned int i = 0; i < labels.size(); i++) { cerr << "Labels: " << labels.at(i) << endl; } for (unsigned int i = 0; i < labels.size(); i++) { string label = labels.at(i); pair<string, string> feat_label; @@ -34,10 +33,8 @@ struct SoftSyntacticFeaturesImpl { void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -99,7 +96,7 @@ struct SoftSyntacticFeaturesImpl { const WordID lhs = src_tree(i,j); string lhs_str = TD::Convert(lhs); //cerr << "LHS: " << lhs_str << " from " << i << " to " << j << endl; - //cerr << "RULE :"<< rule << endl; + //cerr << "RULE :"<< rule << endl; int& fid_ef = fids_ef(i,j)[&rule]; for (unsigned int i = 0; i < feat_labels.size(); i++) { ostringstream os; @@ -110,10 +107,10 @@ struct SoftSyntacticFeaturesImpl { switch(feat_type) { case '2': if (lhs_str.compare(label) == 0) { - os << "SYN:" << label << "_conform"; + os << "SOFT:" << label << "_conform"; } else { - os << "SYN:" << label << "_cross"; + os << "SOFT:" << label << "_cross"; } fid_ef = FD::Convert(os.str()); if (fid_ef > 0) { @@ -122,11 +119,11 @@ struct SoftSyntacticFeaturesImpl { } break; case '_': - os << "SYN:" << label; + os << "SOFT:" << label; fid_ef = FD::Convert(os.str()); if (lhs_str.compare(label) == 0) { if (fid_ef > 0) { - //cerr << "Feature: " << os.str() << endl; + //cerr << "Feature: " << os.str() << endl; feats->set_value(fid_ef, 1.0); } } @@ -139,7 +136,7 @@ struct SoftSyntacticFeaturesImpl { break; case '+': if (lhs_str.compare(label) == 0) { - os << "SYN:" << label << "_conform"; + os << "SOFT:" << label << "_conform"; fid_ef = FD::Convert(os.str()); if (fid_ef > 0) { //cerr << "Feature: " << os.str() << endl; @@ -147,10 +144,10 @@ struct SoftSyntacticFeaturesImpl { } } break; - case '-': - //cerr << "-" << endl; + case '-': + //cerr << "-" << endl; if (lhs_str.compare(label) != 0) { - os << "SYN:" << label << "_cross"; + os << "SOFT:" << label << "_cross"; fid_ef = FD::Convert(os.str()); if (fid_ef > 0) { //cerr << "Feature :" << os.str() << endl; @@ -167,22 +164,22 @@ struct SoftSyntacticFeaturesImpl { return lhs; } - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + Array2D<WordID> src_tree; // src_tree(i,j) NT = type + mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized vector<pair<string, string> > feat_labels; }; -SoftSyntacticFeatures::SoftSyntacticFeatures(const string& param) : +SoftSyntaxFeatures::SoftSyntaxFeatures(const string& param) : FeatureFunction(sizeof(WordID)) { - impl = new SoftSyntacticFeaturesImpl(param); + impl = new SoftSyntaxFeaturesImpl(param); } -SoftSyntacticFeatures::~SoftSyntacticFeatures() { +SoftSyntaxFeatures::~SoftSyntaxFeatures() { delete impl; impl = NULL; } -void SoftSyntacticFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void SoftSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const vector<const void*>& ant_contexts, SparseVector<double>* features, @@ -196,6 +193,10 @@ void SoftSyntacticFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); } -void SoftSyntacticFeatures::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); +void SoftSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { + ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree")); + string tree; + f.ReadAll(tree); + impl->InitializeGrids(tree, smeta.GetSourceLength()); } + diff --git a/decoder/ff_soft_syntax.h b/decoder/ff_soft_syntax.h index 79352f49..e71825d5 100644 --- a/decoder/ff_soft_syntax.h +++ b/decoder/ff_soft_syntax.h @@ -1,15 +1,15 @@ -#ifndef _FF_SOFTSYNTAX_H_ -#define _FF_SOFTSYNTAX_H_ +#ifndef _FF_SOFT_SYNTAX_H_ +#define _FF_SOFT_SYNTAX_H_ #include "ff.h" #include "hg.h" -struct SoftSyntacticFeaturesImpl; +struct SoftSyntaxFeaturesImpl; -class SoftSyntacticFeatures : public FeatureFunction { +class SoftSyntaxFeatures : public FeatureFunction { public: - SoftSyntacticFeatures(const std::string& param); - ~SoftSyntacticFeatures(); + SoftSyntaxFeatures(const std::string& param); + ~SoftSyntaxFeatures(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -19,9 +19,9 @@ class SoftSyntacticFeatures : public FeatureFunction { void* context) const; virtual void PrepareForInput(const SentenceMetadata& smeta); private: - SoftSyntacticFeaturesImpl* impl; + SoftSyntaxFeaturesImpl* impl; }; - #endif + diff --git a/decoder/ff_soft_syntax2.cc b/decoder/ff_soft_syntax_mindist.cc index 121bc39b..a23f70f8 100644 --- a/decoder/ff_soft_syntax2.cc +++ b/decoder/ff_soft_syntax_mindist.cc @@ -1,4 +1,4 @@ -#include "ff_soft_syntax2.h" +#include "ff_soft_syntax_mindist.h" #include <cstdio> #include <sstream> @@ -13,16 +13,18 @@ using namespace std; -// Implements the soft syntactic features described in +// Implements the soft syntactic features described in // Marton and Resnik (2008): "Soft Syntacitc Constraints for Hierarchical Phrase-Based Translation". // Source trees must be represented in Penn Treebank format, // e.g. (S (NP John) (VP (V left))). +// +// This variant accepts fuzzy matches, choosing the constituent with +// minimum distance. -struct SoftSyntacticFeatures2Impl { - SoftSyntacticFeatures2Impl(const string& param) { +struct SoftSyntaxFeaturesMindistImpl { + SoftSyntaxFeaturesMindistImpl(const string& param) { vector<string> labels = SplitOnWhitespace(param); - //for (unsigned int i = 0; i < labels.size(); i++) - //cerr << "Labels: " << labels.at(i) << endl; + //for (unsigned int i = 0; i < labels.size(); i++) { cerr << "Labels: " << labels.at(i) << endl; } for (unsigned int i = 0; i < labels.size(); i++) { string label = labels.at(i); pair<string, string> feat_label; @@ -30,14 +32,12 @@ struct SoftSyntacticFeatures2Impl { feat_label.second = label.at(label.size() - 1); feat_labels.push_back(feat_label); } - } + } void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -99,14 +99,14 @@ struct SoftSyntacticFeatures2Impl { const WordID lhs = src_tree(i,j); string lhs_str = TD::Convert(lhs); //cerr << "LHS: " << lhs_str << " from " << i << " to " << j << endl; - //cerr << "RULE :"<< rule << endl; + //cerr << "RULE :"<< rule << endl; int& fid_ef = fids_ef(i,j)[&rule]; string lhs_to_str = TD::Convert(lhs); int min_dist; string min_dist_label; if (lhs_to_str.compare("XX") != 0) { min_dist = 0; - min_dist_label = lhs_to_str; + min_dist_label = lhs_to_str; } else { int ok = 0; @@ -128,7 +128,7 @@ struct SoftSyntacticFeatures2Impl { min_dist_label = (TD::Convert(src_tree(l_rem, r_rem))); break; } - } + } } if (ok) break; } @@ -146,10 +146,10 @@ struct SoftSyntacticFeatures2Impl { case '2': if (min_dist_label.compare(label) == 0) { if (min_dist == 0) { - os << "SYN:" << label << "_conform"; + os << "SOFTM:" << label << "_conform"; } else { - os << "SYN:" << label << "_cross"; + os << "SOFTM:" << label << "_cross"; } fid_ef = FD::Convert(os.str()); //cerr << "Feature :" << os.str() << endl; @@ -157,7 +157,7 @@ struct SoftSyntacticFeatures2Impl { } break; case '_': - os << "SYN:" << label; + os << "SOFTM:" << label; fid_ef = FD::Convert(os.str()); if (min_dist_label.compare(label) == 0) { //cerr << "Feature: " << os.str() << endl; @@ -172,7 +172,7 @@ struct SoftSyntacticFeatures2Impl { break; case '+': if (min_dist_label.compare(label) == 0) { - os << "SYN:" << label << "_conform"; + os << "SOFTM:" << label << "_conform"; fid_ef = FD::Convert(os.str()); if (min_dist == 0) { //cerr << "Feature: " << os.str() << endl; @@ -180,10 +180,10 @@ struct SoftSyntacticFeatures2Impl { } } break; - case '-': - //cerr << "-" << endl; + case '-': + //cerr << "-" << endl; if (min_dist_label.compare(label) != 0) { - os << "SYN:" << label << "_cross"; + os << "SOFTM:" << label << "_cross"; fid_ef = FD::Convert(os.str()); if (min_dist > 0) { //cerr << "Feature :" << os.str() << endl; @@ -200,22 +200,22 @@ struct SoftSyntacticFeatures2Impl { return lhs; } - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + Array2D<WordID> src_tree; // src_tree(i,j) NT = type + mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized vector<pair<string, string> > feat_labels; }; -SoftSyntacticFeatures2::SoftSyntacticFeatures2(const string& param) : +SoftSyntaxFeaturesMindist::SoftSyntaxFeaturesMindist(const string& param) : FeatureFunction(sizeof(WordID)) { - impl = new SoftSyntacticFeatures2Impl(param); + impl = new SoftSyntaxFeaturesMindistImpl(param); } -SoftSyntacticFeatures2::~SoftSyntacticFeatures2() { +SoftSyntaxFeaturesMindist::~SoftSyntaxFeaturesMindist() { delete impl; impl = NULL; } -void SoftSyntacticFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta, +void SoftSyntaxFeaturesMindist::TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, const vector<const void*>& ant_contexts, SparseVector<double>* features, @@ -229,6 +229,10 @@ void SoftSyntacticFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); } -void SoftSyntacticFeatures2::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); +void SoftSyntaxFeaturesMindist::PrepareForInput(const SentenceMetadata& smeta) { + ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree")); + string tree; + f.ReadAll(tree); + impl->InitializeGrids(tree, smeta.GetSourceLength()); } + diff --git a/decoder/ff_soft_syntax2.h b/decoder/ff_soft_syntax_mindist.h index 4de91d86..bf938b38 100644 --- a/decoder/ff_soft_syntax2.h +++ b/decoder/ff_soft_syntax_mindist.h @@ -1,15 +1,15 @@ -#ifndef _FF_SOFTSYNTAX2_H_ -#define _FF_SOFTSYNTAX2_H_ +#ifndef _FF_SOFT_SYNTAX_MINDIST_H_ +#define _FF_SOFT_SYNTAX_MINDIST_H_ #include "ff.h" #include "hg.h" -struct SoftSyntacticFeatures2Impl; +struct SoftSyntaxFeaturesMindistImpl; -class SoftSyntacticFeatures2 : public FeatureFunction { +class SoftSyntaxFeaturesMindist : public FeatureFunction { public: - SoftSyntacticFeatures2(const std::string& param); - ~SoftSyntacticFeatures2(); + SoftSyntaxFeaturesMindist(const std::string& param); + ~SoftSyntaxFeaturesMindist(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, const Hypergraph::Edge& edge, @@ -19,9 +19,9 @@ class SoftSyntacticFeatures2 : public FeatureFunction { void* context) const; virtual void PrepareForInput(const SentenceMetadata& smeta); private: - SoftSyntacticFeatures2Impl* impl; + SoftSyntaxFeaturesMindistImpl* impl; }; - #endif + diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc index 88f6714c..6b183863 100644 --- a/decoder/ff_source_syntax.cc +++ b/decoder/ff_source_syntax.cc @@ -9,7 +9,6 @@ namespace std { using std::tr1::unordered_set; } #endif -#include "hg.h" #include "sentence_metadata.h" #include "array2d.h" #include "filelib.h" @@ -30,6 +29,17 @@ inline int SpanSizeTransform(unsigned span_size) { struct SourceSyntaxFeaturesImpl { SourceSyntaxFeaturesImpl() {} + SourceSyntaxFeaturesImpl(const string& param) { + if (!(param.compare("") == 0)) { + string triggered_features_fn = param; + ReadFile triggered_features(triggered_features_fn); + string in; + while(getline(*triggered_features, in)) { + feature_filter.insert(FD::Convert(in)); + } + } + } + void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); //fids_cat.clear(); @@ -99,7 +109,7 @@ struct SourceSyntaxFeaturesImpl { if (fid_ef <= 0) { ostringstream os; //ostringstream os2; - os << "SYN:" << TD::Convert(lhs); + os << "SSYN:" << TD::Convert(lhs); //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); //fid_cat = FD::Convert(os2.str()); os << ':'; @@ -124,21 +134,28 @@ struct SourceSyntaxFeaturesImpl { } fid_ef = FD::Convert(os.str()); } - //if (fid_cat > 0) - // feats->set_value(fid_cat, 1.0); - if (fid_ef > 0) - feats->set_value(fid_ef, 1.0); + if (fid_ef > 0) { + if (feature_filter.size()>0) { + if (feature_filter.find(fid_ef) != feature_filter.end()) { + feats->set_value(fid_ef, 1.0); + } + } else { + feats->set_value(fid_ef, 1.0); + } + } + cerr << FD::Convert(fid_ef) << endl; return lhs; } - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - // mutable Array2D<int> fids_cat; // this tends to overfit baddly - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + Array2D<WordID> src_tree; // src_tree(i,j) NT = type + // mutable Array2D<int> fids_cat; // this tends to overfit baddly + mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + unordered_set<int> feature_filter; }; SourceSyntaxFeatures::SourceSyntaxFeatures(const string& param) : FeatureFunction(sizeof(WordID)) { - impl = new SourceSyntaxFeaturesImpl; + impl = new SourceSyntaxFeaturesImpl(param); } SourceSyntaxFeatures::~SourceSyntaxFeatures() { @@ -161,7 +178,10 @@ void SourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, } void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); + ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree")); + string tree; + f.ReadAll(tree); + impl->InitializeGrids(tree, smeta.GetSourceLength()); } struct SourceSpanSizeFeaturesImpl { @@ -236,4 +256,3 @@ void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) { impl->InitializeGrids(smeta.GetSourceLength()); } - diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h index a8c7150a..bdd638c1 100644 --- a/decoder/ff_source_syntax.h +++ b/decoder/ff_source_syntax.h @@ -1,7 +1,8 @@ -#ifndef _FF_SOURCE_TOOLS_H_ -#define _FF_SOURCE_TOOLS_H_ +#ifndef _FF_SOURCE_SYNTAX_H_ +#define _FF_SOURCE_SYNTAX_H_ #include "ff.h" +#include "hg.h" struct SourceSyntaxFeaturesImpl; @@ -11,7 +12,7 @@ class SourceSyntaxFeatures : public FeatureFunction { ~SourceSyntaxFeatures(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const HG::Edge& edge, + const Hypergraph::Edge& edge, const std::vector<const void*>& ant_contexts, SparseVector<double>* features, SparseVector<double>* estimated_features, @@ -28,7 +29,7 @@ class SourceSpanSizeFeatures : public FeatureFunction { ~SourceSpanSizeFeatures(); protected: virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const HG::Edge& edge, + const Hypergraph::Edge& edge, const std::vector<const void*>& ant_contexts, SparseVector<double>* features, SparseVector<double>* estimated_features, @@ -39,3 +40,4 @@ class SourceSpanSizeFeatures : public FeatureFunction { }; #endif + diff --git a/decoder/ff_source_syntax2.cc b/decoder/ff_source_syntax2.cc index 622c6908..a97e31d8 100644 --- a/decoder/ff_source_syntax2.cc +++ b/decoder/ff_source_syntax2.cc @@ -16,7 +16,7 @@ using namespace std; struct SourceSyntaxFeatures2Impl { SourceSyntaxFeatures2Impl(const string& param) { - if (!(param.compare("") == 0)) { + if (param.compare("") != 0) { string triggered_features_fn = param; ReadFile triggered_features(triggered_features_fn); string in; @@ -28,10 +28,8 @@ struct SourceSyntaxFeatures2Impl { void InitializeGrids(const string& tree, unsigned src_len) { assert(tree.size() > 0); - //fids_cat.clear(); fids_ef.clear(); src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); fids_ef.resize(src_len, src_len + 1); src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); ParseTreeString(tree, src_len); @@ -39,7 +37,7 @@ struct SourceSyntaxFeatures2Impl { void ParseTreeString(const string& tree, unsigned src_len) { //cerr << "TREE: " << tree << endl; - stack<pair<int, WordID> > stk; // first = i, second = category + stack<pair<int, WordID> > stk; // first = i, second = category pair<int, WordID> cur_cat; cur_cat.first = -1; unsigned i = 0; unsigned p = 0; @@ -91,7 +89,7 @@ struct SourceSyntaxFeatures2Impl { const WordID lhs = src_tree(i,j); int& fid_ef = fids_ef(i,j)[&rule]; ostringstream os; - os << "SYN:" << TD::Convert(lhs); + os << "SSYN2:" << TD::Convert(lhs); os << ':'; unsigned ntc = 0; for (unsigned k = 0; k < rule.f_.size(); ++k) { @@ -99,7 +97,7 @@ struct SourceSyntaxFeatures2Impl { if (k > 0 && fj <= 0) os << '_'; if (fj <= 0) { os << '[' << TD::Convert(ants[ntc++]) << ']'; - } /*else { + }/*else { os << TD::Convert(fj); }*/ } @@ -115,16 +113,22 @@ struct SourceSyntaxFeatures2Impl { fid_ef = FD::Convert(os.str()); //cerr << "FEATURE: " << os.str() << endl; //cerr << "FID_EF: " << fid_ef << endl; - if (feature_filter.find(fid_ef) != feature_filter.end()) { - cerr << "SYN-Feature was trigger more than once on training set." << endl; + if (feature_filter.size() > 0) { + if (feature_filter.find(fid_ef) != feature_filter.end()) { + //cerr << "SYN-Feature was trigger more than once on training set." << endl; + feats->set_value(fid_ef, 1.0); + } + //else cerr << "SYN-Feature was triggered less than once on training set." << endli; + } + else { feats->set_value(fid_ef, 1.0); } - else cerr << "SYN-Feature was triggered less than once on training set." << endl; + cerr << FD::Convert(fid_ef) << endl; return lhs; } - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized + Array2D<WordID> src_tree; // src_tree(i,j) NT = type + mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized unordered_set<int> feature_filter; }; @@ -153,5 +157,9 @@ void SourceSyntaxFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta, } void SourceSyntaxFeatures2::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); + ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree")); + string tree; + f.ReadAll(tree); + impl->InitializeGrids(tree, smeta.GetSourceLength()); } + diff --git a/decoder/ff_source_syntax2.h b/decoder/ff_source_syntax2.h index b6b7dc3d..f606c2bf 100644 --- a/decoder/ff_source_syntax2.h +++ b/decoder/ff_source_syntax2.h @@ -1,5 +1,5 @@ -#ifndef _FF_SOURCE_TOOLS2_H_ -#define _FF_SOURCE_TOOLS2_H_ +#ifndef _FF_SOURCE_SYNTAX2_H_ +#define _FF_SOURCE_SYNTAX2_H_ #include "ff.h" #include "hg.h" @@ -23,3 +23,4 @@ class SourceSyntaxFeatures2 : public FeatureFunction { }; #endif + diff --git a/decoder/ff_source_syntax2_p.cc b/decoder/ff_source_syntax2_p.cc deleted file mode 100644 index 6a2ae742..00000000 --- a/decoder/ff_source_syntax2_p.cc +++ /dev/null @@ -1,170 +0,0 @@ -#include "ff_source_syntax2_p.h" - -#include <sstream> -#include <stack> -#include <string> -#ifndef HAVE_OLD_CPP -# include <unordered_set> -#else -# include <tr1/unordered_set> -namespace std { using std::tr1::unordered_set; } -#endif - -#include "sentence_metadata.h" -#include "array2d.h" -#include "filelib.h" - -using namespace std; - -// implements the source side syntax features described in Blunsom et al. (EMNLP 2008) -// source trees must be represented in Penn Treebank format, e.g. -// (S (NP John) (VP (V left))) - -struct PSourceSyntaxFeatures2Impl { - PSourceSyntaxFeatures2Impl(const string& param) { - if (param.compare("") != 0) { - string triggered_features_fn = param; - ReadFile triggered_features(triggered_features_fn); - string in; - while(getline(*triggered_features, in)) { - feature_filter.insert(FD::Convert(in)); - } - } - /*cerr << "find(\"One\") == " << boolalpha << (table.find("One") != table.end()) << endl; - cerr << "find(\"Three\") == " << boolalpha << (table.find("Three") != table.end()) << endl;*/ - } - - void InitializeGrids(const string& tree, unsigned src_len) { - assert(tree.size() > 0); - //fids_cat.clear(); - fids_ef.clear(); - src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); - fids_ef.resize(src_len, src_len + 1); - src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); - ParseTreeString(tree, src_len); - } - - void ParseTreeString(const string& tree, unsigned src_len) { - //cerr << "TREE: " << tree << endl; - stack<pair<int, WordID> > stk; // first = i, second = category - pair<int, WordID> cur_cat; cur_cat.first = -1; - unsigned i = 0; - unsigned p = 0; - while(p < tree.size()) { - const char cur = tree[p]; - if (cur == '(') { - stk.push(cur_cat); - ++p; - unsigned k = p + 1; - while (k < tree.size() && tree[k] != ' ') { ++k; } - cur_cat.first = i; - cur_cat.second = TD::Convert(tree.substr(p, k - p)); - // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; - p = k + 1; - } else if (cur == ')') { - unsigned k = p; - while (k < tree.size() && tree[k] == ')') { ++k; } - const unsigned num_closes = k - p; - for (unsigned ci = 0; ci < num_closes; ++ci) { - src_tree(cur_cat.first, i) = cur_cat.second; - cur_cat = stk.top(); - stk.pop(); - } - p = k; - while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; } - } else if (cur == ' ' || cur == '\t') { - cerr << "Unexpected whitespace in: " << tree << endl; - abort(); - } else { // terminal symbol - unsigned k = p + 1; - do { - while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; } - // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; - ++i; - assert(i <= src_len); - while (k < tree.size() && tree[k] == ' ') { ++k; } - p = k; - } while (p < tree.size() && tree[p] != ')'); - } - //cerr << "i=" << i << " src_len=" << src_len << endl; - } - //cerr << "i=" << i << " src_len=" << src_len << endl; - assert(i == src_len); // make sure tree specified in src_tree is - // the same length as the source sentence - } - - WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) { - //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; - const WordID lhs = src_tree(i,j); - int& fid_ef = fids_ef(i,j)[&rule]; - ostringstream os; - os << "SYN:" << TD::Convert(lhs); - os << ':'; - unsigned ntc = 0; - for (unsigned k = 0; k < rule.f_.size(); ++k) { - int fj = rule.f_[k]; - if (k > 0 && fj <= 0) os << '_'; - if (fj <= 0) { - os << '[' << TD::Convert(ants[ntc++]) << ']'; - } /*else { - os << TD::Convert(fj); - }*/ - } - os << ':'; - for (unsigned k = 0; k < rule.e_.size(); ++k) { - const int ei = rule.e_[k]; - if (k > 0) os << '_'; - if (ei <= 0) - os << '[' << (1-ei) << ']'; - else - os << TD::Convert(ei); - } - fid_ef = FD::Convert(os.str()); - //cerr << "FEATURE: " << os.str() << endl; - //cerr << "FID_EF: " << fid_ef << endl; - if (feature_filter.size() > 0) { - if (feature_filter.find(fid_ef) != feature_filter.end()) { - //cerr << "SYN-Feature was trigger more than once on training set." << endl; - feats->set_value(fid_ef, 1.0); - } - //else cerr << "SYN-Feature was triggered less than once on training set." << endli; - } - else { - feats->set_value(fid_ef, 1.0); - } - return lhs; - } - - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized - unordered_set<int> feature_filter; -}; - -PSourceSyntaxFeatures2::PSourceSyntaxFeatures2(const string& param) : - FeatureFunction(sizeof(WordID)) { - impl = new PSourceSyntaxFeatures2Impl(param); -} - -PSourceSyntaxFeatures2::~PSourceSyntaxFeatures2() { - delete impl; - impl = NULL; -} - -void PSourceSyntaxFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const { - WordID ants[8]; - for (unsigned i = 0; i < ant_contexts.size(); ++i) - ants[i] = *static_cast<const WordID*>(ant_contexts[i]); - - *static_cast<WordID*>(context) = - impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); -} - -void PSourceSyntaxFeatures2::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); -} diff --git a/decoder/ff_source_syntax2_p.h b/decoder/ff_source_syntax2_p.h deleted file mode 100644 index d56ecab0..00000000 --- a/decoder/ff_source_syntax2_p.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef _FF_SOURCE_TOOLS2_H_ -#define _FF_SOURCE_TOOLS2_H_ - -#include "ff.h" -#include "hg.h" - -struct PSourceSyntaxFeatures2Impl; - -class PSourceSyntaxFeatures2 : public FeatureFunction { - public: - PSourceSyntaxFeatures2(const std::string& param); - ~PSourceSyntaxFeatures2(); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - PSourceSyntaxFeatures2Impl* impl; -}; - -#endif diff --git a/decoder/ff_source_syntax_p.cc b/decoder/ff_source_syntax_p.cc deleted file mode 100644 index c094de59..00000000 --- a/decoder/ff_source_syntax_p.cc +++ /dev/null @@ -1,250 +0,0 @@ -#include "ff_source_syntax_p.h" - -#include <sstream> -#include <stack> -#ifndef HAVE_OLD_CPP -# include <unordered_set> -#else -# include <tr1/unordered_set> -namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } -#endif - -#include "sentence_metadata.h" -#include "array2d.h" -#include "filelib.h" - -using namespace std; - -// implements the source side syntax features described in Blunsom et al. (EMNLP 2008) -// source trees must be represented in Penn Treebank format, e.g. -// (S (NP John) (VP (V left))) - -// log transform to make long spans cluster together -// but preserve differences -inline int SpanSizeTransform(unsigned span_size) { - if (!span_size) return 0; - return static_cast<int>(log(span_size+1) / log(1.39)) - 1; -} - -struct PSourceSyntaxFeaturesImpl { - PSourceSyntaxFeaturesImpl() {} - - PSourceSyntaxFeaturesImpl(const string& param) { - if (!(param.compare("") == 0)) { - string triggered_features_fn = param; - ReadFile triggered_features(triggered_features_fn); - string in; - while(getline(*triggered_features, in)) { - feature_filter.insert(FD::Convert(in)); - } - } - } - - void InitializeGrids(const string& tree, unsigned src_len) { - assert(tree.size() > 0); - //fids_cat.clear(); - fids_ef.clear(); - src_tree.clear(); - //fids_cat.resize(src_len, src_len + 1); - fids_ef.resize(src_len, src_len + 1); - src_tree.resize(src_len, src_len + 1, TD::Convert("XX")); - ParseTreeString(tree, src_len); - } - - void ParseTreeString(const string& tree, unsigned src_len) { - stack<pair<int, WordID> > stk; // first = i, second = category - pair<int, WordID> cur_cat; cur_cat.first = -1; - unsigned i = 0; - unsigned p = 0; - while(p < tree.size()) { - const char cur = tree[p]; - if (cur == '(') { - stk.push(cur_cat); - ++p; - unsigned k = p + 1; - while (k < tree.size() && tree[k] != ' ') { ++k; } - cur_cat.first = i; - cur_cat.second = TD::Convert(tree.substr(p, k - p)); - // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; - p = k + 1; - } else if (cur == ')') { - unsigned k = p; - while (k < tree.size() && tree[k] == ')') { ++k; } - const unsigned num_closes = k - p; - for (unsigned ci = 0; ci < num_closes; ++ci) { - // cur_cat.second spans from cur_cat.first to i - // cerr << TD::Convert(cur_cat.second) << " from " << cur_cat.first << " to " << i << endl; - // NOTE: unary rule chains end up being labeled with the top-most category - src_tree(cur_cat.first, i) = cur_cat.second; - cur_cat = stk.top(); - stk.pop(); - } - p = k; - while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; } - } else if (cur == ' ' || cur == '\t') { - cerr << "Unexpected whitespace in: " << tree << endl; - abort(); - } else { // terminal symbol - unsigned k = p + 1; - do { - while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; } - // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n"; - ++i; - assert(i <= src_len); - while (k < tree.size() && tree[k] == ' ') { ++k; } - p = k; - } while (p < tree.size() && tree[p] != ')'); - } - } - // cerr << "i=" << i << " src_len=" << src_len << endl; - assert(i == src_len); // make sure tree specified in src_tree is - // the same length as the source sentence - } - - WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) { - //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl; - const WordID lhs = src_tree(i,j); - //int& fid_cat = fids_cat(i,j); - int& fid_ef = fids_ef(i,j)[&rule]; - if (fid_ef <= 0) { - ostringstream os; - //ostringstream os2; - os << "SYN:" << TD::Convert(lhs); - //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i); - //fid_cat = FD::Convert(os2.str()); - os << ':'; - unsigned ntc = 0; - for (unsigned k = 0; k < rule.f_.size(); ++k) { - if (k > 0) os << '_'; - int fj = rule.f_[k]; - if (fj <= 0) { - os << '[' << TD::Convert(ants[ntc++]) << ']'; - } else { - os << TD::Convert(fj); - } - } - os << ':'; - for (unsigned k = 0; k < rule.e_.size(); ++k) { - const int ei = rule.e_[k]; - if (k > 0) os << '_'; - if (ei <= 0) - os << '[' << (1-ei) << ']'; - else - os << TD::Convert(ei); - } - fid_ef = FD::Convert(os.str()); - } - //if (fid_cat > 0) - // feats->set_value(fid_cat, 1.0); - if (fid_ef > 0 && (feature_filter.find(fid_ef) != feature_filter.end())) - feats->set_value(fid_ef, 1.0); - return lhs; - } - - Array2D<WordID> src_tree; // src_tree(i,j) NT = type - // mutable Array2D<int> fids_cat; // this tends to overfit baddly - mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized - unordered_set<int> feature_filter; -}; - -PSourceSyntaxFeatures::PSourceSyntaxFeatures(const string& param) : - FeatureFunction(sizeof(WordID)) { - impl = new PSourceSyntaxFeaturesImpl(param); -} - -PSourceSyntaxFeatures::~PSourceSyntaxFeatures() { - delete impl; - impl = NULL; -} - -void PSourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const { - WordID ants[8]; - for (unsigned i = 0; i < ant_contexts.size(); ++i) - ants[i] = *static_cast<const WordID*>(ant_contexts[i]); - - *static_cast<WordID*>(context) = - impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); -} - -void PSourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength()); -} - -struct PSourceSpanSizeFeaturesImpl { - PSourceSpanSizeFeaturesImpl() {} - - void InitializeGrids(unsigned src_len) { - fids.clear(); - fids.resize(src_len, src_len + 1); - } - - int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) { - if (rule.Arity() > 0) { - int& fid = fids(i,j)[&rule]; - if (fid <= 0) { - ostringstream os; - os << "SSS:"; - unsigned ntc = 0; - for (unsigned k = 0; k < rule.f_.size(); ++k) { - if (k > 0) os << '_'; - int fj = rule.f_[k]; - if (fj <= 0) { - os << '[' << TD::Convert(-fj) << ants[ntc++] << ']'; - } else { - os << TD::Convert(fj); - } - } - os << ':'; - for (unsigned k = 0; k < rule.e_.size(); ++k) { - const int ei = rule.e_[k]; - if (k > 0) os << '_'; - if (ei <= 0) - os << '[' << (1-ei) << ']'; - else - os << TD::Convert(ei); - } - fid = FD::Convert(os.str()); - } - if (fid > 0) - feats->set_value(fid, 1.0); - } - return SpanSizeTransform(j - i); - } - - mutable Array2D<map<const TRule*, int> > fids; -}; - -PSourceSpanSizeFeatures::PSourceSpanSizeFeatures(const string& param) : - FeatureFunction(sizeof(char)) { - impl = new PSourceSpanSizeFeaturesImpl; -} - -PSourceSpanSizeFeatures::~PSourceSpanSizeFeatures() { - delete impl; - impl = NULL; -} - -void PSourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const { - int ants[8]; - for (unsigned i = 0; i < ant_contexts.size(); ++i) - ants[i] = *static_cast<const char*>(ant_contexts[i]); - - *static_cast<char*>(context) = - impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features); -} - -void PSourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) { - impl->InitializeGrids(smeta.GetSourceLength()); -} - - diff --git a/decoder/ff_source_syntax_p.h b/decoder/ff_source_syntax_p.h deleted file mode 100644 index 2dd9094a..00000000 --- a/decoder/ff_source_syntax_p.h +++ /dev/null @@ -1,42 +0,0 @@ -#ifndef _FF_SOURCE_TOOLS_H_ -#define _FF_SOURCE_TOOLS_H_ - -#include "ff.h" -#include "hg.h" - -struct PSourceSyntaxFeaturesImpl; - -class PSourceSyntaxFeatures : public FeatureFunction { - public: - PSourceSyntaxFeatures(const std::string& param); - ~PSourceSyntaxFeatures(); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - PSourceSyntaxFeaturesImpl* impl; -}; - -struct PSourceSpanSizeFeaturesImpl; -class PSourceSpanSizeFeatures : public FeatureFunction { - public: - PSourceSpanSizeFeatures(const std::string& param); - ~PSourceSpanSizeFeatures(); - protected: - virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, - const Hypergraph::Edge& edge, - const std::vector<const void*>& ant_contexts, - SparseVector<double>* features, - SparseVector<double>* estimated_features, - void* context) const; - virtual void PrepareForInput(const SentenceMetadata& smeta); - private: - PSourceSpanSizeFeaturesImpl* impl; -}; - -#endif diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc index 8050ce7b..1fbdee5b 100644 --- a/extractor/grammar_extractor.cc +++ b/extractor/grammar_extractor.cc @@ -3,11 +3,13 @@ #include <iterator> #include <sstream> #include <vector> +#include <unordered_set> #include "grammar.h" #include "rule.h" #include "rule_factory.h" #include "vocabulary.h" +#include "data_array.h" using namespace std; @@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor( vocabulary(vocabulary), rule_factory(rule_factory) {} -Grammar GrammarExtractor::GetGrammar(const string& sentence) { +Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { vector<string> words = TokenizeSentence(sentence); vector<int> word_ids = AnnotateWords(words); - return rule_factory->GetGrammar(word_ids); + return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); } vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) { diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h index b36ceeb9..6c0aafbf 100644 --- a/extractor/grammar_extractor.h +++ b/extractor/grammar_extractor.h @@ -4,6 +4,7 @@ #include <memory> #include <string> #include <vector> +#include <unordered_set> using namespace std; @@ -44,7 +45,7 @@ class GrammarExtractor { // Converts the sentence to a vector of word ids and uses the RuleFactory to // extract the SCFG rules which may be used to decode the sentence. - Grammar GetGrammar(const string& sentence); + Grammar GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); private: // Splits the sentence in a vector of words. diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc index 823bb8b4..f32a9599 100644 --- a/extractor/grammar_extractor_test.cc +++ b/extractor/grammar_extractor_test.cc @@ -39,12 +39,15 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) { vector<Rule> rules; vector<string> feature_names; Grammar grammar(rules, feature_names); - EXPECT_CALL(*factory, GetGrammar(word_ids)) + unordered_set<int> blacklisted_sentence_ids; + shared_ptr<DataArray> source_data_array; + EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array)) .WillOnce(Return(grammar)); GrammarExtractor extractor(vocabulary, factory); string sentence = "Anna has many many apples ."; - extractor.GetGrammar(sentence); + + extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array); } } // namespace diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h index 7389b396..86a084b5 100644 --- a/extractor/mocks/mock_rule_factory.h +++ b/extractor/mocks/mock_rule_factory.h @@ -7,7 +7,7 @@ namespace extractor { class MockHieroCachingRuleFactory : public HieroCachingRuleFactory { public: - MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids)); + MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array)); }; } // namespace extractor diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc index 8c30fb9e..e52019ae 100644 --- a/extractor/rule_factory.cc +++ b/extractor/rule_factory.cc @@ -17,6 +17,7 @@ #include "suffix_array.h" #include "time_util.h" #include "vocabulary.h" +#include "data_array.h" using namespace std; using namespace chrono; @@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {} HieroCachingRuleFactory::~HieroCachingRuleFactory() {} -Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { +Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) { Clock::time_point start_time = Clock::now(); double total_extract_time = 0; double total_intersect_time = 0; @@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) { Clock::time_point extract_start = Clock::now(); if (!state.starts_with_x) { // Extract rules for the sampled set of occurrences. - PhraseLocation sample = sampler->Sample(next_node->matchings); + PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array); vector<Rule> new_rules = rule_extractor->ExtractRules(next_phrase, sample); rules.insert(rules.end(), new_rules.begin(), new_rules.end()); diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h index 52e8712a..c7332720 100644 --- a/extractor/rule_factory.h +++ b/extractor/rule_factory.h @@ -3,6 +3,7 @@ #include <memory> #include <vector> +#include <unordered_set> #include "matchings_trie.h" @@ -71,7 +72,7 @@ class HieroCachingRuleFactory { // Constructs SCFG rules for a given sentence. // (See class description for more details.) - virtual Grammar GetGrammar(const vector<int>& word_ids); + virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array); protected: HieroCachingRuleFactory(); diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc index 08af3dcd..f26cc567 100644 --- a/extractor/rule_factory_test.cc +++ b/extractor/rule_factory_test.cc @@ -76,7 +76,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) { .WillRepeatedly(Return(PhraseLocation(0, 1))); vector<int> word_ids = {2, 3, 4}; - Grammar grammar = factory->GetGrammar(word_ids); + unordered_set<int> blacklisted_sentence_ids; + shared_ptr<DataArray> source_data_array; + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(7, grammar.GetRules().size()); } @@ -94,7 +96,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) { .WillRepeatedly(Return(PhraseLocation(0, 1))); vector<int> word_ids = {2, 3, 4, 2, 3}; - Grammar grammar = factory->GetGrammar(word_ids); + unordered_set<int> blacklisted_sentence_ids; + shared_ptr<DataArray> source_data_array; + Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array); EXPECT_EQ(feature_names, grammar.GetFeatureNames()); EXPECT_EQ(28, grammar.GetRules().size()); } diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc index 8a9ca89d..6eb55073 100644 --- a/extractor/run_extractor.cc +++ b/extractor/run_extractor.cc @@ -75,7 +75,9 @@ int main(int argc, char** argv) { ("max_samples", po::value<int>()->default_value(300), "Maximum number of samples") ("tight_phrases", po::value<bool>()->default_value(true), - "False if phrases may be loose (better, but slower)"); + "False if phrases may be loose (better, but slower)") + ("leave_one_out", po::value<bool>()->zero_tokens(), + "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -96,6 +98,11 @@ int main(int argc, char** argv) { return 1; } + bool leave_one_out = false; + if (vm.count("leave_one_out")) { + leave_one_out = true; + } + int num_threads = vm["threads"].as<int>(); cerr << "Grammar extraction will use " << num_threads << " threads." << endl; @@ -223,7 +230,9 @@ int main(int argc, char** argv) { } suffixes[i] = suffix; - Grammar grammar = extractor.GetGrammar(sentences[i]); + unordered_set<int> blacklisted_sentence_ids; + if (leave_one_out) blacklisted_sentence_ids.insert(i); + Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array); ofstream output(GetGrammarFilePath(grammar_path, i).c_str()); output << grammar; } diff --git a/extractor/sample_alignment.txt b/extractor/sample_alignment.txt index 80b446a4..f0292b01 100644 --- a/extractor/sample_alignment.txt +++ b/extractor/sample_alignment.txt @@ -1,2 +1,5 @@ 0-0 1-1 2-2 1-0 2-1 +0-0 +0-0 1-1 +0-0 1-1 diff --git a/extractor/sample_bitext.txt b/extractor/sample_bitext.txt index 93d6b39d..2b7c8e40 100644 --- a/extractor/sample_bitext.txt +++ b/extractor/sample_bitext.txt @@ -1,2 +1,5 @@ +asdf ||| dontseeme +qqq asdf ||| zzz fdsa +asdf qqq ||| fdsa zzz ana are mere . ||| anna has apples . ana bea mult lapte . ||| anna drinks a lot of milk . diff --git a/extractor/sample_source.txt b/extractor/sample_source.txt new file mode 100644 index 00000000..9b46dd6a --- /dev/null +++ b/extractor/sample_source.txt @@ -0,0 +1,5 @@ +asdf +qqq asdf +asdf qqq +ana are mere . +ana bea mult lapte . diff --git a/extractor/sampler.cc b/extractor/sampler.cc index d81956b5..d332dd90 100644 --- a/extractor/sampler.cc +++ b/extractor/sampler.cc @@ -12,7 +12,7 @@ Sampler::Sampler() {} Sampler::~Sampler() {} -PhraseLocation Sampler::Sample(const PhraseLocation& location) const { +PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const { vector<int> sample; int num_subpatterns; if (location.matchings == NULL) { @@ -20,8 +20,37 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const { num_subpatterns = 1; int low = location.sa_low, high = location.sa_high; double step = max(1.0, (double) (high - low) / max_samples); - for (double i = low; i < high && sample.size() < max_samples; i += step) { - sample.push_back(suffix_array->GetSuffix(Round(i))); + double i = low, last = i; + bool found; + while (sample.size() < max_samples && i < high) { + int x = suffix_array->GetSuffix(Round(i)); + int id = source_data_array->GetSentenceId(x); + if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) { + found = false; + double backoff_step = 1; + while (true) { + if ((double)backoff_step >= step) break; + double j = i - backoff_step; + x = suffix_array->GetSuffix(Round(j)); + id = source_data_array->GetSentenceId(x); + if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + found = true; last = i; break; + } + double k = i + backoff_step; + x = suffix_array->GetSuffix(Round(k)); + id = source_data_array->GetSentenceId(x); + if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) { + found = true; last = k; break; + } + if (j <= last && k >= high) break; + backoff_step++; + } + } else { + found = true; + last = i; + } + if (found) sample.push_back(x); + i += step; } } else { // Sample vector of occurrences. diff --git a/extractor/sampler.h b/extractor/sampler.h index be4aa1bb..30e747fd 100644 --- a/extractor/sampler.h +++ b/extractor/sampler.h @@ -2,6 +2,9 @@ #define _SAMPLER_H_ #include <memory> +#include <unordered_set> + +#include "data_array.h" using namespace std; @@ -20,7 +23,7 @@ class Sampler { virtual ~Sampler(); // Samples uniformly at most max_samples phrase occurrences. - virtual PhraseLocation Sample(const PhraseLocation& location) const; + virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const; protected: Sampler(); diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc index e9abebfa..965567ba 100644 --- a/extractor/sampler_test.cc +++ b/extractor/sampler_test.cc @@ -3,6 +3,7 @@ #include <memory> #include "mocks/mock_suffix_array.h" +#include "mocks/mock_data_array.h" #include "phrase_location.h" #include "sampler.h" @@ -15,6 +16,8 @@ namespace { class SamplerTest : public Test { protected: virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999)); suffix_array = make_shared<MockSuffixArray>(); for (int i = 0; i < 10; ++i) { EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); @@ -23,51 +26,54 @@ class SamplerTest : public Test { shared_ptr<MockSuffixArray> suffix_array; shared_ptr<Sampler> sampler; + shared_ptr<MockDataArray> source_data_array; }; TEST_F(SamplerTest, TestSuffixArrayRange) { PhraseLocation location(0, 10); + unordered_set<int> blacklist; sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 5}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 3, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 4); expected_locations = {0, 3, 5, 8}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 100); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); } TEST_F(SamplerTest, TestSubstringsSample) { vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + unordered_set<int> blacklist; PhraseLocation location(locations, 2); sampler = make_shared<Sampler>(suffix_array, 1); vector<int> expected_locations = {0, 1}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 2); expected_locations = {0, 1, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 3); expected_locations = {0, 1, 4, 5, 6, 7}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); sampler = make_shared<Sampler>(suffix_array, 7); expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location)); + EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array)); } } // namespace diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc new file mode 100644 index 00000000..3305b990 --- /dev/null +++ b/extractor/sampler_test_blacklist.cc @@ -0,0 +1,102 @@ +#include <gtest/gtest.h> + +#include <memory> + +#include "mocks/mock_suffix_array.h" +#include "mocks/mock_data_array.h" +#include "phrase_location.h" +#include "sampler.h" + +using namespace std; +using namespace ::testing; + +namespace extractor { +namespace { + +class SamplerTestBlacklist : public Test { + protected: + virtual void SetUp() { + source_data_array = make_shared<MockDataArray>(); + for (int i = 0; i < 10; ++i) { + EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i)); + } + for (int i = -10; i < 0; ++i) { + EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0)); + } + suffix_array = make_shared<MockSuffixArray>(); + for (int i = -10; i < 10; ++i) { + EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i)); + } + } + + shared_ptr<MockSuffixArray> suffix_array; + shared_ptr<Sampler> sampler; + shared_ptr<MockDataArray> source_data_array; +}; + +TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) { + PhraseLocation location(0, 10); + unordered_set<int> blacklist; + vector<int> expected_locations; + + blacklist.insert(0); + sampler = make_shared<Sampler>(suffix_array, 1); + expected_locations = {1}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + for (int i = 0; i < 9; i++) { + blacklist.insert(i); + } + sampler = make_shared<Sampler>(suffix_array, 1); + expected_locations = {9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(0); + blacklist.insert(5); + sampler = make_shared<Sampler>(suffix_array, 2); + expected_locations = {1, 4}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(0); + blacklist.insert(1); + blacklist.insert(2); + blacklist.insert(3); + sampler = make_shared<Sampler>(suffix_array, 2); + expected_locations = {4, 5}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(0); + blacklist.insert(3); + blacklist.insert(7); + sampler = make_shared<Sampler>(suffix_array, 3); + expected_locations = {1, 2, 6}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(0); + blacklist.insert(3); + blacklist.insert(5); + blacklist.insert(8); + sampler = make_shared<Sampler>(suffix_array, 4); + expected_locations = {1, 2, 4, 7}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(0); + sampler = make_shared<Sampler>(suffix_array, 100); + expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); + blacklist.clear(); + + blacklist.insert(9); + sampler = make_shared<Sampler>(suffix_array, 100); + expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array)); +} + +} // namespace +} // namespace extractor diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am index 844c790d..ecb6c128 100644 --- a/training/dtrain/Makefile.am +++ b/training/dtrain/Makefile.am @@ -1,7 +1,7 @@ bin_PROGRAMS = dtrain dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h -dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a +dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc index 0ee2f124..0a27a068 100644 --- a/training/dtrain/dtrain.cc +++ b/training/dtrain/dtrain.cc @@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) { po::options_description ini("Configuration File Options"); ini.add_options() - ("input", po::value<string>()->default_value("-"), "input file (src)") + ("input", po::value<string>(), "input file (src)") ("refs,r", po::value<string>(), "references") + ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'") ("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT") ("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)") ("decoder_config", po::value<string>(), "configuration file for cdec") @@ -40,6 +41,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) ("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair") ("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near") ("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.") + ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate") + ("batch", po::value<bool>()->zero_tokens(), "do batch optimization") + ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times") + //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)") ("noup", po::value<bool>()->zero_tokens(), "do not update weights"); po::options_description cl("Command Line Options"); cl.add_options() @@ -72,13 +77,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg) cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl; return false; } - if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { + if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") { cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl; } - if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { + if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) { cerr << "hi_lo must lie in [0.01, 0.5]" << endl; return false; } + if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) { + cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl; + return false; + } if ((*cfg)["pair_threshold"].as<score_t>() < 0) { cerr << "The threshold must be >= 0!" << endl; return false; @@ -120,10 +129,16 @@ main(int argc, char** argv) const float hi_lo = cfg["hi_lo"].as<float>(); const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>(); const unsigned max_pairs = cfg["max_pairs"].as<unsigned>(); + int repeat = cfg["repeat"].as<unsigned>(); + //bool test_k_best = false; + //if (cfg.count("test-k-best")) test_k_best = true; weight_t loss_margin = cfg["loss_margin"].as<weight_t>(); + bool batch = false; + if (cfg.count("batch")) batch = true; if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max(); bool scale_bleu_diff = false; if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true; + const string pclr = cfg["pclr"].as<string>(); bool average = false; if (select_weights == "avg") average = true; @@ -131,7 +146,6 @@ main(int argc, char** argv) if (cfg.count("print_weights")) boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" ")); - // setup decoder register_feature_functions(); SetSilent(true); @@ -178,17 +192,16 @@ main(int argc, char** argv) observer->SetScorer(scorer); // init weights - vector<weight_t>& dense_weights = decoder.CurrentWeightVector(); + vector<weight_t>& decoder_weights = decoder.CurrentWeightVector(); SparseVector<weight_t> lambdas, cumulative_penalties, w_average; - if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights); - Weights::InitSparseVector(dense_weights, &lambdas); + if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights); + Weights::InitSparseVector(decoder_weights, &lambdas); // meta params for perceptron, SVM weight_t eta = cfg["learning_rate"].as<weight_t>(); weight_t gamma = cfg["gamma"].as<weight_t>(); // faster perceptron: consider only misranked pairs, see - // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin! bool faster_perceptron = false; if (gamma==0 && loss_margin==0) faster_perceptron = true; @@ -208,13 +221,24 @@ main(int argc, char** argv) // output string output_fn = cfg["output"].as<string>(); // input - string input_fn = cfg["input"].as<string>(); + bool read_bitext = false; + string input_fn; + if (cfg.count("bitext")) { + read_bitext = true; + input_fn = cfg["bitext"].as<string>(); + } else { + input_fn = cfg["input"].as<string>(); + } ReadFile input(input_fn); // buffer input for t > 0 vector<string> src_str_buf; // source strings (decoder takes only strings) vector<vector<WordID> > ref_ids_buf; // references as WordID vecs - string refs_fn = cfg["refs"].as<string>(); - ReadFile refs(refs_fn); + ReadFile refs; + string refs_fn; + if (!read_bitext) { + refs_fn = cfg["refs"].as<string>(); + refs.Init(refs_fn); + } unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size vector<pair<score_t, score_t> > all_scores; @@ -229,6 +253,7 @@ main(int argc, char** argv) cerr << setw(25) << "k " << k << endl; cerr << setw(25) << "N " << N << endl; cerr << setw(25) << "T " << T << endl; + cerr << setw(25) << "batch " << batch << endl; cerr << setw(26) << "scorer '" << scorer_str << "'" << endl; if (scorer_str == "approx_bleu") cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl; @@ -249,10 +274,14 @@ main(int argc, char** argv) cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl; if (rescale) cerr << setw(25) << "rescale " << rescale << endl; + cerr << setw(25) << "pclr " << pclr << endl; cerr << setw(25) << "max pairs " << max_pairs << endl; + cerr << setw(25) << "repeat " << repeat << endl; + //cerr << setw(25) << "test k-best " << test_k_best << endl; cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl; cerr << setw(25) << "input " << "'" << input_fn << "'" << endl; - cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; + if (!read_bitext) + cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl; cerr << setw(25) << "output " << "'" << output_fn << "'" << endl; if (cfg.count("input_weights")) cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl; @@ -261,6 +290,11 @@ main(int argc, char** argv) if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl; } + // pclr + SparseVector<weight_t> learning_rates; + // batch + SparseVector<weight_t> batch_updates; + score_t batch_loss; for (unsigned t = 0; t < T; t++) // T epochs { @@ -269,16 +303,24 @@ main(int argc, char** argv) time(&start); score_t score_sum = 0.; score_t model_sum(0); - unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0; + unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0, kbest_loss_improve = 0; + batch_loss = 0.; if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl; while(true) { string in; + string ref; bool next = false, stop = false; // next iteration or premature stop if (t == 0) { if(!getline(*input, in)) next = true; + if(read_bitext) { + vector<string> strs; + boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| ")); + in = strs[0]; + ref = strs[1]; + } } else { if (ii == in_sz) next = true; // stop if we reach the end of our input } @@ -310,15 +352,16 @@ main(int argc, char** argv) if (next || stop) break; // weights - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); // getting input vector<WordID> ref_ids; // reference as vector<WordID> if (t == 0) { - string r_; - getline(*refs, r_); + if (!read_bitext) { + getline(*refs, ref); + } vector<string> ref_tok; - boost::split(ref_tok, r_, boost::is_any_of(" ")); + boost::split(ref_tok, ref, boost::is_any_of(" ")); register_and_convert(ref_tok, ref_ids); ref_ids_buf.push_back(ref_ids); src_str_buf.push_back(in); @@ -348,8 +391,10 @@ main(int argc, char** argv) } } - score_sum += (*samples)[0].score; // stats for 1best - model_sum += (*samples)[0].model; + if (repeat == 1) { + score_sum += (*samples)[0].score; // stats for 1best + model_sum += (*samples)[0].model; + } f_count += observer->get_f_count(); list_sz += observer->get_sz(); @@ -364,30 +409,74 @@ main(int argc, char** argv) partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo); if (pair_sampling == "PRO") PROsampling(samples, pairs, pair_threshold, max_pairs); - npairs += pairs.size(); + int cur_npairs = pairs.size(); + npairs += cur_npairs; + + score_t kbest_loss_first, kbest_loss_last = 0.0; - SparseVector<weight_t> lambdas_copy; + for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); + it != pairs.end(); it++) { + score_t model_diff = it->first.model - it->second.model; + kbest_loss_first += max(0.0, -1.0 * model_diff); + } + + for (int ki=0; ki < repeat; ki++) { + + score_t kbest_loss = 0.0; // test-k-best + SparseVector<weight_t> lambdas_copy; // for l1 regularization + SparseVector<weight_t> sum_up; // for pclr if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas; for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin(); it != pairs.end(); it++) { - bool rank_error; + score_t model_diff = it->first.model - it->second.model; + if (repeat > 1) { + model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f); + kbest_loss += max(0.0, -1.0 * model_diff); + } + bool rank_error = false; score_t margin; if (faster_perceptron) { // we only have considering misranked pairs rank_error = true; // pair sampling already did this for us margin = std::numeric_limits<float>::max(); } else { - rank_error = it->first.model <= it->second.model; - margin = fabs(it->first.model - it->second.model); + rank_error = model_diff<=0.0; + margin = fabs(model_diff); if (!rank_error && margin < loss_margin) margin_violations++; } - if (rank_error) rank_errors++; + if (rank_error && ki==1) rank_errors++; if (scale_bleu_diff) eta = it->first.score - it->second.score; if (rank_error || margin < loss_margin) { SparseVector<weight_t> diff_vec = it->first.f - it->second.f; - lambdas.plus_eq_v_times_s(diff_vec, eta); - if (gamma) - lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + if (batch) { + batch_loss += max(0., -1.0*model_diff); + batch_updates += diff_vec; + continue; + } + if (pclr != "no") { + sum_up += diff_vec; + } else { + lambdas.plus_eq_v_times_s(diff_vec, eta); + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs)); + } + } + } + + // per-coordinate learning rate + if (pclr != "no") { + SparseVector<weight_t>::iterator it = sum_up.begin(); + for (; it != sum_up.end(); ++it) { + if (pclr == "simple") { + lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]); + learning_rates[it->first]++; + } else if (pclr == "adagrad") { + if (learning_rates[it->first] == 0) { + lambdas[it->first] += it->second * eta; + } else { + lambdas[it->first] += it->second * eta * learning_rates[it->first]; + } + learning_rates[it->first] += pow(it->second, 2.0); + } } } @@ -395,14 +484,16 @@ main(int argc, char** argv) // please note that this regularizations happen // after a _sentence_ -- not after each example/pair! if (l1naive) { - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { + it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME + learning_rates[it->first]++; it->second -= sign(it->second) * l1_reg; } } } else if (l1clip) { - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { @@ -417,7 +508,7 @@ main(int argc, char** argv) } } else if (l1cumul) { weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input - FastSparseVector<weight_t>::iterator it = lambdas.begin(); + SparseVector<weight_t>::iterator it = lambdas.begin(); for (; it != lambdas.end(); ++it) { if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) { if (it->second != 0) { @@ -435,7 +526,28 @@ main(int argc, char** argv) } } - } + if (ki==repeat-1) { // done + kbest_loss_last = kbest_loss; + if (repeat > 1) { + score_t best_score = -1.; + score_t best_model = -std::numeric_limits<score_t>::max(); + unsigned best_idx; + for (unsigned i=0; i < samples->size(); i++) { + score_t s = lambdas.dot((*samples)[i].f); + if (s > best_model) { + best_idx = i; + best_model = s; + } + } + score_sum += (*samples)[best_idx].score; + model_sum += best_model; + } + } + } // repeat + + if ((kbest_loss_first - kbest_loss_last) >= 0) kbest_loss_improve++; + + } // noup if (rescale) lambdas /= lambdas.l2norm(); @@ -443,14 +555,19 @@ main(int argc, char** argv) } // input loop - if (average) w_average += lambdas; + if (t == 0) in_sz = ii; // remember size of input (# lines) - if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); - if (t == 0) { - in_sz = ii; // remember size of input (# lines) + if (batch) { + lambdas.plus_eq_v_times_s(batch_updates, eta); + if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs)); + batch_updates.clear(); } + if (average) w_average += lambdas; + + if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset(); + // print some stats score_t score_avg = score_sum/(score_t)in_sz; score_t model_avg = model_sum/(score_t)in_sz; @@ -477,13 +594,15 @@ main(int argc, char** argv) cerr << _np << " 1best avg model score: " << model_avg; cerr << _p << " (" << model_diff << ")" << endl; cerr << " avg # pairs: "; - cerr << _np << npairs/(float)in_sz; + cerr << _np << npairs/(float)in_sz << endl; + cerr << " avg # rank err: "; + cerr << rank_errors/(float)in_sz; if (faster_perceptron) cerr << " (meaningless)"; cerr << endl; - cerr << " avg # rank err: "; - cerr << rank_errors/(float)in_sz << endl; cerr << " avg # margin viol: "; cerr << margin_violations/(float)in_sz << endl; + if (batch) cerr << " batch loss: " << batch_loss << endl; + cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl; cerr << " non0 feature count: " << nonz << endl; cerr << " avg list sz: " << list_sz/(float)in_sz << endl; cerr << " avg f count: " << f_count/(float)list_sz << endl; @@ -510,9 +629,9 @@ main(int argc, char** argv) // write weights to file if (select_weights == "best" || keep) { - lambdas.init_vector(&dense_weights); + lambdas.init_vector(&decoder_weights); string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz"; - Weights::WriteToFile(w_fn, dense_weights, true); + Weights::WriteToFile(w_fn, decoder_weights, true); } } // outer loop diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h index 3981fb39..ccb5ad4d 100644 --- a/training/dtrain/dtrain.h +++ b/training/dtrain/dtrain.h @@ -9,6 +9,8 @@ #include <string.h> #include <boost/algorithm/string.hpp> +#include <boost/regex.hpp> +#include <boost/algorithm/string/regex.hpp> #include <boost/program_options.hpp> #include "decoder.h" diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini index 23e94285..fc83f08e 100644 --- a/training/dtrain/examples/standard/dtrain.ini +++ b/training/dtrain/examples/standard/dtrain.ini @@ -1,5 +1,6 @@ -input=./nc-wmt11.de.gz -refs=./nc-wmt11.en.gz +#input=./nc-wmt11.de.gz +#refs=./nc-wmt11.en.gz +bitext=./nc-wmt11.gz output=- # a weights file (add .gz for gzip compression) or STDOUT '-' select_weights=VOID # output average (over epochs) weight vector decoder_config=./cdec.ini # config for cdec @@ -10,11 +11,11 @@ print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 Phr stop_after=10 # stop epoch after 10 inputs # interesting stuff -epochs=2 # run over input 2 times +epochs=3 # run over input 3 times k=100 # use 100best lists N=4 # optimize (approx) BLEU4 scorer=fixed_stupid_bleu # use 'stupid' BLEU+1 -learning_rate=1.0 # learning rate, don't care if gamma=0 (perceptron) +learning_rate=0.1 # learning rate, don't care if gamma=0 (perceptron) and loss_margin=0 (not margin perceptron) gamma=0 # use SVM reg sample_from=kbest # use kbest lists (as opposed to forest) filter=uniq # only unique entries in kbest (surface form) @@ -22,3 +23,5 @@ pair_sampling=XYX # hi_lo=0.1 # 10 vs 80 vs 10 and 80 vs 10 here pair_threshold=0 # minimum distance in BLEU (here: > 0) loss_margin=0 # update if correctly ranked, but within this margin +repeat=1 # repeat training on a kbest list 1 times +#batch=true # batch tuning, update after accumulating over all sentences and all kbest lists diff --git a/training/dtrain/examples/standard/expected-output b/training/dtrain/examples/standard/expected-output index 21f91244..75f47337 100644 --- a/training/dtrain/examples/standard/expected-output +++ b/training/dtrain/examples/standard/expected-output @@ -4,17 +4,18 @@ Reading ./nc-wmt11.en.srilm.gz ----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100 **************************************************************************************************** Example feature: Shape_S00000_T00000 -Seeding random number sequence to 970626287 +Seeding random number sequence to 3751911392 dtrain Parameters: k 100 N 4 - T 2 + T 3 + batch 0 scorer 'fixed_stupid_bleu' sample from 'kbest' filter 'uniq' - learning rate 1 + learning rate 0.1 gamma 0 loss margin 0 faster perceptron 1 @@ -23,69 +24,99 @@ Parameters: pair threshold 0 select weights 'VOID' l1 reg 0 'none' + pclr no max pairs 4294967295 + repeat 1 cdec cfg './cdec.ini' - input './nc-wmt11.de.gz' - refs './nc-wmt11.en.gz' + input './nc-wmt11.gz' output '-' stop_after 10 (a dot represents 10 inputs) -Iteration #1 of 2. +Iteration #1 of 3. . 10 Stopping after 10 input sentences. WEIGHTS - Glue = -614 - WordPenalty = +1256.8 - LanguageModel = +5610.5 - LanguageModel_OOV = -1449 - PhraseModel_0 = -2107 - PhraseModel_1 = -4666.1 - PhraseModel_2 = -2713.5 - PhraseModel_3 = +4204.3 - PhraseModel_4 = -1435.8 - PhraseModel_5 = +916 - PhraseModel_6 = +190 - PassThrough = -2527 + Glue = -110 + WordPenalty = -8.2082 + LanguageModel = -319.91 + LanguageModel_OOV = -19.2 + PhraseModel_0 = +312.82 + PhraseModel_1 = -161.02 + PhraseModel_2 = -433.65 + PhraseModel_3 = +291.03 + PhraseModel_4 = +252.32 + PhraseModel_5 = +50.6 + PhraseModel_6 = +146.7 + PassThrough = -38.7 --- - 1best avg score: 0.17874 (+0.17874) - 1best avg model score: 88399 (+88399) - avg # pairs: 798.2 (meaningless) - avg # rank err: 798.2 + 1best avg score: 0.16966 (+0.16966) + 1best avg model score: 29874 (+29874) + avg # pairs: 906.3 + avg # rank err: 0 (meaningless) avg # margin viol: 0 - non0 feature count: 887 + k-best loss imp: 100% + non0 feature count: 832 avg list sz: 91.3 - avg f count: 126.85 -(time 0.33 min, 2 s/S) + avg f count: 139.77 +(time 0.35 min, 2.1 s/S) -Iteration #2 of 2. +Iteration #2 of 3. . 10 WEIGHTS - Glue = -1025 - WordPenalty = +1751.5 - LanguageModel = +10059 - LanguageModel_OOV = -4490 - PhraseModel_0 = -2640.7 - PhraseModel_1 = -3757.4 - PhraseModel_2 = -1133.1 - PhraseModel_3 = +1837.3 - PhraseModel_4 = -3534.3 - PhraseModel_5 = +2308 - PhraseModel_6 = +1677 - PassThrough = -6222 + Glue = -122.1 + WordPenalty = +83.689 + LanguageModel = +233.23 + LanguageModel_OOV = -145.1 + PhraseModel_0 = +150.72 + PhraseModel_1 = -272.84 + PhraseModel_2 = -418.36 + PhraseModel_3 = +181.63 + PhraseModel_4 = -289.47 + PhraseModel_5 = +140.3 + PhraseModel_6 = +3.5 + PassThrough = -109.7 --- - 1best avg score: 0.30764 (+0.12891) - 1best avg model score: -2.5042e+05 (-3.3882e+05) - avg # pairs: 725.9 (meaningless) - avg # rank err: 725.9 + 1best avg score: 0.17399 (+0.004325) + 1best avg model score: 4936.9 (-24937) + avg # pairs: 662.4 + avg # rank err: 0 (meaningless) avg # margin viol: 0 - non0 feature count: 1499 + k-best loss imp: 100% + non0 feature count: 1240 avg list sz: 91.3 - avg f count: 114.34 -(time 0.32 min, 1.9 s/S) + avg f count: 125.11 +(time 0.27 min, 1.6 s/S) + +Iteration #3 of 3. + . 10 +WEIGHTS + Glue = -157.4 + WordPenalty = -1.7372 + LanguageModel = +686.18 + LanguageModel_OOV = -399.7 + PhraseModel_0 = -39.876 + PhraseModel_1 = -341.96 + PhraseModel_2 = -318.67 + PhraseModel_3 = +105.08 + PhraseModel_4 = -290.27 + PhraseModel_5 = -48.6 + PhraseModel_6 = -43.6 + PassThrough = -298.5 + --- + 1best avg score: 0.30742 (+0.13343) + 1best avg model score: -15393 (-20329) + avg # pairs: 623.8 + avg # rank err: 0 (meaningless) + avg # margin viol: 0 + k-best loss imp: 100% + non0 feature count: 1776 + avg list sz: 91.3 + avg f count: 118.58 +(time 0.28 min, 1.7 s/S) Writing weights file to '-' ... done --- -Best iteration: 2 [SCORE 'fixed_stupid_bleu'=0.30764]. -This took 0.65 min. +Best iteration: 3 [SCORE 'fixed_stupid_bleu'=0.30742]. +This took 0.9 min. diff --git a/training/dtrain/examples/standard/nc-wmt11.gz b/training/dtrain/examples/standard/nc-wmt11.gz Binary files differnew file mode 100644 index 00000000..c39c5aef --- /dev/null +++ b/training/dtrain/examples/standard/nc-wmt11.gz diff --git a/training/dtrain/parallelize.rb b/training/dtrain/parallelize.rb index 285f3c9b..60ca9422 100755 --- a/training/dtrain/parallelize.rb +++ b/training/dtrain/parallelize.rb @@ -21,6 +21,8 @@ opts = Trollop::options do opt :qsub, "use qsub", :type => :bool, :default => false opt :dtrain_binary, "path to dtrain binary", :type => :string opt :extra_qsub, "extra qsub args", :type => :string, :default => "" + opt :per_shard_decoder_configs, "give special decoder config per shard", :type => :string, :short => '-o' + opt :first_input_weights, "input weights for first iter", :type => :string, :default => '', :short => '-w' end usage if not opts[:config]&&opts[:shards]&&opts[:input]&&opts[:references] @@ -41,9 +43,11 @@ epochs = opts[:epochs] rand = opts[:randomize] reshard = opts[:reshard] predefined_shards = false +per_shard_decoder_configs = false if opts[:shards] == 0 predefined_shards = true num_shards = 0 + per_shard_decoder_configs = true if opts[:per_shard_decoder_configs] else num_shards = opts[:shards] end @@ -51,6 +55,7 @@ input = opts[:input] refs = opts[:references] use_qsub = opts[:qsub] shards_at_once = opts[:processes_at_once] +first_input_weights = opts[:first_input_weights] `mkdir work` @@ -101,6 +106,9 @@ refs_files = [] if predefined_shards input_files = File.new(input).readlines.map {|i| i.strip } refs_files = File.new(refs).readlines.map {|i| i.strip } + if per_shard_decoder_configs + decoder_configs = File.new(opts[:per_shard_decoder_configs]).readlines.map {|i| i.strip} + end num_shards = input_files.size else input_files, refs_files = make_shards input, refs, num_shards, 0, rand @@ -126,10 +134,18 @@ end else local_end = "2>work/out.#{shard}.#{epoch}" end + if per_shard_decoder_configs + cdec_cfg = "--decoder_config #{decoder_configs[shard]}" + else + cdec_cfg = "" + end + if first_input_weights!='' && epoch == 0 + input_weights = "--input_weights #{first_input_weights}" + end pids << Kernel.fork { - `#{qsub_str_start}#{dtrain_bin} -c #{ini}\ + `#{qsub_str_start}#{dtrain_bin} -c #{ini} #{cdec_cfg} #{input_weights}\ --input #{input_files[shard]}\ - --refs #{refs_files[shard]} #{input_weights}\ + --refs #{refs_files[shard]}\ --output work/weights.#{shard}.#{epoch}#{qsub_str_end} #{local_end}` } weights_files << "work/weights.#{shard}.#{epoch}" diff --git a/utils/filelib.h b/utils/filelib.h index b9ea3940..4fa69760 100644 --- a/utils/filelib.h +++ b/utils/filelib.h @@ -75,7 +75,10 @@ class ReadFile : public BaseFile<std::istream> { } } } - + void ReadAll(std::string& s) { + getline(*stream(), s, (char) EOF); + if (s.size() > 0) s.resize(s.size()-1); + } }; class WriteFile : public BaseFile<std::ostream> { |