diff options
-rw-r--r-- | decoder/bottom_up_parser.cc | 2 | ||||
-rw-r--r-- | decoder/hg_io.cc | 83 | ||||
-rw-r--r-- | decoder/hg_test.cc | 2 | ||||
-rw-r--r-- | tests/system_tests/lattice/gold.statistics | 7 | ||||
-rw-r--r-- | tests/system_tests/lattice/gold.stdout | 14 | ||||
-rw-r--r-- | tests/system_tests/lattice/input.txt | 1 | ||||
-rw-r--r-- | tests/system_tests/lattice/lattice.scfg | 1 | ||||
-rw-r--r-- | tests/system_tests/lattice/weights | 3 | ||||
-rwxr-xr-x | training/mira/mira.py | 11 |
9 files changed, 88 insertions, 36 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index a614b8b3..7ce8e09d 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -274,7 +274,7 @@ void PassiveChart::ApplyRules(const int i, void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector<int>& nodes = chart_(i,j); // reference is important! for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const WordID cat = forest_->nodes_[nodes[di]].cat_; for (unsigned ri = 0; ri < unaries_.size(); ++ri) { //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; if (unaries_[ri]->f()[0] == cat) { diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index d97ab3dc..71f50a29 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -134,7 +134,10 @@ inline void eatws(const std::string& in, int& c) { std::string getEscapedString(const std::string& in, int &c) { eatws(in,c); - if (get(in,c++) != quote) return "ERROR"; + if (get(in,c++) != quote) { + cerr << "Expected escaped string to begin with " << quote << ". Got " << get(in, c - 1) << "\n"; + abort(); + } std::string res; char cur = 0; do { @@ -152,7 +155,7 @@ float getFloat(const std::string& in, int &c) { std::string tmp; eatws(in,c); - while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',' && get(in,c) != '}') { tmp += get(in,c++); } eatws(in,c); @@ -177,7 +180,18 @@ int getInt(const std::string& in, int &c) // maximum number of nodes permitted #define MAX_NODES 100000000 -// parse ('foo', 0.23) + +void ReadPLFFeature(const std::string& in, int &c, map<string, float>& features) { + eatws(in,c); + string name = getEscapedString(in,c); + eatws(in,c); + if (get(in,c++) != ':') { cerr << "PCN/PLF parse error: expected : after feature name " << name << "\n"; abort(); } + float value = getFloat(in, c); + eatws(in,c); + features[name] = value; +} + +// parse ('foo', 0.23, 1) void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); } vector<WordID> ewords(2, 0); @@ -186,22 +200,49 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { r->ComputeArity(); // cerr << "RULE: " << r->AsString() << endl; if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); } + eatws(in,c); + + map<string, float> features; size_t cnNext = 1; - std::vector<float> probs; - probs.push_back(getFloat(in,c)); - while (get(in,c) == ',') { + // Read in sparse feature format + if (get(in,c) == '{') { c++; - float val = getFloat(in,c); - probs.push_back(val); - // cerr << val << endl; //REMO + eatws(in,c); + if (get(in,c) != '}') { + ReadPLFFeature(in, c, features); + } + while (get(in,c) == ',') { + c++; + if (get(in,c) == '}') { break; } + ReadPLFFeature(in, c, features); + } + if (get(in,c++) != '}') { cerr << "PCN/PLF parse error: expected } after feature dictionary\n"; abort(); } + eatws(in,c); + if (get(in, c++) != ',') { cerr << "PCN/PLF parse error: expected , after feature dictionary\n"; abort(); } + cnNext = static_cast<size_t>(getFloat(in, c)); } - //if we read more than one prob, this was a lattice, last item was column increment - if (probs.size()>1) { + // Read in dense feature format + else { + std::vector<float> probs; + probs.push_back(getFloat(in,c)); + while (get(in,c) == ',') { + c++; + float val = getFloat(in,c); + probs.push_back(val); + // cerr << val << endl; //REMO + } + if (probs.size() == 0) { cerr << "PCN/PLF parse error: missing destination state increment\n"; abort(); } + + // the last item was column increment cnNext = static_cast<size_t>(probs.back()); probs.pop_back(); - if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } + + for (unsigned i = 0; i < probs.size(); ++i) { + features["LatticeCost_" + to_string(i)] = probs[i]; + } } - if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); } + if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block. Got " << get(in, c-1) << "\n"; abort(); } + if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } eatws(in,c); Hypergraph::TailNodeVector tail(1, cur_node); Hypergraph::Edge* edge = hg->AddEdge(r, tail); @@ -210,21 +251,15 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); - if (probs.size() != 0) { - if (probs.size() == 1) { - edge->feature_values_.set_value(FD::Convert("LatticeCost"), probs[0]); - } else { - cerr << "Don't know how to deal with multiple lattice edge features: implement Python dictionary format.\n"; - abort(); - } + for (map<string, float>::iterator it = features.begin(); it != features.end(); ++it) { + edge->feature_values_.set_value(FD::Convert(it->first), it->second); } } -// parse (('foo', 0.23), ('bar', 0.77)) +// parse (('foo', 0.23, 1), ('bar', 0.77, 1)) void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { - //cerr << "PLF READING NODE " << cur_node << endl; if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } - if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } + if (get(in,c++) != '(') { cerr << line << ": Syntax error 1 in PLF\n"; abort(); } eatws(in,c); while (1) { if (c > (int)in.size()) { break; } @@ -249,7 +284,7 @@ void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) hg->clear(); int c = 0; int cur_node = 0; - if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } + if (in[c++] != '(') { cerr << line << ": Syntax error in PLF!\n"; abort(); } while (1) { if (c > (int)in.size()) { break; } if (PLF::get(in,c) == ')') { diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index ec91cd3b..a597ad8d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) { string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); SparseVector<double> wts; - wts.set_value(FD::Convert("LatticeCost"), 1.0); + wts.set_value(FD::Convert("LatticeCost_0"), 1.0); hg.Reweight(wts); hg.PrintGraphviz(); string outplf = HypergraphIO::AsPLF(hg); diff --git a/tests/system_tests/lattice/gold.statistics b/tests/system_tests/lattice/gold.statistics index 302ddf14..4ce55900 100644 --- a/tests/system_tests/lattice/gold.statistics +++ b/tests/system_tests/lattice/gold.statistics @@ -5,3 +5,10 @@ +lm_edges 10 +lm_paths 5 +lm_trans ab +-lm_nodes 3 +-lm_edges 6 +-lm_paths 4 ++lm_nodes 3 ++lm_edges 6 ++lm_paths 4 ++lm_trans d' diff --git a/tests/system_tests/lattice/gold.stdout b/tests/system_tests/lattice/gold.stdout index 1adb51f1..dd2a2943 100644 --- a/tests/system_tests/lattice/gold.stdout +++ b/tests/system_tests/lattice/gold.stdout @@ -1,5 +1,9 @@ -0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost=0.125 ||| -1.09359 -0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost=2.25 ||| -3.85288 -0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost=2.5 ||| -4.00288 -0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718 -0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost=2.5 ||| -4.53718 +0 ||| ab ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.125 ||| -1.09359 +0 ||| cb ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.3 LatticeCost_0=2.25 ||| -3.85288 +0 ||| a_b ||| SourceWordPenalty=-0.868589 WordPenalty=-0.434294 Cost=0.2 LatticeCost_0=2.5 ||| -4.00288 +0 ||| a' b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718 +0 ||| a b ||| Glue=1 SourceWordPenalty=-0.868589 WordPenalty=-0.868589 Cost=0.3 LatticeCost_0=2.5 ||| -4.53718 +1 ||| d' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost0=-0.1 LatticeCost_0=0.1 UsesDPrime=1 ||| 999.031 +1 ||| b ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.2 ||| -1.06859 +1 ||| a' ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859 +1 ||| a ||| SourceWordPenalty=-0.434294 WordPenalty=-0.434294 Cost=0.1 LatticeCost_0=0.5 ||| -1.46859 diff --git a/tests/system_tests/lattice/input.txt b/tests/system_tests/lattice/input.txt index e0cd1b57..17bfd47c 100644 --- a/tests/system_tests/lattice/input.txt +++ b/tests/system_tests/lattice/input.txt @@ -1 +1,2 @@ ((('A',0.5,1),('C',0.25,1),('AB',0.125,2),),(('B',2,1),),) +((('A',0.5,1),('D\'',{'LatticeCost_0':0.1, 'UsesDPrime':1.0,},1),('B', 1)),) diff --git a/tests/system_tests/lattice/lattice.scfg b/tests/system_tests/lattice/lattice.scfg index 87a72383..04fe0cf0 100644 --- a/tests/system_tests/lattice/lattice.scfg +++ b/tests/system_tests/lattice/lattice.scfg @@ -4,3 +4,4 @@ [X] ||| AB ||| ab ||| Cost=0.1 [X] ||| C B ||| cb ||| Cost=0.3 [X] ||| A B ||| a_b ||| Cost=0.2 +[X] ||| D' ||| d' ||| Cost0=-0.1 diff --git a/tests/system_tests/lattice/weights b/tests/system_tests/lattice/weights index cb59b27b..7e7d0fa8 100644 --- a/tests/system_tests/lattice/weights +++ b/tests/system_tests/lattice/weights @@ -2,4 +2,5 @@ WordPenalty 1 SourceWordPenalty 1 Glue 0 Cost -1 -LatticeCost -1 +LatticeCost_0 -1 +UsesDPrime 1000 diff --git a/training/mira/mira.py b/training/mira/mira.py index 32478c4f..4c87c664 100755 --- a/training/mira/mira.py +++ b/training/mira/mira.py @@ -138,6 +138,9 @@ def main(): args = parser.parse_args() args.metric = args.metric.upper() + score_sign = 1.0 + if args.metric == 'TER' or args.metric == 'WER' or args.metric == 'CER': + score_sign = -1.0 if not args.update_size: args.update_size = args.kbest_size @@ -187,7 +190,7 @@ def main(): args.devset = newdev log_config(args) - args.weights, hope_best_fear = optimize(args, script_dir, dev_size) + args.weights, hope_best_fear = optimize(args, script_dir, dev_size, score_sign) graph_file = '' if have_mpl: graph_file = graph(args.output_dir, hope_best_fear, args.metric) @@ -305,7 +308,7 @@ def split_devset(dev, outdir): refs.close() return (outdir+'/source.input', outdir+'/refs.input') -def optimize(args, script_dir, dev_size): +def optimize(args, script_dir, dev_size, score_sign): parallelize = script_dir+'/../utils/parallelize.pl' if args.qsub: parallelize += " -p %s"%args.pmem @@ -316,7 +319,7 @@ def optimize(args, script_dir, dev_size): num_features = 0 last_p_score = 0 best_score_iter = -1 - best_score = -1 + best_score = -10 * score_sign i = 0 hope_best_fear = {'hope':[],'best':[],'fear':[]} #main optimization loop @@ -433,7 +436,7 @@ def optimize(args, script_dir, dev_size): hope_best_fear['fear'].append(dec_score_f) logging.info('DECODER SCORE: {0} HOPE: {1} FEAR: {2}'.format( dec_score, dec_score_h, dec_score_f)) - if dec_score > best_score: + if score_sign*dec_score > score_sign*best_score: best_score_iter = i best_score = dec_score |