summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2013-11-13 11:22:24 -0800
committerChris Dyer <redpony@gmail.com>2013-11-13 11:22:24 -0800
commit9be8d89b5a4065a81b26d8af1f3443d152e7922a (patch)
tree237090ff519a0419c3ba379ec3a6884f05caa6c2
parent8a24bb77bc2e9fd17a6f6529a2942cde96a6af49 (diff)
parent4a9449a564e626fe004200b730bfaa44d6152e0f (diff)
Merge pull request #27 from pks/master
Tidying (soft) syntax features; loo for C++ extractor; updates for dtrain
-rw-r--r--decoder/Makefile.am8
-rw-r--r--decoder/cdec_ff.cc25
-rw-r--r--decoder/ff_parse_match.cc18
-rw-r--r--decoder/ff_parse_match.h1
-rw-r--r--decoder/ff_soft_syntax.cc49
-rw-r--r--decoder/ff_soft_syntax.h16
-rw-r--r--decoder/ff_soft_syntax_mindist.cc (renamed from decoder/ff_soft_syntax2.cc)58
-rw-r--r--decoder/ff_soft_syntax_mindist.h (renamed from decoder/ff_soft_syntax2.h)16
-rw-r--r--decoder/ff_source_syntax.cc43
-rw-r--r--decoder/ff_source_syntax.h10
-rw-r--r--decoder/ff_source_syntax2.cc32
-rw-r--r--decoder/ff_source_syntax2.h5
-rw-r--r--decoder/ff_source_syntax2_p.cc170
-rw-r--r--decoder/ff_source_syntax2_p.h25
-rw-r--r--decoder/ff_source_syntax_p.cc250
-rw-r--r--decoder/ff_source_syntax_p.h42
-rw-r--r--extractor/grammar_extractor.cc6
-rw-r--r--extractor/grammar_extractor.h3
-rw-r--r--extractor/grammar_extractor_test.cc7
-rw-r--r--extractor/mocks/mock_rule_factory.h2
-rw-r--r--extractor/rule_factory.cc5
-rw-r--r--extractor/rule_factory.h3
-rw-r--r--extractor/rule_factory_test.cc8
-rw-r--r--extractor/run_extractor.cc13
-rw-r--r--extractor/sample_source.txt2
-rw-r--r--extractor/sampler.cc35
-rw-r--r--extractor/sampler.h5
-rw-r--r--extractor/sampler_test.cc24
-rw-r--r--extractor/sampler_test_blacklist.cc102
-rw-r--r--training/dtrain/Makefile.am2
-rw-r--r--training/dtrain/README.md30
-rw-r--r--training/dtrain/dtrain.cc201
-rw-r--r--training/dtrain/dtrain.h2
-rw-r--r--training/dtrain/examples/standard/dtrain.ini11
-rw-r--r--training/dtrain/examples/standard/expected-output125
-rw-r--r--training/dtrain/examples/standard/nc-wmt11.gzbin0 -> 113504 bytes
-rwxr-xr-xtraining/dtrain/parallelize.rb20
-rw-r--r--utils/filelib.h5
38 files changed, 620 insertions, 759 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index f02299e6..8280b22c 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -61,12 +61,10 @@ libcdec_a_SOURCES = \
ff_ruleshape.h \
ff_sample_fsa.h \
ff_soft_syntax.h \
- ff_soft_syntax2.h \
+ ff_soft_syntax_mindist.h \
ff_source_path.h \
ff_source_syntax.h \
ff_source_syntax2.h \
- ff_source_syntax2_p.h \
- ff_source_syntax_p.h \
ff_spans.h \
ff_tagger.h \
ff_wordalign.h \
@@ -127,12 +125,10 @@ libcdec_a_SOURCES = \
ff_rules.cc \
ff_ruleshape.cc \
ff_soft_syntax.cc \
- ff_soft_syntax2.cc \
+ ff_soft_syntax_mindist.cc \
ff_source_path.cc \
ff_source_syntax.cc \
ff_source_syntax2.cc \
- ff_source_syntax2_p.cc \
- ff_source_syntax_p.cc \
ff_spans.cc \
ff_tagger.cc \
ff_wordalign.cc \
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 09597e87..d586c1d1 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -15,17 +15,11 @@
#include "ff_ruleshape.h"
#include "ff_bleu.h"
#include "ff_soft_syntax.h"
-#include "ff_soft_syntax2.h"
+#include "ff_soft_syntax_mindist.h"
#include "ff_source_path.h"
-
-
#include "ff_parse_match.h"
#include "ff_source_syntax.h"
-#include "ff_source_syntax_p.h"
#include "ff_source_syntax2.h"
-#include "ff_source_syntax2_p.h"
-
-
#include "ff_register.h"
#include "ff_charset.h"
#include "ff_wordset.h"
@@ -51,23 +45,12 @@ void register_feature_functions() {
ff_registry.Register("NgramFeatures", new FFFactory<NgramDetector>());
ff_registry.Register("RuleContextFeatures", new FFFactory<RuleContextFeatures>());
ff_registry.Register("RuleIdentityFeatures", new FFFactory<RuleIdentityFeatures>());
-
-
ff_registry.Register("ParseMatchFeatures", new FFFactory<ParseMatchFeatures>);
-
- ff_registry.Register("SoftSyntacticFeatures", new FFFactory<SoftSyntacticFeatures>);
- ff_registry.Register("SoftSyntacticFeatures2", new FFFactory<SoftSyntacticFeatures2>);
-
+ ff_registry.Register("SoftSyntaxFeatures", new FFFactory<SoftSyntaxFeatures>);
+ ff_registry.Register("SoftSyntaxFeaturesMindist", new FFFactory<SoftSyntaxFeaturesMindist>);
ff_registry.Register("SourceSyntaxFeatures", new FFFactory<SourceSyntaxFeatures>);
- ff_registry.Register("SourceSyntaxFeatures2", new FFFactory<SourceSyntaxFeatures2>);
-
ff_registry.Register("SourceSpanSizeFeatures", new FFFactory<SourceSpanSizeFeatures>);
-
- //ff_registry.Register("PSourceSyntaxFeatures", new FFFactory<PSourceSyntaxFeatures>);
- //ff_registry.Register("PSourceSpanSizeFeatures", new FFFactory<PSourceSpanSizeFeatures>);
- //ff_registry.Register("PSourceSyntaxFeatures2", new FFFactory<PSourceSyntaxFeatures2>);
-
-
+ ff_registry.Register("SourceSyntaxFeatures2", new FFFactory<SourceSyntaxFeatures2>);
ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory<CMR2008ReorderingFeatures>());
ff_registry.Register("RuleSourceBigramFeatures", new FFFactory<RuleSourceBigramFeatures>());
ff_registry.Register("RuleTargetBigramFeatures", new FFFactory<RuleTargetBigramFeatures>());
diff --git a/decoder/ff_parse_match.cc b/decoder/ff_parse_match.cc
index ed556b91..58026975 100644
--- a/decoder/ff_parse_match.cc
+++ b/decoder/ff_parse_match.cc
@@ -42,10 +42,8 @@ struct ParseMatchFeaturesImpl {
void InitializeGrids(const string& tree, unsigned src_len) {
assert(tree.size() > 0);
- //fids_cat.clear();
fids_ef.clear();
src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
fids_ef.resize(src_len, src_len + 1);
src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
ParseTreeString(tree, src_len);
@@ -112,7 +110,7 @@ struct ParseMatchFeaturesImpl {
int fid_ef = FD::Convert("PM");
int min_dist; // minimal distance to next syntactic constituent of this rule's LHS
int summed_min_dists; // minimal distances of LHS and NTs summed up
- if (TD::Convert(lhs).compare("XX") != 0)
+ if (TD::Convert(lhs).compare("XX") != 0)
min_dist= 0;
// compute the distance to the next syntactical constituent
else {
@@ -131,7 +129,7 @@ struct ParseMatchFeaturesImpl {
ok = 1;
break;
}
- // check if removing k words from the rule span will
+ // check if removing k words from the rule span will
// lead to a syntactical constituent
else {
//cerr << "Hilfe...!" << endl;
@@ -144,7 +142,7 @@ struct ParseMatchFeaturesImpl {
ok = 1;
break;
}
- }
+ }
}
if (ok) break;
}
@@ -183,9 +181,9 @@ struct ParseMatchFeaturesImpl {
return min_dist;
}
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
unsigned int src_sent_len;
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
int scoring_method;
};
@@ -214,5 +212,9 @@ void ParseMatchFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
void ParseMatchFeatures::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+ ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree"));
+ string tree;
+ f.ReadAll(tree);
+ impl->InitializeGrids(tree, smeta.GetSourceLength());
}
+
diff --git a/decoder/ff_parse_match.h b/decoder/ff_parse_match.h
index fa73481a..7820b418 100644
--- a/decoder/ff_parse_match.h
+++ b/decoder/ff_parse_match.h
@@ -23,3 +23,4 @@ class ParseMatchFeatures : public FeatureFunction {
};
#endif
+
diff --git a/decoder/ff_soft_syntax.cc b/decoder/ff_soft_syntax.cc
index 9981fa45..23fe87bd 100644
--- a/decoder/ff_soft_syntax.cc
+++ b/decoder/ff_soft_syntax.cc
@@ -13,16 +13,15 @@
using namespace std;
-// Implements the soft syntactic features described in
+// Implements the soft syntactic features described in
// Marton and Resnik (2008): "Soft Syntacitc Constraints for Hierarchical Phrase-Based Translation".
// Source trees must be represented in Penn Treebank format,
// e.g. (S (NP John) (VP (V left))).
-struct SoftSyntacticFeaturesImpl {
- SoftSyntacticFeaturesImpl(const string& param) {
+struct SoftSyntaxFeaturesImpl {
+ SoftSyntaxFeaturesImpl(const string& param) {
vector<string> labels = SplitOnWhitespace(param);
- for (unsigned int i = 0; i < labels.size(); i++)
- //cerr << "Labels: " << labels.at(i) << endl;
+ //for (unsigned int i = 0; i < labels.size(); i++) { cerr << "Labels: " << labels.at(i) << endl; }
for (unsigned int i = 0; i < labels.size(); i++) {
string label = labels.at(i);
pair<string, string> feat_label;
@@ -34,10 +33,8 @@ struct SoftSyntacticFeaturesImpl {
void InitializeGrids(const string& tree, unsigned src_len) {
assert(tree.size() > 0);
- //fids_cat.clear();
fids_ef.clear();
src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
fids_ef.resize(src_len, src_len + 1);
src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
ParseTreeString(tree, src_len);
@@ -99,7 +96,7 @@ struct SoftSyntacticFeaturesImpl {
const WordID lhs = src_tree(i,j);
string lhs_str = TD::Convert(lhs);
//cerr << "LHS: " << lhs_str << " from " << i << " to " << j << endl;
- //cerr << "RULE :"<< rule << endl;
+ //cerr << "RULE :"<< rule << endl;
int& fid_ef = fids_ef(i,j)[&rule];
for (unsigned int i = 0; i < feat_labels.size(); i++) {
ostringstream os;
@@ -110,10 +107,10 @@ struct SoftSyntacticFeaturesImpl {
switch(feat_type) {
case '2':
if (lhs_str.compare(label) == 0) {
- os << "SYN:" << label << "_conform";
+ os << "SOFT:" << label << "_conform";
}
else {
- os << "SYN:" << label << "_cross";
+ os << "SOFT:" << label << "_cross";
}
fid_ef = FD::Convert(os.str());
if (fid_ef > 0) {
@@ -122,11 +119,11 @@ struct SoftSyntacticFeaturesImpl {
}
break;
case '_':
- os << "SYN:" << label;
+ os << "SOFT:" << label;
fid_ef = FD::Convert(os.str());
if (lhs_str.compare(label) == 0) {
if (fid_ef > 0) {
- //cerr << "Feature: " << os.str() << endl;
+ //cerr << "Feature: " << os.str() << endl;
feats->set_value(fid_ef, 1.0);
}
}
@@ -139,7 +136,7 @@ struct SoftSyntacticFeaturesImpl {
break;
case '+':
if (lhs_str.compare(label) == 0) {
- os << "SYN:" << label << "_conform";
+ os << "SOFT:" << label << "_conform";
fid_ef = FD::Convert(os.str());
if (fid_ef > 0) {
//cerr << "Feature: " << os.str() << endl;
@@ -147,10 +144,10 @@ struct SoftSyntacticFeaturesImpl {
}
}
break;
- case '-':
- //cerr << "-" << endl;
+ case '-':
+ //cerr << "-" << endl;
if (lhs_str.compare(label) != 0) {
- os << "SYN:" << label << "_cross";
+ os << "SOFT:" << label << "_cross";
fid_ef = FD::Convert(os.str());
if (fid_ef > 0) {
//cerr << "Feature :" << os.str() << endl;
@@ -167,22 +164,22 @@ struct SoftSyntacticFeaturesImpl {
return lhs;
}
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
vector<pair<string, string> > feat_labels;
};
-SoftSyntacticFeatures::SoftSyntacticFeatures(const string& param) :
+SoftSyntaxFeatures::SoftSyntaxFeatures(const string& param) :
FeatureFunction(sizeof(WordID)) {
- impl = new SoftSyntacticFeaturesImpl(param);
+ impl = new SoftSyntaxFeaturesImpl(param);
}
-SoftSyntacticFeatures::~SoftSyntacticFeatures() {
+SoftSyntaxFeatures::~SoftSyntaxFeatures() {
delete impl;
impl = NULL;
}
-void SoftSyntacticFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void SoftSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const vector<const void*>& ant_contexts,
SparseVector<double>* features,
@@ -196,6 +193,10 @@ void SoftSyntacticFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
}
-void SoftSyntacticFeatures::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+void SoftSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) {
+ ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree"));
+ string tree;
+ f.ReadAll(tree);
+ impl->InitializeGrids(tree, smeta.GetSourceLength());
}
+
diff --git a/decoder/ff_soft_syntax.h b/decoder/ff_soft_syntax.h
index 79352f49..e71825d5 100644
--- a/decoder/ff_soft_syntax.h
+++ b/decoder/ff_soft_syntax.h
@@ -1,15 +1,15 @@
-#ifndef _FF_SOFTSYNTAX_H_
-#define _FF_SOFTSYNTAX_H_
+#ifndef _FF_SOFT_SYNTAX_H_
+#define _FF_SOFT_SYNTAX_H_
#include "ff.h"
#include "hg.h"
-struct SoftSyntacticFeaturesImpl;
+struct SoftSyntaxFeaturesImpl;
-class SoftSyntacticFeatures : public FeatureFunction {
+class SoftSyntaxFeatures : public FeatureFunction {
public:
- SoftSyntacticFeatures(const std::string& param);
- ~SoftSyntacticFeatures();
+ SoftSyntaxFeatures(const std::string& param);
+ ~SoftSyntaxFeatures();
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
@@ -19,9 +19,9 @@ class SoftSyntacticFeatures : public FeatureFunction {
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
private:
- SoftSyntacticFeaturesImpl* impl;
+ SoftSyntaxFeaturesImpl* impl;
};
-
#endif
+
diff --git a/decoder/ff_soft_syntax2.cc b/decoder/ff_soft_syntax_mindist.cc
index 121bc39b..a23f70f8 100644
--- a/decoder/ff_soft_syntax2.cc
+++ b/decoder/ff_soft_syntax_mindist.cc
@@ -1,4 +1,4 @@
-#include "ff_soft_syntax2.h"
+#include "ff_soft_syntax_mindist.h"
#include <cstdio>
#include <sstream>
@@ -13,16 +13,18 @@
using namespace std;
-// Implements the soft syntactic features described in
+// Implements the soft syntactic features described in
// Marton and Resnik (2008): "Soft Syntacitc Constraints for Hierarchical Phrase-Based Translation".
// Source trees must be represented in Penn Treebank format,
// e.g. (S (NP John) (VP (V left))).
+//
+// This variant accepts fuzzy matches, choosing the constituent with
+// minimum distance.
-struct SoftSyntacticFeatures2Impl {
- SoftSyntacticFeatures2Impl(const string& param) {
+struct SoftSyntaxFeaturesMindistImpl {
+ SoftSyntaxFeaturesMindistImpl(const string& param) {
vector<string> labels = SplitOnWhitespace(param);
- //for (unsigned int i = 0; i < labels.size(); i++)
- //cerr << "Labels: " << labels.at(i) << endl;
+ //for (unsigned int i = 0; i < labels.size(); i++) { cerr << "Labels: " << labels.at(i) << endl; }
for (unsigned int i = 0; i < labels.size(); i++) {
string label = labels.at(i);
pair<string, string> feat_label;
@@ -30,14 +32,12 @@ struct SoftSyntacticFeatures2Impl {
feat_label.second = label.at(label.size() - 1);
feat_labels.push_back(feat_label);
}
- }
+ }
void InitializeGrids(const string& tree, unsigned src_len) {
assert(tree.size() > 0);
- //fids_cat.clear();
fids_ef.clear();
src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
fids_ef.resize(src_len, src_len + 1);
src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
ParseTreeString(tree, src_len);
@@ -99,14 +99,14 @@ struct SoftSyntacticFeatures2Impl {
const WordID lhs = src_tree(i,j);
string lhs_str = TD::Convert(lhs);
//cerr << "LHS: " << lhs_str << " from " << i << " to " << j << endl;
- //cerr << "RULE :"<< rule << endl;
+ //cerr << "RULE :"<< rule << endl;
int& fid_ef = fids_ef(i,j)[&rule];
string lhs_to_str = TD::Convert(lhs);
int min_dist;
string min_dist_label;
if (lhs_to_str.compare("XX") != 0) {
min_dist = 0;
- min_dist_label = lhs_to_str;
+ min_dist_label = lhs_to_str;
}
else {
int ok = 0;
@@ -128,7 +128,7 @@ struct SoftSyntacticFeatures2Impl {
min_dist_label = (TD::Convert(src_tree(l_rem, r_rem)));
break;
}
- }
+ }
}
if (ok) break;
}
@@ -146,10 +146,10 @@ struct SoftSyntacticFeatures2Impl {
case '2':
if (min_dist_label.compare(label) == 0) {
if (min_dist == 0) {
- os << "SYN:" << label << "_conform";
+ os << "SOFTM:" << label << "_conform";
}
else {
- os << "SYN:" << label << "_cross";
+ os << "SOFTM:" << label << "_cross";
}
fid_ef = FD::Convert(os.str());
//cerr << "Feature :" << os.str() << endl;
@@ -157,7 +157,7 @@ struct SoftSyntacticFeatures2Impl {
}
break;
case '_':
- os << "SYN:" << label;
+ os << "SOFTM:" << label;
fid_ef = FD::Convert(os.str());
if (min_dist_label.compare(label) == 0) {
//cerr << "Feature: " << os.str() << endl;
@@ -172,7 +172,7 @@ struct SoftSyntacticFeatures2Impl {
break;
case '+':
if (min_dist_label.compare(label) == 0) {
- os << "SYN:" << label << "_conform";
+ os << "SOFTM:" << label << "_conform";
fid_ef = FD::Convert(os.str());
if (min_dist == 0) {
//cerr << "Feature: " << os.str() << endl;
@@ -180,10 +180,10 @@ struct SoftSyntacticFeatures2Impl {
}
}
break;
- case '-':
- //cerr << "-" << endl;
+ case '-':
+ //cerr << "-" << endl;
if (min_dist_label.compare(label) != 0) {
- os << "SYN:" << label << "_cross";
+ os << "SOFTM:" << label << "_cross";
fid_ef = FD::Convert(os.str());
if (min_dist > 0) {
//cerr << "Feature :" << os.str() << endl;
@@ -200,22 +200,22 @@ struct SoftSyntacticFeatures2Impl {
return lhs;
}
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
vector<pair<string, string> > feat_labels;
};
-SoftSyntacticFeatures2::SoftSyntacticFeatures2(const string& param) :
+SoftSyntaxFeaturesMindist::SoftSyntaxFeaturesMindist(const string& param) :
FeatureFunction(sizeof(WordID)) {
- impl = new SoftSyntacticFeatures2Impl(param);
+ impl = new SoftSyntaxFeaturesMindistImpl(param);
}
-SoftSyntacticFeatures2::~SoftSyntacticFeatures2() {
+SoftSyntaxFeaturesMindist::~SoftSyntaxFeaturesMindist() {
delete impl;
impl = NULL;
}
-void SoftSyntacticFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+void SoftSyntaxFeaturesMindist::TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
const vector<const void*>& ant_contexts,
SparseVector<double>* features,
@@ -229,6 +229,10 @@ void SoftSyntacticFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta
impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
}
-void SoftSyntacticFeatures2::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+void SoftSyntaxFeaturesMindist::PrepareForInput(const SentenceMetadata& smeta) {
+ ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree"));
+ string tree;
+ f.ReadAll(tree);
+ impl->InitializeGrids(tree, smeta.GetSourceLength());
}
+
diff --git a/decoder/ff_soft_syntax2.h b/decoder/ff_soft_syntax_mindist.h
index 4de91d86..bf938b38 100644
--- a/decoder/ff_soft_syntax2.h
+++ b/decoder/ff_soft_syntax_mindist.h
@@ -1,15 +1,15 @@
-#ifndef _FF_SOFTSYNTAX2_H_
-#define _FF_SOFTSYNTAX2_H_
+#ifndef _FF_SOFT_SYNTAX_MINDIST_H_
+#define _FF_SOFT_SYNTAX_MINDIST_H_
#include "ff.h"
#include "hg.h"
-struct SoftSyntacticFeatures2Impl;
+struct SoftSyntaxFeaturesMindistImpl;
-class SoftSyntacticFeatures2 : public FeatureFunction {
+class SoftSyntaxFeaturesMindist : public FeatureFunction {
public:
- SoftSyntacticFeatures2(const std::string& param);
- ~SoftSyntacticFeatures2();
+ SoftSyntaxFeaturesMindist(const std::string& param);
+ ~SoftSyntaxFeaturesMindist();
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
const Hypergraph::Edge& edge,
@@ -19,9 +19,9 @@ class SoftSyntacticFeatures2 : public FeatureFunction {
void* context) const;
virtual void PrepareForInput(const SentenceMetadata& smeta);
private:
- SoftSyntacticFeatures2Impl* impl;
+ SoftSyntaxFeaturesMindistImpl* impl;
};
-
#endif
+
diff --git a/decoder/ff_source_syntax.cc b/decoder/ff_source_syntax.cc
index 88f6714c..6b183863 100644
--- a/decoder/ff_source_syntax.cc
+++ b/decoder/ff_source_syntax.cc
@@ -9,7 +9,6 @@
namespace std { using std::tr1::unordered_set; }
#endif
-#include "hg.h"
#include "sentence_metadata.h"
#include "array2d.h"
#include "filelib.h"
@@ -30,6 +29,17 @@ inline int SpanSizeTransform(unsigned span_size) {
struct SourceSyntaxFeaturesImpl {
SourceSyntaxFeaturesImpl() {}
+ SourceSyntaxFeaturesImpl(const string& param) {
+ if (!(param.compare("") == 0)) {
+ string triggered_features_fn = param;
+ ReadFile triggered_features(triggered_features_fn);
+ string in;
+ while(getline(*triggered_features, in)) {
+ feature_filter.insert(FD::Convert(in));
+ }
+ }
+ }
+
void InitializeGrids(const string& tree, unsigned src_len) {
assert(tree.size() > 0);
//fids_cat.clear();
@@ -99,7 +109,7 @@ struct SourceSyntaxFeaturesImpl {
if (fid_ef <= 0) {
ostringstream os;
//ostringstream os2;
- os << "SYN:" << TD::Convert(lhs);
+ os << "SSYN:" << TD::Convert(lhs);
//os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i);
//fid_cat = FD::Convert(os2.str());
os << ':';
@@ -124,21 +134,28 @@ struct SourceSyntaxFeaturesImpl {
}
fid_ef = FD::Convert(os.str());
}
- //if (fid_cat > 0)
- // feats->set_value(fid_cat, 1.0);
- if (fid_ef > 0)
- feats->set_value(fid_ef, 1.0);
+ if (fid_ef > 0) {
+ if (feature_filter.size()>0) {
+ if (feature_filter.find(fid_ef) != feature_filter.end()) {
+ feats->set_value(fid_ef, 1.0);
+ }
+ } else {
+ feats->set_value(fid_ef, 1.0);
+ }
+ }
+ cerr << FD::Convert(fid_ef) << endl;
return lhs;
}
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- // mutable Array2D<int> fids_cat; // this tends to overfit baddly
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ // mutable Array2D<int> fids_cat; // this tends to overfit baddly
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ unordered_set<int> feature_filter;
};
SourceSyntaxFeatures::SourceSyntaxFeatures(const string& param) :
FeatureFunction(sizeof(WordID)) {
- impl = new SourceSyntaxFeaturesImpl;
+ impl = new SourceSyntaxFeaturesImpl(param);
}
SourceSyntaxFeatures::~SourceSyntaxFeatures() {
@@ -161,7 +178,10 @@ void SourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
void SourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+ ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree"));
+ string tree;
+ f.ReadAll(tree);
+ impl->InitializeGrids(tree, smeta.GetSourceLength());
}
struct SourceSpanSizeFeaturesImpl {
@@ -236,4 +256,3 @@ void SourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) {
impl->InitializeGrids(smeta.GetSourceLength());
}
-
diff --git a/decoder/ff_source_syntax.h b/decoder/ff_source_syntax.h
index a8c7150a..bdd638c1 100644
--- a/decoder/ff_source_syntax.h
+++ b/decoder/ff_source_syntax.h
@@ -1,7 +1,8 @@
-#ifndef _FF_SOURCE_TOOLS_H_
-#define _FF_SOURCE_TOOLS_H_
+#ifndef _FF_SOURCE_SYNTAX_H_
+#define _FF_SOURCE_SYNTAX_H_
#include "ff.h"
+#include "hg.h"
struct SourceSyntaxFeaturesImpl;
@@ -11,7 +12,7 @@ class SourceSyntaxFeatures : public FeatureFunction {
~SourceSyntaxFeatures();
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const HG::Edge& edge,
+ const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
SparseVector<double>* estimated_features,
@@ -28,7 +29,7 @@ class SourceSpanSizeFeatures : public FeatureFunction {
~SourceSpanSizeFeatures();
protected:
virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const HG::Edge& edge,
+ const Hypergraph::Edge& edge,
const std::vector<const void*>& ant_contexts,
SparseVector<double>* features,
SparseVector<double>* estimated_features,
@@ -39,3 +40,4 @@ class SourceSpanSizeFeatures : public FeatureFunction {
};
#endif
+
diff --git a/decoder/ff_source_syntax2.cc b/decoder/ff_source_syntax2.cc
index 622c6908..a97e31d8 100644
--- a/decoder/ff_source_syntax2.cc
+++ b/decoder/ff_source_syntax2.cc
@@ -16,7 +16,7 @@ using namespace std;
struct SourceSyntaxFeatures2Impl {
SourceSyntaxFeatures2Impl(const string& param) {
- if (!(param.compare("") == 0)) {
+ if (param.compare("") != 0) {
string triggered_features_fn = param;
ReadFile triggered_features(triggered_features_fn);
string in;
@@ -28,10 +28,8 @@ struct SourceSyntaxFeatures2Impl {
void InitializeGrids(const string& tree, unsigned src_len) {
assert(tree.size() > 0);
- //fids_cat.clear();
fids_ef.clear();
src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
fids_ef.resize(src_len, src_len + 1);
src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
ParseTreeString(tree, src_len);
@@ -39,7 +37,7 @@ struct SourceSyntaxFeatures2Impl {
void ParseTreeString(const string& tree, unsigned src_len) {
//cerr << "TREE: " << tree << endl;
- stack<pair<int, WordID> > stk; // first = i, second = category
+ stack<pair<int, WordID> > stk; // first = i, second = category
pair<int, WordID> cur_cat; cur_cat.first = -1;
unsigned i = 0;
unsigned p = 0;
@@ -91,7 +89,7 @@ struct SourceSyntaxFeatures2Impl {
const WordID lhs = src_tree(i,j);
int& fid_ef = fids_ef(i,j)[&rule];
ostringstream os;
- os << "SYN:" << TD::Convert(lhs);
+ os << "SSYN2:" << TD::Convert(lhs);
os << ':';
unsigned ntc = 0;
for (unsigned k = 0; k < rule.f_.size(); ++k) {
@@ -99,7 +97,7 @@ struct SourceSyntaxFeatures2Impl {
if (k > 0 && fj <= 0) os << '_';
if (fj <= 0) {
os << '[' << TD::Convert(ants[ntc++]) << ']';
- } /*else {
+ }/*else {
os << TD::Convert(fj);
}*/
}
@@ -115,16 +113,22 @@ struct SourceSyntaxFeatures2Impl {
fid_ef = FD::Convert(os.str());
//cerr << "FEATURE: " << os.str() << endl;
//cerr << "FID_EF: " << fid_ef << endl;
- if (feature_filter.find(fid_ef) != feature_filter.end()) {
- cerr << "SYN-Feature was trigger more than once on training set." << endl;
+ if (feature_filter.size() > 0) {
+ if (feature_filter.find(fid_ef) != feature_filter.end()) {
+ //cerr << "SYN-Feature was trigger more than once on training set." << endl;
+ feats->set_value(fid_ef, 1.0);
+ }
+ //else cerr << "SYN-Feature was triggered less than once on training set." << endli;
+ }
+ else {
feats->set_value(fid_ef, 1.0);
}
- else cerr << "SYN-Feature was triggered less than once on training set." << endl;
+ cerr << FD::Convert(fid_ef) << endl;
return lhs;
}
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
+ Array2D<WordID> src_tree; // src_tree(i,j) NT = type
+ mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
unordered_set<int> feature_filter;
};
@@ -153,5 +157,9 @@ void SourceSyntaxFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta,
}
void SourceSyntaxFeatures2::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
+ ReadFile f = ReadFile(smeta.GetSGMLValue("src_tree"));
+ string tree;
+ f.ReadAll(tree);
+ impl->InitializeGrids(tree, smeta.GetSourceLength());
}
+
diff --git a/decoder/ff_source_syntax2.h b/decoder/ff_source_syntax2.h
index b6b7dc3d..f606c2bf 100644
--- a/decoder/ff_source_syntax2.h
+++ b/decoder/ff_source_syntax2.h
@@ -1,5 +1,5 @@
-#ifndef _FF_SOURCE_TOOLS2_H_
-#define _FF_SOURCE_TOOLS2_H_
+#ifndef _FF_SOURCE_SYNTAX2_H_
+#define _FF_SOURCE_SYNTAX2_H_
#include "ff.h"
#include "hg.h"
@@ -23,3 +23,4 @@ class SourceSyntaxFeatures2 : public FeatureFunction {
};
#endif
+
diff --git a/decoder/ff_source_syntax2_p.cc b/decoder/ff_source_syntax2_p.cc
deleted file mode 100644
index 6a2ae742..00000000
--- a/decoder/ff_source_syntax2_p.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-#include "ff_source_syntax2_p.h"
-
-#include <sstream>
-#include <stack>
-#include <string>
-#ifndef HAVE_OLD_CPP
-# include <unordered_set>
-#else
-# include <tr1/unordered_set>
-namespace std { using std::tr1::unordered_set; }
-#endif
-
-#include "sentence_metadata.h"
-#include "array2d.h"
-#include "filelib.h"
-
-using namespace std;
-
-// implements the source side syntax features described in Blunsom et al. (EMNLP 2008)
-// source trees must be represented in Penn Treebank format, e.g.
-// (S (NP John) (VP (V left)))
-
-struct PSourceSyntaxFeatures2Impl {
- PSourceSyntaxFeatures2Impl(const string& param) {
- if (param.compare("") != 0) {
- string triggered_features_fn = param;
- ReadFile triggered_features(triggered_features_fn);
- string in;
- while(getline(*triggered_features, in)) {
- feature_filter.insert(FD::Convert(in));
- }
- }
- /*cerr << "find(\"One\") == " << boolalpha << (table.find("One") != table.end()) << endl;
- cerr << "find(\"Three\") == " << boolalpha << (table.find("Three") != table.end()) << endl;*/
- }
-
- void InitializeGrids(const string& tree, unsigned src_len) {
- assert(tree.size() > 0);
- //fids_cat.clear();
- fids_ef.clear();
- src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
- fids_ef.resize(src_len, src_len + 1);
- src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
- ParseTreeString(tree, src_len);
- }
-
- void ParseTreeString(const string& tree, unsigned src_len) {
- //cerr << "TREE: " << tree << endl;
- stack<pair<int, WordID> > stk; // first = i, second = category
- pair<int, WordID> cur_cat; cur_cat.first = -1;
- unsigned i = 0;
- unsigned p = 0;
- while(p < tree.size()) {
- const char cur = tree[p];
- if (cur == '(') {
- stk.push(cur_cat);
- ++p;
- unsigned k = p + 1;
- while (k < tree.size() && tree[k] != ' ') { ++k; }
- cur_cat.first = i;
- cur_cat.second = TD::Convert(tree.substr(p, k - p));
- // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
- p = k + 1;
- } else if (cur == ')') {
- unsigned k = p;
- while (k < tree.size() && tree[k] == ')') { ++k; }
- const unsigned num_closes = k - p;
- for (unsigned ci = 0; ci < num_closes; ++ci) {
- src_tree(cur_cat.first, i) = cur_cat.second;
- cur_cat = stk.top();
- stk.pop();
- }
- p = k;
- while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; }
- } else if (cur == ' ' || cur == '\t') {
- cerr << "Unexpected whitespace in: " << tree << endl;
- abort();
- } else { // terminal symbol
- unsigned k = p + 1;
- do {
- while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; }
- // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
- ++i;
- assert(i <= src_len);
- while (k < tree.size() && tree[k] == ' ') { ++k; }
- p = k;
- } while (p < tree.size() && tree[p] != ')');
- }
- //cerr << "i=" << i << " src_len=" << src_len << endl;
- }
- //cerr << "i=" << i << " src_len=" << src_len << endl;
- assert(i == src_len); // make sure tree specified in src_tree is
- // the same length as the source sentence
- }
-
- WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
- //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl;
- const WordID lhs = src_tree(i,j);
- int& fid_ef = fids_ef(i,j)[&rule];
- ostringstream os;
- os << "SYN:" << TD::Convert(lhs);
- os << ':';
- unsigned ntc = 0;
- for (unsigned k = 0; k < rule.f_.size(); ++k) {
- int fj = rule.f_[k];
- if (k > 0 && fj <= 0) os << '_';
- if (fj <= 0) {
- os << '[' << TD::Convert(ants[ntc++]) << ']';
- } /*else {
- os << TD::Convert(fj);
- }*/
- }
- os << ':';
- for (unsigned k = 0; k < rule.e_.size(); ++k) {
- const int ei = rule.e_[k];
- if (k > 0) os << '_';
- if (ei <= 0)
- os << '[' << (1-ei) << ']';
- else
- os << TD::Convert(ei);
- }
- fid_ef = FD::Convert(os.str());
- //cerr << "FEATURE: " << os.str() << endl;
- //cerr << "FID_EF: " << fid_ef << endl;
- if (feature_filter.size() > 0) {
- if (feature_filter.find(fid_ef) != feature_filter.end()) {
- //cerr << "SYN-Feature was trigger more than once on training set." << endl;
- feats->set_value(fid_ef, 1.0);
- }
- //else cerr << "SYN-Feature was triggered less than once on training set." << endli;
- }
- else {
- feats->set_value(fid_ef, 1.0);
- }
- return lhs;
- }
-
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
- unordered_set<int> feature_filter;
-};
-
-PSourceSyntaxFeatures2::PSourceSyntaxFeatures2(const string& param) :
- FeatureFunction(sizeof(WordID)) {
- impl = new PSourceSyntaxFeatures2Impl(param);
-}
-
-PSourceSyntaxFeatures2::~PSourceSyntaxFeatures2() {
- delete impl;
- impl = NULL;
-}
-
-void PSourceSyntaxFeatures2::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- WordID ants[8];
- for (unsigned i = 0; i < ant_contexts.size(); ++i)
- ants[i] = *static_cast<const WordID*>(ant_contexts[i]);
-
- *static_cast<WordID*>(context) =
- impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
-}
-
-void PSourceSyntaxFeatures2::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
-}
diff --git a/decoder/ff_source_syntax2_p.h b/decoder/ff_source_syntax2_p.h
deleted file mode 100644
index d56ecab0..00000000
--- a/decoder/ff_source_syntax2_p.h
+++ /dev/null
@@ -1,25 +0,0 @@
-#ifndef _FF_SOURCE_TOOLS2_H_
-#define _FF_SOURCE_TOOLS2_H_
-
-#include "ff.h"
-#include "hg.h"
-
-struct PSourceSyntaxFeatures2Impl;
-
-class PSourceSyntaxFeatures2 : public FeatureFunction {
- public:
- PSourceSyntaxFeatures2(const std::string& param);
- ~PSourceSyntaxFeatures2();
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
- private:
- PSourceSyntaxFeatures2Impl* impl;
-};
-
-#endif
diff --git a/decoder/ff_source_syntax_p.cc b/decoder/ff_source_syntax_p.cc
deleted file mode 100644
index c094de59..00000000
--- a/decoder/ff_source_syntax_p.cc
+++ /dev/null
@@ -1,250 +0,0 @@
-#include "ff_source_syntax_p.h"
-
-#include <sstream>
-#include <stack>
-#ifndef HAVE_OLD_CPP
-# include <unordered_set>
-#else
-# include <tr1/unordered_set>
-namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; }
-#endif
-
-#include "sentence_metadata.h"
-#include "array2d.h"
-#include "filelib.h"
-
-using namespace std;
-
-// implements the source side syntax features described in Blunsom et al. (EMNLP 2008)
-// source trees must be represented in Penn Treebank format, e.g.
-// (S (NP John) (VP (V left)))
-
-// log transform to make long spans cluster together
-// but preserve differences
-inline int SpanSizeTransform(unsigned span_size) {
- if (!span_size) return 0;
- return static_cast<int>(log(span_size+1) / log(1.39)) - 1;
-}
-
-struct PSourceSyntaxFeaturesImpl {
- PSourceSyntaxFeaturesImpl() {}
-
- PSourceSyntaxFeaturesImpl(const string& param) {
- if (!(param.compare("") == 0)) {
- string triggered_features_fn = param;
- ReadFile triggered_features(triggered_features_fn);
- string in;
- while(getline(*triggered_features, in)) {
- feature_filter.insert(FD::Convert(in));
- }
- }
- }
-
- void InitializeGrids(const string& tree, unsigned src_len) {
- assert(tree.size() > 0);
- //fids_cat.clear();
- fids_ef.clear();
- src_tree.clear();
- //fids_cat.resize(src_len, src_len + 1);
- fids_ef.resize(src_len, src_len + 1);
- src_tree.resize(src_len, src_len + 1, TD::Convert("XX"));
- ParseTreeString(tree, src_len);
- }
-
- void ParseTreeString(const string& tree, unsigned src_len) {
- stack<pair<int, WordID> > stk; // first = i, second = category
- pair<int, WordID> cur_cat; cur_cat.first = -1;
- unsigned i = 0;
- unsigned p = 0;
- while(p < tree.size()) {
- const char cur = tree[p];
- if (cur == '(') {
- stk.push(cur_cat);
- ++p;
- unsigned k = p + 1;
- while (k < tree.size() && tree[k] != ' ') { ++k; }
- cur_cat.first = i;
- cur_cat.second = TD::Convert(tree.substr(p, k - p));
- // cerr << "NT: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
- p = k + 1;
- } else if (cur == ')') {
- unsigned k = p;
- while (k < tree.size() && tree[k] == ')') { ++k; }
- const unsigned num_closes = k - p;
- for (unsigned ci = 0; ci < num_closes; ++ci) {
- // cur_cat.second spans from cur_cat.first to i
- // cerr << TD::Convert(cur_cat.second) << " from " << cur_cat.first << " to " << i << endl;
- // NOTE: unary rule chains end up being labeled with the top-most category
- src_tree(cur_cat.first, i) = cur_cat.second;
- cur_cat = stk.top();
- stk.pop();
- }
- p = k;
- while (p < tree.size() && (tree[p] == ' ' || tree[p] == '\t')) { ++p; }
- } else if (cur == ' ' || cur == '\t') {
- cerr << "Unexpected whitespace in: " << tree << endl;
- abort();
- } else { // terminal symbol
- unsigned k = p + 1;
- do {
- while (k < tree.size() && tree[k] != ')' && tree[k] != ' ') { ++k; }
- // cerr << "TERM: '" << tree.substr(p, k-p) << "' (i=" << i << ")\n";
- ++i;
- assert(i <= src_len);
- while (k < tree.size() && tree[k] == ' ') { ++k; }
- p = k;
- } while (p < tree.size() && tree[p] != ')');
- }
- }
- // cerr << "i=" << i << " src_len=" << src_len << endl;
- assert(i == src_len); // make sure tree specified in src_tree is
- // the same length as the source sentence
- }
-
- WordID FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
- //cerr << "fire features: " << rule.AsString() << " for " << i << "," << j << endl;
- const WordID lhs = src_tree(i,j);
- //int& fid_cat = fids_cat(i,j);
- int& fid_ef = fids_ef(i,j)[&rule];
- if (fid_ef <= 0) {
- ostringstream os;
- //ostringstream os2;
- os << "SYN:" << TD::Convert(lhs);
- //os2 << "SYN:" << TD::Convert(lhs) << '_' << SpanSizeTransform(j - i);
- //fid_cat = FD::Convert(os2.str());
- os << ':';
- unsigned ntc = 0;
- for (unsigned k = 0; k < rule.f_.size(); ++k) {
- if (k > 0) os << '_';
- int fj = rule.f_[k];
- if (fj <= 0) {
- os << '[' << TD::Convert(ants[ntc++]) << ']';
- } else {
- os << TD::Convert(fj);
- }
- }
- os << ':';
- for (unsigned k = 0; k < rule.e_.size(); ++k) {
- const int ei = rule.e_[k];
- if (k > 0) os << '_';
- if (ei <= 0)
- os << '[' << (1-ei) << ']';
- else
- os << TD::Convert(ei);
- }
- fid_ef = FD::Convert(os.str());
- }
- //if (fid_cat > 0)
- // feats->set_value(fid_cat, 1.0);
- if (fid_ef > 0 && (feature_filter.find(fid_ef) != feature_filter.end()))
- feats->set_value(fid_ef, 1.0);
- return lhs;
- }
-
- Array2D<WordID> src_tree; // src_tree(i,j) NT = type
- // mutable Array2D<int> fids_cat; // this tends to overfit baddly
- mutable Array2D<map<const TRule*, int> > fids_ef; // fires for fully lexicalized
- unordered_set<int> feature_filter;
-};
-
-PSourceSyntaxFeatures::PSourceSyntaxFeatures(const string& param) :
- FeatureFunction(sizeof(WordID)) {
- impl = new PSourceSyntaxFeaturesImpl(param);
-}
-
-PSourceSyntaxFeatures::~PSourceSyntaxFeatures() {
- delete impl;
- impl = NULL;
-}
-
-void PSourceSyntaxFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- WordID ants[8];
- for (unsigned i = 0; i < ant_contexts.size(); ++i)
- ants[i] = *static_cast<const WordID*>(ant_contexts[i]);
-
- *static_cast<WordID*>(context) =
- impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
-}
-
-void PSourceSyntaxFeatures::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSGMLValue("src_tree"), smeta.GetSourceLength());
-}
-
-struct PSourceSpanSizeFeaturesImpl {
- PSourceSpanSizeFeaturesImpl() {}
-
- void InitializeGrids(unsigned src_len) {
- fids.clear();
- fids.resize(src_len, src_len + 1);
- }
-
- int FireFeatures(const TRule& rule, const int i, const int j, const WordID* ants, SparseVector<double>* feats) {
- if (rule.Arity() > 0) {
- int& fid = fids(i,j)[&rule];
- if (fid <= 0) {
- ostringstream os;
- os << "SSS:";
- unsigned ntc = 0;
- for (unsigned k = 0; k < rule.f_.size(); ++k) {
- if (k > 0) os << '_';
- int fj = rule.f_[k];
- if (fj <= 0) {
- os << '[' << TD::Convert(-fj) << ants[ntc++] << ']';
- } else {
- os << TD::Convert(fj);
- }
- }
- os << ':';
- for (unsigned k = 0; k < rule.e_.size(); ++k) {
- const int ei = rule.e_[k];
- if (k > 0) os << '_';
- if (ei <= 0)
- os << '[' << (1-ei) << ']';
- else
- os << TD::Convert(ei);
- }
- fid = FD::Convert(os.str());
- }
- if (fid > 0)
- feats->set_value(fid, 1.0);
- }
- return SpanSizeTransform(j - i);
- }
-
- mutable Array2D<map<const TRule*, int> > fids;
-};
-
-PSourceSpanSizeFeatures::PSourceSpanSizeFeatures(const string& param) :
- FeatureFunction(sizeof(char)) {
- impl = new PSourceSpanSizeFeaturesImpl;
-}
-
-PSourceSpanSizeFeatures::~PSourceSpanSizeFeatures() {
- delete impl;
- impl = NULL;
-}
-
-void PSourceSpanSizeFeatures::TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const {
- int ants[8];
- for (unsigned i = 0; i < ant_contexts.size(); ++i)
- ants[i] = *static_cast<const char*>(ant_contexts[i]);
-
- *static_cast<char*>(context) =
- impl->FireFeatures(*edge.rule_, edge.i_, edge.j_, ants, features);
-}
-
-void PSourceSpanSizeFeatures::PrepareForInput(const SentenceMetadata& smeta) {
- impl->InitializeGrids(smeta.GetSourceLength());
-}
-
-
diff --git a/decoder/ff_source_syntax_p.h b/decoder/ff_source_syntax_p.h
deleted file mode 100644
index 2dd9094a..00000000
--- a/decoder/ff_source_syntax_p.h
+++ /dev/null
@@ -1,42 +0,0 @@
-#ifndef _FF_SOURCE_TOOLS_H_
-#define _FF_SOURCE_TOOLS_H_
-
-#include "ff.h"
-#include "hg.h"
-
-struct PSourceSyntaxFeaturesImpl;
-
-class PSourceSyntaxFeatures : public FeatureFunction {
- public:
- PSourceSyntaxFeatures(const std::string& param);
- ~PSourceSyntaxFeatures();
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
- private:
- PSourceSyntaxFeaturesImpl* impl;
-};
-
-struct PSourceSpanSizeFeaturesImpl;
-class PSourceSpanSizeFeatures : public FeatureFunction {
- public:
- PSourceSpanSizeFeatures(const std::string& param);
- ~PSourceSpanSizeFeatures();
- protected:
- virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
- const Hypergraph::Edge& edge,
- const std::vector<const void*>& ant_contexts,
- SparseVector<double>* features,
- SparseVector<double>* estimated_features,
- void* context) const;
- virtual void PrepareForInput(const SentenceMetadata& smeta);
- private:
- PSourceSpanSizeFeaturesImpl* impl;
-};
-
-#endif
diff --git a/extractor/grammar_extractor.cc b/extractor/grammar_extractor.cc
index 8050ce7b..1fbdee5b 100644
--- a/extractor/grammar_extractor.cc
+++ b/extractor/grammar_extractor.cc
@@ -3,11 +3,13 @@
#include <iterator>
#include <sstream>
#include <vector>
+#include <unordered_set>
#include "grammar.h"
#include "rule.h"
#include "rule_factory.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
@@ -32,10 +34,10 @@ GrammarExtractor::GrammarExtractor(
vocabulary(vocabulary),
rule_factory(rule_factory) {}
-Grammar GrammarExtractor::GetGrammar(const string& sentence) {
+Grammar GrammarExtractor::GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
vector<string> words = TokenizeSentence(sentence);
vector<int> word_ids = AnnotateWords(words);
- return rule_factory->GetGrammar(word_ids);
+ return rule_factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
}
vector<string> GrammarExtractor::TokenizeSentence(const string& sentence) {
diff --git a/extractor/grammar_extractor.h b/extractor/grammar_extractor.h
index b36ceeb9..6c0aafbf 100644
--- a/extractor/grammar_extractor.h
+++ b/extractor/grammar_extractor.h
@@ -4,6 +4,7 @@
#include <memory>
#include <string>
#include <vector>
+#include <unordered_set>
using namespace std;
@@ -44,7 +45,7 @@ class GrammarExtractor {
// Converts the sentence to a vector of word ids and uses the RuleFactory to
// extract the SCFG rules which may be used to decode the sentence.
- Grammar GetGrammar(const string& sentence);
+ Grammar GetGrammar(const string& sentence, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
private:
// Splits the sentence in a vector of words.
diff --git a/extractor/grammar_extractor_test.cc b/extractor/grammar_extractor_test.cc
index 823bb8b4..f32a9599 100644
--- a/extractor/grammar_extractor_test.cc
+++ b/extractor/grammar_extractor_test.cc
@@ -39,12 +39,15 @@ TEST(GrammarExtractorTest, TestAnnotatingWords) {
vector<Rule> rules;
vector<string> feature_names;
Grammar grammar(rules, feature_names);
- EXPECT_CALL(*factory, GetGrammar(word_ids))
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ EXPECT_CALL(*factory, GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array))
.WillOnce(Return(grammar));
GrammarExtractor extractor(vocabulary, factory);
string sentence = "Anna has many many apples .";
- extractor.GetGrammar(sentence);
+
+ extractor.GetGrammar(sentence, blacklisted_sentence_ids, source_data_array);
}
} // namespace
diff --git a/extractor/mocks/mock_rule_factory.h b/extractor/mocks/mock_rule_factory.h
index 7389b396..86a084b5 100644
--- a/extractor/mocks/mock_rule_factory.h
+++ b/extractor/mocks/mock_rule_factory.h
@@ -7,7 +7,7 @@ namespace extractor {
class MockHieroCachingRuleFactory : public HieroCachingRuleFactory {
public:
- MOCK_METHOD1(GetGrammar, Grammar(const vector<int>& word_ids));
+ MOCK_METHOD3(GetGrammar, Grammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array));
};
} // namespace extractor
diff --git a/extractor/rule_factory.cc b/extractor/rule_factory.cc
index 8c30fb9e..e52019ae 100644
--- a/extractor/rule_factory.cc
+++ b/extractor/rule_factory.cc
@@ -17,6 +17,7 @@
#include "suffix_array.h"
#include "time_util.h"
#include "vocabulary.h"
+#include "data_array.h"
using namespace std;
using namespace chrono;
@@ -100,7 +101,7 @@ HieroCachingRuleFactory::HieroCachingRuleFactory() {}
HieroCachingRuleFactory::~HieroCachingRuleFactory() {}
-Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
+Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) {
Clock::time_point start_time = Clock::now();
double total_extract_time = 0;
double total_intersect_time = 0;
@@ -192,7 +193,7 @@ Grammar HieroCachingRuleFactory::GetGrammar(const vector<int>& word_ids) {
Clock::time_point extract_start = Clock::now();
if (!state.starts_with_x) {
// Extract rules for the sampled set of occurrences.
- PhraseLocation sample = sampler->Sample(next_node->matchings);
+ PhraseLocation sample = sampler->Sample(next_node->matchings, blacklisted_sentence_ids, source_data_array);
vector<Rule> new_rules =
rule_extractor->ExtractRules(next_phrase, sample);
rules.insert(rules.end(), new_rules.begin(), new_rules.end());
diff --git a/extractor/rule_factory.h b/extractor/rule_factory.h
index 52e8712a..c7332720 100644
--- a/extractor/rule_factory.h
+++ b/extractor/rule_factory.h
@@ -3,6 +3,7 @@
#include <memory>
#include <vector>
+#include <unordered_set>
#include "matchings_trie.h"
@@ -71,7 +72,7 @@ class HieroCachingRuleFactory {
// Constructs SCFG rules for a given sentence.
// (See class description for more details.)
- virtual Grammar GetGrammar(const vector<int>& word_ids);
+ virtual Grammar GetGrammar(const vector<int>& word_ids, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array);
protected:
HieroCachingRuleFactory();
diff --git a/extractor/rule_factory_test.cc b/extractor/rule_factory_test.cc
index 08af3dcd..f26cc567 100644
--- a/extractor/rule_factory_test.cc
+++ b/extractor/rule_factory_test.cc
@@ -76,7 +76,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarDifferentWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector<int> word_ids = {2, 3, 4};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(7, grammar.GetRules().size());
}
@@ -94,7 +96,9 @@ TEST_F(RuleFactoryTest, TestGetGrammarRepeatingWords) {
.WillRepeatedly(Return(PhraseLocation(0, 1)));
vector<int> word_ids = {2, 3, 4, 2, 3};
- Grammar grammar = factory->GetGrammar(word_ids);
+ unordered_set<int> blacklisted_sentence_ids;
+ shared_ptr<DataArray> source_data_array;
+ Grammar grammar = factory->GetGrammar(word_ids, blacklisted_sentence_ids, source_data_array);
EXPECT_EQ(feature_names, grammar.GetFeatureNames());
EXPECT_EQ(28, grammar.GetRules().size());
}
diff --git a/extractor/run_extractor.cc b/extractor/run_extractor.cc
index 8a9ca89d..6eb55073 100644
--- a/extractor/run_extractor.cc
+++ b/extractor/run_extractor.cc
@@ -75,7 +75,9 @@ int main(int argc, char** argv) {
("max_samples", po::value<int>()->default_value(300),
"Maximum number of samples")
("tight_phrases", po::value<bool>()->default_value(true),
- "False if phrases may be loose (better, but slower)");
+ "False if phrases may be loose (better, but slower)")
+ ("leave_one_out", po::value<bool>()->zero_tokens(),
+ "do leave-one-out estimation of grammars (e.g. for extracting grammars for the training set");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@@ -96,6 +98,11 @@ int main(int argc, char** argv) {
return 1;
}
+ bool leave_one_out = false;
+ if (vm.count("leave_one_out")) {
+ leave_one_out = true;
+ }
+
int num_threads = vm["threads"].as<int>();
cerr << "Grammar extraction will use " << num_threads << " threads." << endl;
@@ -223,7 +230,9 @@ int main(int argc, char** argv) {
}
suffixes[i] = suffix;
- Grammar grammar = extractor.GetGrammar(sentences[i]);
+ unordered_set<int> blacklisted_sentence_ids;
+ if (leave_one_out) blacklisted_sentence_ids.insert(i);
+ Grammar grammar = extractor.GetGrammar(sentences[i], blacklisted_sentence_ids, source_data_array);
ofstream output(GetGrammarFilePath(grammar_path, i).c_str());
output << grammar;
}
diff --git a/extractor/sample_source.txt b/extractor/sample_source.txt
new file mode 100644
index 00000000..971baf6d
--- /dev/null
+++ b/extractor/sample_source.txt
@@ -0,0 +1,2 @@
+ana are mere .
+ana bea mult lapte .
diff --git a/extractor/sampler.cc b/extractor/sampler.cc
index d81956b5..d332dd90 100644
--- a/extractor/sampler.cc
+++ b/extractor/sampler.cc
@@ -12,7 +12,7 @@ Sampler::Sampler() {}
Sampler::~Sampler() {}
-PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
+PhraseLocation Sampler::Sample(const PhraseLocation& location, unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const {
vector<int> sample;
int num_subpatterns;
if (location.matchings == NULL) {
@@ -20,8 +20,37 @@ PhraseLocation Sampler::Sample(const PhraseLocation& location) const {
num_subpatterns = 1;
int low = location.sa_low, high = location.sa_high;
double step = max(1.0, (double) (high - low) / max_samples);
- for (double i = low; i < high && sample.size() < max_samples; i += step) {
- sample.push_back(suffix_array->GetSuffix(Round(i)));
+ double i = low, last = i;
+ bool found;
+ while (sample.size() < max_samples && i < high) {
+ int x = suffix_array->GetSuffix(Round(i));
+ int id = source_data_array->GetSentenceId(x);
+ if (find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) != blacklisted_sentence_ids.end()) {
+ found = false;
+ double backoff_step = 1;
+ while (true) {
+ if ((double)backoff_step >= step) break;
+ double j = i - backoff_step;
+ x = suffix_array->GetSuffix(Round(j));
+ id = source_data_array->GetSentenceId(x);
+ if (x >= 0 && j > last && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ found = true; last = i; break;
+ }
+ double k = i + backoff_step;
+ x = suffix_array->GetSuffix(Round(k));
+ id = source_data_array->GetSentenceId(x);
+ if (k < min(i+step, (double)high) && find(blacklisted_sentence_ids.begin(), blacklisted_sentence_ids.end(), id) == blacklisted_sentence_ids.end()) {
+ found = true; last = k; break;
+ }
+ if (j <= last && k >= high) break;
+ backoff_step++;
+ }
+ } else {
+ found = true;
+ last = i;
+ }
+ if (found) sample.push_back(x);
+ i += step;
}
} else {
// Sample vector of occurrences.
diff --git a/extractor/sampler.h b/extractor/sampler.h
index be4aa1bb..30e747fd 100644
--- a/extractor/sampler.h
+++ b/extractor/sampler.h
@@ -2,6 +2,9 @@
#define _SAMPLER_H_
#include <memory>
+#include <unordered_set>
+
+#include "data_array.h"
using namespace std;
@@ -20,7 +23,7 @@ class Sampler {
virtual ~Sampler();
// Samples uniformly at most max_samples phrase occurrences.
- virtual PhraseLocation Sample(const PhraseLocation& location) const;
+ virtual PhraseLocation Sample(const PhraseLocation& location, const unordered_set<int> blacklisted_sentence_ids, const shared_ptr<DataArray> source_data_array) const;
protected:
Sampler();
diff --git a/extractor/sampler_test.cc b/extractor/sampler_test.cc
index e9abebfa..965567ba 100644
--- a/extractor/sampler_test.cc
+++ b/extractor/sampler_test.cc
@@ -3,6 +3,7 @@
#include <memory>
#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
#include "phrase_location.h"
#include "sampler.h"
@@ -15,6 +16,8 @@ namespace {
class SamplerTest : public Test {
protected:
virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ EXPECT_CALL(*source_data_array, GetSentenceId(_)).WillRepeatedly(Return(9999));
suffix_array = make_shared<MockSuffixArray>();
for (int i = 0; i < 10; ++i) {
EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
@@ -23,51 +26,54 @@ class SamplerTest : public Test {
shared_ptr<MockSuffixArray> suffix_array;
shared_ptr<Sampler> sampler;
+ shared_ptr<MockDataArray> source_data_array;
};
TEST_F(SamplerTest, TestSuffixArrayRange) {
PhraseLocation location(0, 10);
+ unordered_set<int> blacklist;
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 5};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 3, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 4);
expected_locations = {0, 3, 5, 8};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 100);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
}
TEST_F(SamplerTest, TestSubstringsSample) {
vector<int> locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+ unordered_set<int> blacklist;
PhraseLocation location(locations, 2);
sampler = make_shared<Sampler>(suffix_array, 1);
vector<int> expected_locations = {0, 1};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 2);
expected_locations = {0, 1, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 3);
expected_locations = {0, 1, 4, 5, 6, 7};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
sampler = make_shared<Sampler>(suffix_array, 7);
expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
- EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location));
+ EXPECT_EQ(PhraseLocation(expected_locations, 2), sampler->Sample(location, blacklist, source_data_array));
}
} // namespace
diff --git a/extractor/sampler_test_blacklist.cc b/extractor/sampler_test_blacklist.cc
new file mode 100644
index 00000000..3305b990
--- /dev/null
+++ b/extractor/sampler_test_blacklist.cc
@@ -0,0 +1,102 @@
+#include <gtest/gtest.h>
+
+#include <memory>
+
+#include "mocks/mock_suffix_array.h"
+#include "mocks/mock_data_array.h"
+#include "phrase_location.h"
+#include "sampler.h"
+
+using namespace std;
+using namespace ::testing;
+
+namespace extractor {
+namespace {
+
+class SamplerTestBlacklist : public Test {
+ protected:
+ virtual void SetUp() {
+ source_data_array = make_shared<MockDataArray>();
+ for (int i = 0; i < 10; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(i));
+ }
+ for (int i = -10; i < 0; ++i) {
+ EXPECT_CALL(*source_data_array, GetSentenceId(i)).WillRepeatedly(Return(0));
+ }
+ suffix_array = make_shared<MockSuffixArray>();
+ for (int i = -10; i < 10; ++i) {
+ EXPECT_CALL(*suffix_array, GetSuffix(i)).WillRepeatedly(Return(i));
+ }
+ }
+
+ shared_ptr<MockSuffixArray> suffix_array;
+ shared_ptr<Sampler> sampler;
+ shared_ptr<MockDataArray> source_data_array;
+};
+
+TEST_F(SamplerTestBlacklist, TestSuffixArrayRange) {
+ PhraseLocation location(0, 10);
+ unordered_set<int> blacklist;
+ vector<int> expected_locations;
+
+ blacklist.insert(0);
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ expected_locations = {1};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ for (int i = 0; i < 9; i++) {
+ blacklist.insert(i);
+ }
+ sampler = make_shared<Sampler>(suffix_array, 1);
+ expected_locations = {9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(5);
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {1, 4};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(1);
+ blacklist.insert(2);
+ blacklist.insert(3);
+ sampler = make_shared<Sampler>(suffix_array, 2);
+ expected_locations = {4, 5};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(7);
+ sampler = make_shared<Sampler>(suffix_array, 3);
+ expected_locations = {1, 2, 6};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ blacklist.insert(3);
+ blacklist.insert(5);
+ blacklist.insert(8);
+ sampler = make_shared<Sampler>(suffix_array, 4);
+ expected_locations = {1, 2, 4, 7};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(0);
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {1, 2, 3, 4, 5, 6, 7, 8, 9};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+ blacklist.clear();
+
+ blacklist.insert(9);
+ sampler = make_shared<Sampler>(suffix_array, 100);
+ expected_locations = {0, 1, 2, 3, 4, 5, 6, 7, 8};
+ EXPECT_EQ(PhraseLocation(expected_locations, 1), sampler->Sample(location, blacklist, source_data_array));
+}
+
+} // namespace
+} // namespace extractor
diff --git a/training/dtrain/Makefile.am b/training/dtrain/Makefile.am
index 844c790d..ecb6c128 100644
--- a/training/dtrain/Makefile.am
+++ b/training/dtrain/Makefile.am
@@ -1,7 +1,7 @@
bin_PROGRAMS = dtrain
dtrain_SOURCES = dtrain.cc score.cc dtrain.h kbestget.h ksampler.h pairsampling.h score.h
-dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a
+dtrain_LDADD = ../../decoder/libcdec.a ../../klm/search/libksearch.a ../../mteval/libmteval.a ../../utils/libutils.a ../../klm/lm/libklm.a ../../klm/util/libklm_util.a ../../klm/util/double-conversion/libklm_util_double.a -lboost_regex
AM_CPPFLAGS = -W -Wall -Wno-sign-compare -I$(top_srcdir)/utils -I$(top_srcdir)/decoder -I$(top_srcdir)/mteval
diff --git a/training/dtrain/README.md b/training/dtrain/README.md
index 2bae6b48..aa1ab3e7 100644
--- a/training/dtrain/README.md
+++ b/training/dtrain/README.md
@@ -1,10 +1,15 @@
This is a simple (and parallelizable) tuning method for cdec
-which is able to train the weights of very many (sparse) features.
-It was used here:
- "Joint Feature Selection in Distributed Stochastic
- Learning for Large-Scale Discriminative Training in
- SMT"
-(Simianer, Riezler, Dyer; ACL 2012)
+which is able to train the weights of very many (sparse) features
+on the training set.
+
+It was used in these papers:
+> "Joint Feature Selection in Distributed Stochastic
+> Learning for Large-Scale Discriminative Training in
+> SMT" (Simianer, Riezler, Dyer; ACL 2012)
+>
+> "Multi-Task Learning for Improved Discriminative
+> Training in SMT" (Simianer, Riezler; WMT 2013)
+>
Building
@@ -17,20 +22,9 @@ To build only parts needed for dtrain do
cd training/dtrain/; make
```
-Ideas
------
- * get approx_bleu to work?
- * implement minibatches (Minibatch and Parallelization for Online Large Margin Structured Learning)
- * learning rate 1/T?
- * use an oracle? mira-like (model vs. BLEU), feature repr. of reference!?
- * implement lc_bleu properly
- * merge kbest lists of previous epochs (as MERT does)
- * ``walk entire regularization path''
- * rerank after each update?
-
Running
-------
-See directories under test/ .
+See directories under examples/ .
Legal
-----
diff --git a/training/dtrain/dtrain.cc b/training/dtrain/dtrain.cc
index 0ee2f124..0a27a068 100644
--- a/training/dtrain/dtrain.cc
+++ b/training/dtrain/dtrain.cc
@@ -12,8 +12,9 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
{
po::options_description ini("Configuration File Options");
ini.add_options()
- ("input", po::value<string>()->default_value("-"), "input file (src)")
+ ("input", po::value<string>(), "input file (src)")
("refs,r", po::value<string>(), "references")
+ ("bitext,b", po::value<string>(), "bitext: 'src ||| tgt'")
("output", po::value<string>()->default_value("-"), "output weights file, '-' for STDOUT")
("input_weights", po::value<string>(), "input weights file (e.g. from previous iteration)")
("decoder_config", po::value<string>(), "configuration file for cdec")
@@ -40,6 +41,10 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
("scale_bleu_diff", po::value<bool>()->zero_tokens(), "learning rate <- bleu diff of a misranked pair")
("loss_margin", po::value<weight_t>()->default_value(0.), "update if no error in pref pair but model scores this near")
("max_pairs", po::value<unsigned>()->default_value(std::numeric_limits<unsigned>::max()), "max. # of pairs per Sent.")
+ ("pclr", po::value<string>()->default_value("no"), "use a (simple|adagrad) per-coordinate learning rate")
+ ("batch", po::value<bool>()->zero_tokens(), "do batch optimization")
+ ("repeat", po::value<unsigned>()->default_value(1), "repeat optimization over kbest list this number of times")
+ //("test-k-best", po::value<bool>()->zero_tokens(), "check if optimization works (use repeat >= 2)")
("noup", po::value<bool>()->zero_tokens(), "do not update weights");
po::options_description cl("Command Line Options");
cl.add_options()
@@ -72,13 +77,17 @@ dtrain_init(int argc, char** argv, po::variables_map* cfg)
cerr << "Wrong 'pair_sampling' param: '" << (*cfg)["pair_sampling"].as<string>() << "'." << endl;
return false;
}
- if(cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
+ if (cfg->count("hi_lo") && (*cfg)["pair_sampling"].as<string>() != "XYX") {
cerr << "Warning: hi_lo only works with pair_sampling XYX." << endl;
}
- if((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
+ if ((*cfg)["hi_lo"].as<float>() > 0.5 || (*cfg)["hi_lo"].as<float>() < 0.01) {
cerr << "hi_lo must lie in [0.01, 0.5]" << endl;
return false;
}
+ if ((cfg->count("input")>0 || cfg->count("refs")>0) && cfg->count("bitext")>0) {
+ cerr << "Provide 'input' and 'refs' or 'bitext', not both." << endl;
+ return false;
+ }
if ((*cfg)["pair_threshold"].as<score_t>() < 0) {
cerr << "The threshold must be >= 0!" << endl;
return false;
@@ -120,10 +129,16 @@ main(int argc, char** argv)
const float hi_lo = cfg["hi_lo"].as<float>();
const score_t approx_bleu_d = cfg["approx_bleu_d"].as<score_t>();
const unsigned max_pairs = cfg["max_pairs"].as<unsigned>();
+ int repeat = cfg["repeat"].as<unsigned>();
+ //bool test_k_best = false;
+ //if (cfg.count("test-k-best")) test_k_best = true;
weight_t loss_margin = cfg["loss_margin"].as<weight_t>();
+ bool batch = false;
+ if (cfg.count("batch")) batch = true;
if (loss_margin > 9998.) loss_margin = std::numeric_limits<float>::max();
bool scale_bleu_diff = false;
if (cfg.count("scale_bleu_diff")) scale_bleu_diff = true;
+ const string pclr = cfg["pclr"].as<string>();
bool average = false;
if (select_weights == "avg")
average = true;
@@ -131,7 +146,6 @@ main(int argc, char** argv)
if (cfg.count("print_weights"))
boost::split(print_weights, cfg["print_weights"].as<string>(), boost::is_any_of(" "));
-
// setup decoder
register_feature_functions();
SetSilent(true);
@@ -178,17 +192,16 @@ main(int argc, char** argv)
observer->SetScorer(scorer);
// init weights
- vector<weight_t>& dense_weights = decoder.CurrentWeightVector();
+ vector<weight_t>& decoder_weights = decoder.CurrentWeightVector();
SparseVector<weight_t> lambdas, cumulative_penalties, w_average;
- if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &dense_weights);
- Weights::InitSparseVector(dense_weights, &lambdas);
+ if (cfg.count("input_weights")) Weights::InitFromFile(cfg["input_weights"].as<string>(), &decoder_weights);
+ Weights::InitSparseVector(decoder_weights, &lambdas);
// meta params for perceptron, SVM
weight_t eta = cfg["learning_rate"].as<weight_t>();
weight_t gamma = cfg["gamma"].as<weight_t>();
// faster perceptron: consider only misranked pairs, see
- // DO NOT ENABLE WITH SVM (gamma > 0) OR loss_margin!
bool faster_perceptron = false;
if (gamma==0 && loss_margin==0) faster_perceptron = true;
@@ -208,13 +221,24 @@ main(int argc, char** argv)
// output
string output_fn = cfg["output"].as<string>();
// input
- string input_fn = cfg["input"].as<string>();
+ bool read_bitext = false;
+ string input_fn;
+ if (cfg.count("bitext")) {
+ read_bitext = true;
+ input_fn = cfg["bitext"].as<string>();
+ } else {
+ input_fn = cfg["input"].as<string>();
+ }
ReadFile input(input_fn);
// buffer input for t > 0
vector<string> src_str_buf; // source strings (decoder takes only strings)
vector<vector<WordID> > ref_ids_buf; // references as WordID vecs
- string refs_fn = cfg["refs"].as<string>();
- ReadFile refs(refs_fn);
+ ReadFile refs;
+ string refs_fn;
+ if (!read_bitext) {
+ refs_fn = cfg["refs"].as<string>();
+ refs.Init(refs_fn);
+ }
unsigned in_sz = std::numeric_limits<unsigned>::max(); // input index, input size
vector<pair<score_t, score_t> > all_scores;
@@ -229,6 +253,7 @@ main(int argc, char** argv)
cerr << setw(25) << "k " << k << endl;
cerr << setw(25) << "N " << N << endl;
cerr << setw(25) << "T " << T << endl;
+ cerr << setw(25) << "batch " << batch << endl;
cerr << setw(26) << "scorer '" << scorer_str << "'" << endl;
if (scorer_str == "approx_bleu")
cerr << setw(25) << "approx. B discount " << approx_bleu_d << endl;
@@ -249,10 +274,14 @@ main(int argc, char** argv)
cerr << setw(25) << "l1 reg " << l1_reg << " '" << cfg["l1_reg"].as<string>() << "'" << endl;
if (rescale)
cerr << setw(25) << "rescale " << rescale << endl;
+ cerr << setw(25) << "pclr " << pclr << endl;
cerr << setw(25) << "max pairs " << max_pairs << endl;
+ cerr << setw(25) << "repeat " << repeat << endl;
+ //cerr << setw(25) << "test k-best " << test_k_best << endl;
cerr << setw(25) << "cdec cfg " << "'" << cfg["decoder_config"].as<string>() << "'" << endl;
cerr << setw(25) << "input " << "'" << input_fn << "'" << endl;
- cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
+ if (!read_bitext)
+ cerr << setw(25) << "refs " << "'" << refs_fn << "'" << endl;
cerr << setw(25) << "output " << "'" << output_fn << "'" << endl;
if (cfg.count("input_weights"))
cerr << setw(25) << "weights in " << "'" << cfg["input_weights"].as<string>() << "'" << endl;
@@ -261,6 +290,11 @@ main(int argc, char** argv)
if (!verbose) cerr << "(a dot represents " << DTRAIN_DOTS << " inputs)" << endl;
}
+ // pclr
+ SparseVector<weight_t> learning_rates;
+ // batch
+ SparseVector<weight_t> batch_updates;
+ score_t batch_loss;
for (unsigned t = 0; t < T; t++) // T epochs
{
@@ -269,16 +303,24 @@ main(int argc, char** argv)
time(&start);
score_t score_sum = 0.;
score_t model_sum(0);
- unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0;
+ unsigned ii = 0, rank_errors = 0, margin_violations = 0, npairs = 0, f_count = 0, list_sz = 0, kbest_loss_improve = 0;
+ batch_loss = 0.;
if (!quiet) cerr << "Iteration #" << t+1 << " of " << T << "." << endl;
while(true)
{
string in;
+ string ref;
bool next = false, stop = false; // next iteration or premature stop
if (t == 0) {
if(!getline(*input, in)) next = true;
+ if(read_bitext) {
+ vector<string> strs;
+ boost::algorithm::split_regex(strs, in, boost::regex(" \\|\\|\\| "));
+ in = strs[0];
+ ref = strs[1];
+ }
} else {
if (ii == in_sz) next = true; // stop if we reach the end of our input
}
@@ -310,15 +352,16 @@ main(int argc, char** argv)
if (next || stop) break;
// weights
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
// getting input
vector<WordID> ref_ids; // reference as vector<WordID>
if (t == 0) {
- string r_;
- getline(*refs, r_);
+ if (!read_bitext) {
+ getline(*refs, ref);
+ }
vector<string> ref_tok;
- boost::split(ref_tok, r_, boost::is_any_of(" "));
+ boost::split(ref_tok, ref, boost::is_any_of(" "));
register_and_convert(ref_tok, ref_ids);
ref_ids_buf.push_back(ref_ids);
src_str_buf.push_back(in);
@@ -348,8 +391,10 @@ main(int argc, char** argv)
}
}
- score_sum += (*samples)[0].score; // stats for 1best
- model_sum += (*samples)[0].model;
+ if (repeat == 1) {
+ score_sum += (*samples)[0].score; // stats for 1best
+ model_sum += (*samples)[0].model;
+ }
f_count += observer->get_f_count();
list_sz += observer->get_sz();
@@ -364,30 +409,74 @@ main(int argc, char** argv)
partXYX(samples, pairs, pair_threshold, max_pairs, faster_perceptron, hi_lo);
if (pair_sampling == "PRO")
PROsampling(samples, pairs, pair_threshold, max_pairs);
- npairs += pairs.size();
+ int cur_npairs = pairs.size();
+ npairs += cur_npairs;
+
+ score_t kbest_loss_first, kbest_loss_last = 0.0;
- SparseVector<weight_t> lambdas_copy;
+ for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
+ it != pairs.end(); it++) {
+ score_t model_diff = it->first.model - it->second.model;
+ kbest_loss_first += max(0.0, -1.0 * model_diff);
+ }
+
+ for (int ki=0; ki < repeat; ki++) {
+
+ score_t kbest_loss = 0.0; // test-k-best
+ SparseVector<weight_t> lambdas_copy; // for l1 regularization
+ SparseVector<weight_t> sum_up; // for pclr
if (l1naive||l1clip||l1cumul) lambdas_copy = lambdas;
for (vector<pair<ScoredHyp,ScoredHyp> >::iterator it = pairs.begin();
it != pairs.end(); it++) {
- bool rank_error;
+ score_t model_diff = it->first.model - it->second.model;
+ if (repeat > 1) {
+ model_diff = lambdas.dot(it->first.f) - lambdas.dot(it->second.f);
+ kbest_loss += max(0.0, -1.0 * model_diff);
+ }
+ bool rank_error = false;
score_t margin;
if (faster_perceptron) { // we only have considering misranked pairs
rank_error = true; // pair sampling already did this for us
margin = std::numeric_limits<float>::max();
} else {
- rank_error = it->first.model <= it->second.model;
- margin = fabs(it->first.model - it->second.model);
+ rank_error = model_diff<=0.0;
+ margin = fabs(model_diff);
if (!rank_error && margin < loss_margin) margin_violations++;
}
- if (rank_error) rank_errors++;
+ if (rank_error && ki==1) rank_errors++;
if (scale_bleu_diff) eta = it->first.score - it->second.score;
if (rank_error || margin < loss_margin) {
SparseVector<weight_t> diff_vec = it->first.f - it->second.f;
- lambdas.plus_eq_v_times_s(diff_vec, eta);
- if (gamma)
- lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ if (batch) {
+ batch_loss += max(0., -1.0*model_diff);
+ batch_updates += diff_vec;
+ continue;
+ }
+ if (pclr != "no") {
+ sum_up += diff_vec;
+ } else {
+ lambdas.plus_eq_v_times_s(diff_vec, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./cur_npairs));
+ }
+ }
+ }
+
+ // per-coordinate learning rate
+ if (pclr != "no") {
+ SparseVector<weight_t>::iterator it = sum_up.begin();
+ for (; it != sum_up.end(); ++it) {
+ if (pclr == "simple") {
+ lambdas[it->first] += it->second / max(1.0, learning_rates[it->first]);
+ learning_rates[it->first]++;
+ } else if (pclr == "adagrad") {
+ if (learning_rates[it->first] == 0) {
+ lambdas[it->first] += it->second * eta;
+ } else {
+ lambdas[it->first] += it->second * eta * learning_rates[it->first];
+ }
+ learning_rates[it->first] += pow(it->second, 2.0);
+ }
}
}
@@ -395,14 +484,16 @@ main(int argc, char** argv)
// please note that this regularizations happen
// after a _sentence_ -- not after each example/pair!
if (l1naive) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
+ it->second *= max(0.0000001, eta/(eta+learning_rates[it->first])); // FIXME
+ learning_rates[it->first]++;
it->second -= sign(it->second) * l1_reg;
}
}
} else if (l1clip) {
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {
@@ -417,7 +508,7 @@ main(int argc, char** argv)
}
} else if (l1cumul) {
weight_t acc_penalty = (ii+1) * l1_reg; // ii is the index of the current input
- FastSparseVector<weight_t>::iterator it = lambdas.begin();
+ SparseVector<weight_t>::iterator it = lambdas.begin();
for (; it != lambdas.end(); ++it) {
if (!lambdas_copy.get(it->first) || lambdas_copy.get(it->first)!=it->second) {
if (it->second != 0) {
@@ -435,7 +526,28 @@ main(int argc, char** argv)
}
}
- }
+ if (ki==repeat-1) { // done
+ kbest_loss_last = kbest_loss;
+ if (repeat > 1) {
+ score_t best_score = -1.;
+ score_t best_model = -std::numeric_limits<score_t>::max();
+ unsigned best_idx;
+ for (unsigned i=0; i < samples->size(); i++) {
+ score_t s = lambdas.dot((*samples)[i].f);
+ if (s > best_model) {
+ best_idx = i;
+ best_model = s;
+ }
+ }
+ score_sum += (*samples)[best_idx].score;
+ model_sum += best_model;
+ }
+ }
+ } // repeat
+
+ if ((kbest_loss_first - kbest_loss_last) >= 0) kbest_loss_improve++;
+
+ } // noup
if (rescale) lambdas /= lambdas.l2norm();
@@ -443,14 +555,19 @@ main(int argc, char** argv)
} // input loop
- if (average) w_average += lambdas;
+ if (t == 0) in_sz = ii; // remember size of input (# lines)
- if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
- if (t == 0) {
- in_sz = ii; // remember size of input (# lines)
+ if (batch) {
+ lambdas.plus_eq_v_times_s(batch_updates, eta);
+ if (gamma) lambdas.plus_eq_v_times_s(lambdas, -2*gamma*eta*(1./npairs));
+ batch_updates.clear();
}
+ if (average) w_average += lambdas;
+
+ if (scorer_str == "approx_bleu" || scorer_str == "lc_bleu") scorer->Reset();
+
// print some stats
score_t score_avg = score_sum/(score_t)in_sz;
score_t model_avg = model_sum/(score_t)in_sz;
@@ -477,13 +594,15 @@ main(int argc, char** argv)
cerr << _np << " 1best avg model score: " << model_avg;
cerr << _p << " (" << model_diff << ")" << endl;
cerr << " avg # pairs: ";
- cerr << _np << npairs/(float)in_sz;
+ cerr << _np << npairs/(float)in_sz << endl;
+ cerr << " avg # rank err: ";
+ cerr << rank_errors/(float)in_sz;
if (faster_perceptron) cerr << " (meaningless)";
cerr << endl;
- cerr << " avg # rank err: ";
- cerr << rank_errors/(float)in_sz << endl;
cerr << " avg # margin viol: ";
cerr << margin_violations/(float)in_sz << endl;
+ if (batch) cerr << " batch loss: " << batch_loss << endl;
+ cerr << " k-best loss imp: " << ((float)kbest_loss_improve/in_sz)*100 << "%" << endl;
cerr << " non0 feature count: " << nonz << endl;
cerr << " avg list sz: " << list_sz/(float)in_sz << endl;
cerr << " avg f count: " << f_count/(float)list_sz << endl;
@@ -510,9 +629,9 @@ main(int argc, char** argv)
// write weights to file
if (select_weights == "best" || keep) {
- lambdas.init_vector(&dense_weights);
+ lambdas.init_vector(&decoder_weights);
string w_fn = "weights." + boost::lexical_cast<string>(t) + ".gz";
- Weights::WriteToFile(w_fn, dense_weights, true);
+ Weights::WriteToFile(w_fn, decoder_weights, true);
}
} // outer loop
diff --git a/training/dtrain/dtrain.h b/training/dtrain/dtrain.h
index 3981fb39..ccb5ad4d 100644
--- a/training/dtrain/dtrain.h
+++ b/training/dtrain/dtrain.h
@@ -9,6 +9,8 @@
#include <string.h>
#include <boost/algorithm/string.hpp>
+#include <boost/regex.hpp>
+#include <boost/algorithm/string/regex.hpp>
#include <boost/program_options.hpp>
#include "decoder.h"
diff --git a/training/dtrain/examples/standard/dtrain.ini b/training/dtrain/examples/standard/dtrain.ini
index 23e94285..fc83f08e 100644
--- a/training/dtrain/examples/standard/dtrain.ini
+++ b/training/dtrain/examples/standard/dtrain.ini
@@ -1,5 +1,6 @@
-input=./nc-wmt11.de.gz
-refs=./nc-wmt11.en.gz
+#input=./nc-wmt11.de.gz
+#refs=./nc-wmt11.en.gz
+bitext=./nc-wmt11.gz
output=- # a weights file (add .gz for gzip compression) or STDOUT '-'
select_weights=VOID # output average (over epochs) weight vector
decoder_config=./cdec.ini # config for cdec
@@ -10,11 +11,11 @@ print_weights=Glue WordPenalty LanguageModel LanguageModel_OOV PhraseModel_0 Phr
stop_after=10 # stop epoch after 10 inputs
# interesting stuff
-epochs=2 # run over input 2 times
+epochs=3 # run over input 3 times
k=100 # use 100best lists
N=4 # optimize (approx) BLEU4
scorer=fixed_stupid_bleu # use 'stupid' BLEU+1
-learning_rate=1.0 # learning rate, don't care if gamma=0 (perceptron)
+learning_rate=0.1 # learning rate, don't care if gamma=0 (perceptron) and loss_margin=0 (not margin perceptron)
gamma=0 # use SVM reg
sample_from=kbest # use kbest lists (as opposed to forest)
filter=uniq # only unique entries in kbest (surface form)
@@ -22,3 +23,5 @@ pair_sampling=XYX #
hi_lo=0.1 # 10 vs 80 vs 10 and 80 vs 10 here
pair_threshold=0 # minimum distance in BLEU (here: > 0)
loss_margin=0 # update if correctly ranked, but within this margin
+repeat=1 # repeat training on a kbest list 1 times
+#batch=true # batch tuning, update after accumulating over all sentences and all kbest lists
diff --git a/training/dtrain/examples/standard/expected-output b/training/dtrain/examples/standard/expected-output
index 21f91244..75f47337 100644
--- a/training/dtrain/examples/standard/expected-output
+++ b/training/dtrain/examples/standard/expected-output
@@ -4,17 +4,18 @@ Reading ./nc-wmt11.en.srilm.gz
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
Example feature: Shape_S00000_T00000
-Seeding random number sequence to 970626287
+Seeding random number sequence to 3751911392
dtrain
Parameters:
k 100
N 4
- T 2
+ T 3
+ batch 0
scorer 'fixed_stupid_bleu'
sample from 'kbest'
filter 'uniq'
- learning rate 1
+ learning rate 0.1
gamma 0
loss margin 0
faster perceptron 1
@@ -23,69 +24,99 @@ Parameters:
pair threshold 0
select weights 'VOID'
l1 reg 0 'none'
+ pclr no
max pairs 4294967295
+ repeat 1
cdec cfg './cdec.ini'
- input './nc-wmt11.de.gz'
- refs './nc-wmt11.en.gz'
+ input './nc-wmt11.gz'
output '-'
stop_after 10
(a dot represents 10 inputs)
-Iteration #1 of 2.
+Iteration #1 of 3.
. 10
Stopping after 10 input sentences.
WEIGHTS
- Glue = -614
- WordPenalty = +1256.8
- LanguageModel = +5610.5
- LanguageModel_OOV = -1449
- PhraseModel_0 = -2107
- PhraseModel_1 = -4666.1
- PhraseModel_2 = -2713.5
- PhraseModel_3 = +4204.3
- PhraseModel_4 = -1435.8
- PhraseModel_5 = +916
- PhraseModel_6 = +190
- PassThrough = -2527
+ Glue = -110
+ WordPenalty = -8.2082
+ LanguageModel = -319.91
+ LanguageModel_OOV = -19.2
+ PhraseModel_0 = +312.82
+ PhraseModel_1 = -161.02
+ PhraseModel_2 = -433.65
+ PhraseModel_3 = +291.03
+ PhraseModel_4 = +252.32
+ PhraseModel_5 = +50.6
+ PhraseModel_6 = +146.7
+ PassThrough = -38.7
---
- 1best avg score: 0.17874 (+0.17874)
- 1best avg model score: 88399 (+88399)
- avg # pairs: 798.2 (meaningless)
- avg # rank err: 798.2
+ 1best avg score: 0.16966 (+0.16966)
+ 1best avg model score: 29874 (+29874)
+ avg # pairs: 906.3
+ avg # rank err: 0 (meaningless)
avg # margin viol: 0
- non0 feature count: 887
+ k-best loss imp: 100%
+ non0 feature count: 832
avg list sz: 91.3
- avg f count: 126.85
-(time 0.33 min, 2 s/S)
+ avg f count: 139.77
+(time 0.35 min, 2.1 s/S)
-Iteration #2 of 2.
+Iteration #2 of 3.
. 10
WEIGHTS
- Glue = -1025
- WordPenalty = +1751.5
- LanguageModel = +10059
- LanguageModel_OOV = -4490
- PhraseModel_0 = -2640.7
- PhraseModel_1 = -3757.4
- PhraseModel_2 = -1133.1
- PhraseModel_3 = +1837.3
- PhraseModel_4 = -3534.3
- PhraseModel_5 = +2308
- PhraseModel_6 = +1677
- PassThrough = -6222
+ Glue = -122.1
+ WordPenalty = +83.689
+ LanguageModel = +233.23
+ LanguageModel_OOV = -145.1
+ PhraseModel_0 = +150.72
+ PhraseModel_1 = -272.84
+ PhraseModel_2 = -418.36
+ PhraseModel_3 = +181.63
+ PhraseModel_4 = -289.47
+ PhraseModel_5 = +140.3
+ PhraseModel_6 = +3.5
+ PassThrough = -109.7
---
- 1best avg score: 0.30764 (+0.12891)
- 1best avg model score: -2.5042e+05 (-3.3882e+05)
- avg # pairs: 725.9 (meaningless)
- avg # rank err: 725.9
+ 1best avg score: 0.17399 (+0.004325)
+ 1best avg model score: 4936.9 (-24937)
+ avg # pairs: 662.4
+ avg # rank err: 0 (meaningless)
avg # margin viol: 0
- non0 feature count: 1499
+ k-best loss imp: 100%
+ non0 feature count: 1240
avg list sz: 91.3
- avg f count: 114.34
-(time 0.32 min, 1.9 s/S)
+ avg f count: 125.11
+(time 0.27 min, 1.6 s/S)
+
+Iteration #3 of 3.
+ . 10
+WEIGHTS
+ Glue = -157.4
+ WordPenalty = -1.7372
+ LanguageModel = +686.18
+ LanguageModel_OOV = -399.7
+ PhraseModel_0 = -39.876
+ PhraseModel_1 = -341.96
+ PhraseModel_2 = -318.67
+ PhraseModel_3 = +105.08
+ PhraseModel_4 = -290.27
+ PhraseModel_5 = -48.6
+ PhraseModel_6 = -43.6
+ PassThrough = -298.5
+ ---
+ 1best avg score: 0.30742 (+0.13343)
+ 1best avg model score: -15393 (-20329)
+ avg # pairs: 623.8
+ avg # rank err: 0 (meaningless)
+ avg # margin viol: 0
+ k-best loss imp: 100%
+ non0 feature count: 1776
+ avg list sz: 91.3
+ avg f count: 118.58
+(time 0.28 min, 1.7 s/S)
Writing weights file to '-' ...
done
---
-Best iteration: 2 [SCORE 'fixed_stupid_bleu'=0.30764].
-This took 0.65 min.
+Best iteration: 3 [SCORE 'fixed_stupid_bleu'=0.30742].
+This took 0.9 min.
diff --git a/training/dtrain/examples/standard/nc-wmt11.gz b/training/dtrain/examples/standard/nc-wmt11.gz
new file mode 100644
index 00000000..c39c5aef
--- /dev/null
+++ b/training/dtrain/examples/standard/nc-wmt11.gz
Binary files differ
diff --git a/training/dtrain/parallelize.rb b/training/dtrain/parallelize.rb
index 285f3c9b..60ca9422 100755
--- a/training/dtrain/parallelize.rb
+++ b/training/dtrain/parallelize.rb
@@ -21,6 +21,8 @@ opts = Trollop::options do
opt :qsub, "use qsub", :type => :bool, :default => false
opt :dtrain_binary, "path to dtrain binary", :type => :string
opt :extra_qsub, "extra qsub args", :type => :string, :default => ""
+ opt :per_shard_decoder_configs, "give special decoder config per shard", :type => :string, :short => '-o'
+ opt :first_input_weights, "input weights for first iter", :type => :string, :default => '', :short => '-w'
end
usage if not opts[:config]&&opts[:shards]&&opts[:input]&&opts[:references]
@@ -41,9 +43,11 @@ epochs = opts[:epochs]
rand = opts[:randomize]
reshard = opts[:reshard]
predefined_shards = false
+per_shard_decoder_configs = false
if opts[:shards] == 0
predefined_shards = true
num_shards = 0
+ per_shard_decoder_configs = true if opts[:per_shard_decoder_configs]
else
num_shards = opts[:shards]
end
@@ -51,6 +55,7 @@ input = opts[:input]
refs = opts[:references]
use_qsub = opts[:qsub]
shards_at_once = opts[:processes_at_once]
+first_input_weights = opts[:first_input_weights]
`mkdir work`
@@ -101,6 +106,9 @@ refs_files = []
if predefined_shards
input_files = File.new(input).readlines.map {|i| i.strip }
refs_files = File.new(refs).readlines.map {|i| i.strip }
+ if per_shard_decoder_configs
+ decoder_configs = File.new(opts[:per_shard_decoder_configs]).readlines.map {|i| i.strip}
+ end
num_shards = input_files.size
else
input_files, refs_files = make_shards input, refs, num_shards, 0, rand
@@ -126,10 +134,18 @@ end
else
local_end = "2>work/out.#{shard}.#{epoch}"
end
+ if per_shard_decoder_configs
+ cdec_cfg = "--decoder_config #{decoder_configs[shard]}"
+ else
+ cdec_cfg = ""
+ end
+ if first_input_weights!='' && epoch == 0
+ input_weights = "--input_weights #{first_input_weights}"
+ end
pids << Kernel.fork {
- `#{qsub_str_start}#{dtrain_bin} -c #{ini}\
+ `#{qsub_str_start}#{dtrain_bin} -c #{ini} #{cdec_cfg} #{input_weights}\
--input #{input_files[shard]}\
- --refs #{refs_files[shard]} #{input_weights}\
+ --refs #{refs_files[shard]}\
--output work/weights.#{shard}.#{epoch}#{qsub_str_end} #{local_end}`
}
weights_files << "work/weights.#{shard}.#{epoch}"
diff --git a/utils/filelib.h b/utils/filelib.h
index b9ea3940..4fa69760 100644
--- a/utils/filelib.h
+++ b/utils/filelib.h
@@ -75,7 +75,10 @@ class ReadFile : public BaseFile<std::istream> {
}
}
}
-
+ void ReadAll(std::string& s) {
+ getline(*stream(), s, (char) EOF);
+ if (s.size() > 0) s.resize(s.size()-1);
+ }
};
class WriteFile : public BaseFile<std::ostream> {