summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/bottom_up_parser.cc2
-rw-r--r--decoder/hg_io.cc83
-rw-r--r--decoder/hg_test.cc2
-rw-r--r--tests/system_tests/lattice/gold.statistics7
-rw-r--r--tests/system_tests/lattice/gold.stdout14
-rw-r--r--tests/system_tests/lattice/input.txt1
-rw-r--r--tests/system_tests/lattice/lattice.scfg1
-rw-r--r--tests/system_tests/lattice/weights3
-rwxr-xr-xtraining/mira/mira.py11
9 files changed, 88 insertions, 36 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index a614b8b3..7ce8e09d 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -274,7 +274,7 @@ void PassiveChart::ApplyRules(const int i,
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
for (unsigned di = 0; di < nodes.size(); ++di) {
- const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ const WordID cat = forest_->nodes_[nodes[di]].cat_;
for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
//cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
if (unaries_[ri]->f()[0] == cat) {
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index d97ab3dc..71f50a29 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -134,7 +134,10 @@ inline void eatws(const std::string& in, int& c) {
std::string getEscapedString(const std::string& in, int &c)
{
eatws(in,c);
- if (get(in,c++) != quote) return "ERROR";
+ if (get(in,c++) != quote) {
+ cerr << "Expected escaped string to begin with " << quote << ". Got " << get(in, c - 1) << "\n";
+ abort();
+ }
std::string res;
char cur = 0;
do {
@@ -152,7 +155,7 @@ float getFloat(const std::string& in, int &c)
{
std::string tmp;
eatws(in,c);
- while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',' && get(in,c) != '}') {
tmp += get(in,c++);
}
eatws(in,c);
@@ -177,7 +180,18 @@ int getInt(const std::string& in, int &c)
// maximum number of nodes permitted
#define MAX_NODES 100000000
-// parse ('foo', 0.23)
+
+void ReadPLFFeature(const std::string& in, int &c, map<string, float>& features) {
+ eatws(in,c);
+ string name = getEscapedString(in,c);
+ eatws(in,c);
+ if (get(in,c++) != ':') { cerr << "PCN/PLF parse error: expected : after feature name " << name << "\n"; abort(); }
+ float value = getFloat(in, c);
+ eatws(in,c);
+ features[name] = value;
+}
+
+// parse ('foo', 0.23, 1)
void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); }
vector<WordID> ewords(2, 0);
@@ -186,22 +200,49 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
r->ComputeArity();
// cerr << "RULE: " << r->AsString() << endl;
if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); }
+ eatws(in,c);
+
+ map<string, float> features;
size_t cnNext = 1;
- std::vector<float> probs;
- probs.push_back(getFloat(in,c));
- while (get(in,c) == ',') {
+ // Read in sparse feature format
+ if (get(in,c) == '{') {
c++;
- float val = getFloat(in,c);
- probs.push_back(val);
- // cerr << val << endl; //REMO
+ eatws(in,c);
+ if (get(in,c) != '}') {
+ ReadPLFFeature(in, c, features);
+ }
+ while (get(in,c) == ',') {
+ c++;
+ if (get(in,c) == '}') { break; }
+ ReadPLFFeature(in, c, features);
+ }
+ if (get(in,c++) != '}') { cerr << "PCN/PLF parse error: expected } after feature dictionary\n"; abort(); }
+ eatws(in,c);
+ if (get(in, c++) != ',') { cerr << "PCN/PLF parse error: expected , after feature dictionary\n"; abort(); }
+ cnNext = static_cast<size_t>(getFloat(in, c));
}
- //if we read more than one prob, this was a lattice, last item was column increment
- if (probs.size()>1) {
+ // Read in dense feature format
+ else {
+ std::vector<float> probs;
+ probs.push_back(getFloat(in,c));
+ while (get(in,c) == ',') {
+ c++;
+ float val = getFloat(in,c);
+ probs.push_back(val);
+ // cerr << val << endl; //REMO
+ }
+ if (probs.size() == 0) { cerr << "PCN/PLF parse error: missing destination state increment\n"; abort(); }
+
+ // the last item was column increment
cnNext = static_cast<size_t>(probs.back());
probs.pop_back();
- if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); }
+
+ for (unsigned i = 0; i < probs.size(); ++i) {
+ features["LatticeCost_" + to_string(i)] = probs[i];
+ }
}
- if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); }
+ if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block. Got " << get(in, c-1) << "\n"; abort(); }
+ if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); }
eatws(in,c);
Hypergraph::TailNodeVector tail(1, cur_node);
Hypergraph::Edge* edge = hg->AddEdge(r, tail);
@@ -210,21 +251,15 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
- if (probs.size() != 0) {
- if (probs.size() == 1) {
- edge->feature_values_.set_value(FD::Convert("LatticeCost"), probs[0]);
- } else {
- cerr << "Don't know how to deal with multiple lattice edge features: implement Python dictionary format.\n";
- abort();
- }
+ for (map<string, float>::iterator it = features.begin(); it != features.end(); ++it) {
+ edge->feature_values_.set_value(FD::Convert(it->first), it->second);
}
}
-// parse (('foo', 0.23), ('bar', 0.77))
+// parse (('foo', 0.23, 1), ('bar', 0.77, 1))
void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
- //cerr << "PLF READING NODE " << cur_node << endl;
if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
- if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
+ if (get(in,c++) != '(') { cerr << line << ": Syntax error 1 in PLF\n"; abort(); }
eatws(in,c);
while (1) {
if (c > (int)in.size()) { break; }
@@ -249,7 +284,7 @@ void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line)
hg->clear();
int c = 0;
int cur_node = 0;
- if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
+ if (in[c++] != '(') { cerr << line << ": Syntax error in PLF!\n"; abort(); }
while (1) {
if (c > (int)in.size()) { break; }
if (PLF::get(in,c) == ')') {
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index ec91cd3b..a597ad8d 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) {
string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)";
HypergraphIO::ReadFromPLF(inplf, &hg);
SparseVector<double> wts;
- wts.set_value(FD::Convert("LatticeCost"), 1.0);
+ wts.set_value(FD::Convert("LatticeCost_0"), 1.0);
hg.Reweight(wts);
hg.PrintGraphviz();
string outplf = HypergraphIO::AsPLF(hg);
diff --git a/tests/system_tests/lattice/gold.statistics b/tests/system_tests/lattice/gold.statistics
index 302ddf14..4ce55900 100644
--- a/tests/system_tests/lattice/gold.statistics
+++ b/tests/system_tests/lattice/gold.statistics
@@ -5,3 +5,10 @@
+lm_edges 10
+lm_paths 5
+lm_trans ab
+-lm_nodes 3
+-lm_edges 6
+-lm_paths 4
++lm_nodes 3
++lm_edges 6
++lm_paths 4
++lm_trans d'
diff --git a/tests/system_tests/lattice/gold.stdout b/tests/system_tests/lattice/gold.stdout
index 1adb51f1..dd2a2943 100644
--- a/tests/system_tests/lattice/gold.stdout
+++ b/tests/system_tests/lattice/gold.stdout
@@ -1,5 +1,9 @@
-0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost=0.125 ||| -1.09359
-0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost=2.25 ||| -3.85288
-0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost=2.5 ||| -4.00288
-0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718
-0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718
+0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.125 ||| -1.09359
+0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost_0=2.25 ||| -3.85288
+0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost_0=2.5 ||| -4.00288
+0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718
+0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718
+1 ||| d' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost0=-0.1 LatticeCost_0=0.1 UsesDPrime=1 ||| 999.031
+1 ||| b ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.2 ||| -1.06859
+1 ||| a' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859
+1 ||| a ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859
diff --git a/tests/system_tests/lattice/input.txt b/tests/system_tests/lattice/input.txt
index e0cd1b57..17bfd47c 100644
--- a/tests/system_tests/lattice/input.txt
+++ b/tests/system_tests/lattice/input.txt
@@ -1 +1,2 @@
((('A',0.5,1),('C',0.25,1),('AB',0.125,2),),(('B',2,1),),)
+((('A',0.5,1),('D\'',{'LatticeCost_0':0.1, 'UsesDPrime':1.0,},1),('B', 1)),)
diff --git a/tests/system_tests/lattice/lattice.scfg b/tests/system_tests/lattice/lattice.scfg
index 87a72383..04fe0cf0 100644
--- a/tests/system_tests/lattice/lattice.scfg
+++ b/tests/system_tests/lattice/lattice.scfg
@@ -4,3 +4,4 @@
[X] ||| AB ||| ab ||| Cost=0.1
[X] ||| C B ||| cb ||| Cost=0.3
[X] ||| A B ||| a_b ||| Cost=0.2
+[X] ||| D' ||| d' ||| Cost0=-0.1
diff --git a/tests/system_tests/lattice/weights b/tests/system_tests/lattice/weights
index cb59b27b..7e7d0fa8 100644
--- a/tests/system_tests/lattice/weights
+++ b/tests/system_tests/lattice/weights
@@ -2,4 +2,5 @@ WordPenalty 1
SourceWordPenalty 1
Glue 0
Cost -1
-LatticeCost -1
+LatticeCost_0 -1
+UsesDPrime 1000
diff --git a/training/mira/mira.py b/training/mira/mira.py
index 32478c4f..4c87c664 100755
--- a/training/mira/mira.py
+++ b/training/mira/mira.py
@@ -138,6 +138,9 @@ def main():
args = parser.parse_args()
args.metric = args.metric.upper()
+ score_sign = 1.0
+ if args.metric == 'TER' or args.metric == 'WER' or args.metric == 'CER':
+ score_sign = -1.0
if not args.update_size:
args.update_size = args.kbest_size
@@ -187,7 +190,7 @@ def main():
args.devset = newdev
log_config(args)
- args.weights, hope_best_fear = optimize(args, script_dir, dev_size)
+ args.weights, hope_best_fear = optimize(args, script_dir, dev_size, score_sign)
graph_file = ''
if have_mpl: graph_file = graph(args.output_dir, hope_best_fear, args.metric)
@@ -305,7 +308,7 @@ def split_devset(dev, outdir):
refs.close()
return (outdir+'/source.input', outdir+'/refs.input')
-def optimize(args, script_dir, dev_size):
+def optimize(args, script_dir, dev_size, score_sign):
parallelize = script_dir+'/../utils/parallelize.pl'
if args.qsub:
parallelize += " -p %s"%args.pmem
@@ -316,7 +319,7 @@ def optimize(args, script_dir, dev_size):
num_features = 0
last_p_score = 0
best_score_iter = -1
- best_score = -1
+ best_score = -10 * score_sign
i = 0
hope_best_fear = {'hope':[],'best':[],'fear':[]}
#main optimization loop
@@ -433,7 +436,7 @@ def optimize(args, script_dir, dev_size):
hope_best_fear['fear'].append(dec_score_f)
logging.info('DECODER SCORE: {0} HOPE: {1} FEAR: {2}'.format(
dec_score, dec_score_h, dec_score_f))
- if dec_score > best_score:
+ if score_sign*dec_score > score_sign*best_score:
best_score_iter = i
best_score = dec_score