diff options
author | Patrick Simianer <p@simianer.de> | 2015-02-26 14:24:41 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2015-02-26 14:24:41 +0100 |
commit | 29ddfafb0dea599965e6a881c25b396a6db2f40f (patch) | |
tree | 5755ec058361776657041ba088062b086eea6d68 /decoder | |
parent | 4223261682388944fe1b1cf31b9d51d88f9ad53b (diff) | |
parent | 03989754cb2511431e1df6001fca41b3806ad461 (diff) |
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'decoder')
-rw-r--r-- | decoder/bottom_up_parser-rs.cc | 29 | ||||
-rw-r--r-- | decoder/bottom_up_parser.cc | 43 | ||||
-rw-r--r-- | decoder/csplit.cc | 2 | ||||
-rw-r--r-- | decoder/grammar_test.cc | 4 | ||||
-rw-r--r-- | decoder/hg_io.cc | 90 | ||||
-rw-r--r-- | decoder/hg_test.cc | 6 | ||||
-rw-r--r-- | decoder/lattice.cc | 2 | ||||
-rw-r--r-- | decoder/lattice.h | 7 | ||||
-rw-r--r-- | decoder/parser_test.cc | 4 |
9 files changed, 115 insertions, 72 deletions
diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc index fbde7e24..863d7e2f 100644 --- a/decoder/bottom_up_parser-rs.cc +++ b/decoder/bottom_up_parser-rs.cc @@ -35,7 +35,7 @@ class RSChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector<double>& lattice_feats); // returns true if a new node was added to the chart // false otherwise @@ -43,7 +43,7 @@ class RSChart { const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); void TopoSortUnaries(); @@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0; // "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" struct RSActiveItem { explicit RSActiveItem(const GrammarIter* g, int i) : - gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} - void ExtendTerminal(int symbol, float src_cost) { - lattice_cost += src_cost; + gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {} + void ExtendTerminal(int symbol, const SparseVector<double>& src_feats) { + lattice_feats += src_feats; if (symbol != kEPS) gptr_ = gptr_->Extend(symbol); } @@ -85,7 +85,7 @@ struct RSActiveItem { } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO: use SparseVector<double> to encode input features + SparseVector<double> lattice_feats; short i_; }; @@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { @@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i, TRulePtr rule = rules->GetIthRule(k); // apply rule, and if we create a new node, apply any necessary // unary rules - if (ApplyRule(i, j, rule, tail, lattice_cost)) { + if (ApplyRule(i, j, rule, tail, lattice_feats)) { unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; ApplyUnaryRules(i, j, rule->lhs_, nodeidx); } @@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig //cerr << " --MATCH\n"; WordID new_lhs = unaries_[ri]->GetLHS(); const Hypergraph::TailNodeVector ant(1, nodeidx); - if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + if (ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>())) { //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; unsigned nodeidx = nodemap_(i,j)[new_lhs]; ApplyUnaryRules(i, j, new_lhs, nodeidx); @@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { // deal with completed rules const RuleBin* rb = x.gptr_->GetRules(); - if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats); //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; // continue looking for extensions of the rule to the right @@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { if (in_edge.dist2next == check_edge_len) { //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; RSActiveItem copy = x; - copy.ExtendTerminal(in_edge.label, in_edge.cost); + copy.ExtendTerminal(in_edge.label, in_edge.features); if (copy) AddToChart(copy, i, k); } } @@ -306,7 +305,7 @@ bool RSChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>()); } } if (!SILENT) cerr << endl; diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index b30f1ec6..7ce8e09d 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -38,13 +38,13 @@ class PassiveChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector<double>& lattice_feats); void ApplyUnaryRules(const int i, const int j); void TopoSortUnaries(); @@ -59,7 +59,6 @@ class PassiveChart { const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found - const int lc_fid_; vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] @@ -74,18 +73,18 @@ class ActiveChart { act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { - ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : - gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector<double>& lfeats) : + gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {} explicit ActiveItem(const GrammarIter* g) : - gptr_(g), ant_nodes_(), lattice_cost(0.0) {} + gptr_(g), ant_nodes_(), lattice_feats() {} - void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const { + void ExtendTerminal(int symbol, const SparseVector<double>& src_feats, vector<ActiveItem>* out_cell) const { if (symbol == kEPS) { - out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats)); } else { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) - out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const { @@ -96,12 +95,12 @@ class ActiveChart { for (unsigned i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; na[ant_nodes_.size()] = node_index; - out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + out_cell->push_back(ActiveItem(ni, na, lattice_feats)); } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO? use SparseVector<double> + SparseVector<double> lattice_feats; }; inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); } @@ -134,12 +133,12 @@ class ActiveChart { for (vector<LatticeArc>::const_iterator ai = out_arcs.begin(); ai != out_arcs.end(); ++ai) { const WordID& f = ai->label; - const double& c = ai->cost; + const SparseVector<double>& c = ai->features; const int& len = ai->dist2next; //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; const vector<ActiveItem>& ec = act_chart_(i, j-1); //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; - //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; } + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; } for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } @@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")), unaries_() { act_chart_.resize(grammars_.size()); for (unsigned i = 0; i < grammars_.size(); ++i) { @@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -265,25 +262,25 @@ void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector<double>& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; - ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats); } } void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector<int>& nodes = chart_(i,j); // reference is important! for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; + const WordID cat = forest_->nodes_[nodes[di]].cat_; for (unsigned ri = 0; ri < unaries_.size(); ++ri) { //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; if (unaries_[ri]->f()[0] == cat) { //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>()); // may update nodes } } } @@ -313,7 +310,7 @@ bool PassiveChart::Parse() { ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; - ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats); } } } @@ -331,7 +328,7 @@ bool PassiveChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>()); } } } diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 7ee4092e..7a6ed102 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -150,7 +150,7 @@ bool CompoundSplit::TranslateImpl(const string& input, SplitUTF8String(input, &in); smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign for (int i = 0; i < in.size(); ++i) - smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); + smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), SparseVector<double>(), 1))); smeta->ComputeInputLatticeType(); pimpl_->BuildTrellis(in, forest); forest->Reweight(weights); diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 69240139..11df1e3f 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(TestTextGrammarFile) { GrammarPtr g(new TextGrammar(path + "/grammar.prune")); vector<GrammarPtr> grammars(1, g); - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 626b2954..71f50a29 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -87,6 +87,12 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses return os.str(); } +// TODO this should write out the PLF with the Python dictionary format +// rather than just the single "LatticeCost" feature +double PLFFeatureDictionary(const SparseVector<double>& f) { + return f.get(FD::Convert("LatticeCost")); +} + string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) { static bool first = true; if (first) { InitEscapes(); first = false; } @@ -99,7 +105,7 @@ string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) os << '('; for (int j = 0; j < arcs.size(); ++j) { os << "('" << Escape(TD::Convert(arcs[j].label)) << "'," - << arcs[j].cost << ',' << arcs[j].dist2next << "),"; + << PLFFeatureDictionary(arcs[j].features) << ',' << arcs[j].dist2next << "),"; } os << "),"; } @@ -128,7 +134,10 @@ inline void eatws(const std::string& in, int& c) { std::string getEscapedString(const std::string& in, int &c) { eatws(in,c); - if (get(in,c++) != quote) return "ERROR"; + if (get(in,c++) != quote) { + cerr << "Expected escaped string to begin with " << quote << ". Got " << get(in, c - 1) << "\n"; + abort(); + } std::string res; char cur = 0; do { @@ -146,7 +155,7 @@ float getFloat(const std::string& in, int &c) { std::string tmp; eatws(in,c); - while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') { + while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',' && get(in,c) != '}') { tmp += get(in,c++); } eatws(in,c); @@ -171,7 +180,18 @@ int getInt(const std::string& in, int &c) // maximum number of nodes permitted #define MAX_NODES 100000000 -// parse ('foo', 0.23) + +void ReadPLFFeature(const std::string& in, int &c, map<string, float>& features) { + eatws(in,c); + string name = getEscapedString(in,c); + eatws(in,c); + if (get(in,c++) != ':') { cerr << "PCN/PLF parse error: expected : after feature name " << name << "\n"; abort(); } + float value = getFloat(in, c); + eatws(in,c); + features[name] = value; +} + +// parse ('foo', 0.23, 1) void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); } vector<WordID> ewords(2, 0); @@ -180,22 +200,49 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { r->ComputeArity(); // cerr << "RULE: " << r->AsString() << endl; if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); } + eatws(in,c); + + map<string, float> features; size_t cnNext = 1; - std::vector<float> probs; - probs.push_back(getFloat(in,c)); - while (get(in,c) == ',') { + // Read in sparse feature format + if (get(in,c) == '{') { c++; - float val = getFloat(in,c); - probs.push_back(val); - // cerr << val << endl; //REMO + eatws(in,c); + if (get(in,c) != '}') { + ReadPLFFeature(in, c, features); + } + while (get(in,c) == ',') { + c++; + if (get(in,c) == '}') { break; } + ReadPLFFeature(in, c, features); + } + if (get(in,c++) != '}') { cerr << "PCN/PLF parse error: expected } after feature dictionary\n"; abort(); } + eatws(in,c); + if (get(in, c++) != ',') { cerr << "PCN/PLF parse error: expected , after feature dictionary\n"; abort(); } + cnNext = static_cast<size_t>(getFloat(in, c)); } - //if we read more than one prob, this was a lattice, last item was column increment - if (probs.size()>1) { + // Read in dense feature format + else { + std::vector<float> probs; + probs.push_back(getFloat(in,c)); + while (get(in,c) == ',') { + c++; + float val = getFloat(in,c); + probs.push_back(val); + // cerr << val << endl; //REMO + } + if (probs.size() == 0) { cerr << "PCN/PLF parse error: missing destination state increment\n"; abort(); } + + // the last item was column increment cnNext = static_cast<size_t>(probs.back()); probs.pop_back(); - if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } + + for (unsigned i = 0; i < probs.size(); ++i) { + features["LatticeCost_" + to_string(i)] = probs[i]; + } } - if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); } + if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block. Got " << get(in, c-1) << "\n"; abort(); } + if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); } eatws(in,c); Hypergraph::TailNodeVector tail(1, cur_node); Hypergraph::Edge* edge = hg->AddEdge(r, tail); @@ -204,15 +251,15 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); - for (int i = 0; i < probs.size(); ++i) - edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]); + for (map<string, float>::iterator it = features.begin(); it != features.end(); ++it) { + edge->feature_values_.set_value(FD::Convert(it->first), it->second); + } } -// parse (('foo', 0.23), ('bar', 0.77)) +// parse (('foo', 0.23, 1), ('bar', 0.77, 1)) void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) { - //cerr << "PLF READING NODE " << cur_node << endl; if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); } - if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); } + if (get(in,c++) != '(') { cerr << line << ": Syntax error 1 in PLF\n"; abort(); } eatws(in,c); while (1) { if (c > (int)in.size()) { break; } @@ -237,7 +284,7 @@ void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line) hg->clear(); int c = 0; int cur_node = 0; - if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); } + if (in[c++] != '(') { cerr << line << ": Syntax error in PLF!\n"; abort(); } while (1) { if (c > (int)in.size()) { break; } if (PLF::get(in,c) == ')') { @@ -263,7 +310,6 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { ReadFromPLF(plf, &g, 0); const int num_nodes = g.nodes_.size() - 1; l.resize(num_nodes); - int fid0=FD::Convert("Feature_0"); for (int i = 0; i < num_nodes; ++i) { vector<LatticeArc>& alts = l[i]; const Hypergraph::Node& node = g.nodes_[i]; @@ -272,7 +318,7 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { for (int j = 0; j < num_alts; ++j) { const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; alts[j].label = edge.rule_->e_[1]; - alts[j].cost = edge.feature_values_.get(fid0); + alts[j].features = edge.feature_values_; alts[j].dist2next = edge.head_node_ - node.id_; } } diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 366b269d..a597ad8d 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -214,8 +214,8 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { BOOST_CHECK_EQUAL(4, best); Lattice target(2); - target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1)); - target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1)); + target[0].push_back(LatticeArc(TD::Convert("a"), SparseVector<double>(), 1)); + target[1].push_back(LatticeArc(TD::Convert("b"), SparseVector<double>(), 1)); HG::Intersect(target, &hg); hg.PrintGraphviz(); } @@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) { string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); SparseVector<double> wts; - wts.set_value(FD::Convert("Feature_0"), 1.0); + wts.set_value(FD::Convert("LatticeCost_0"), 1.0); hg.Reweight(wts); hg.PrintGraphviz(); string outplf = HypergraphIO::AsPLF(hg); diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 1f97048d..2740ce63 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -49,7 +49,7 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { l.clear(); l.resize(ids.size()); for (int i = 0; i < l.size(); ++i) - l[i].push_back(LatticeArc(ids[i], 0.0, 1)); + l[i].push_back(LatticeArc(ids[i], SparseVector<double>(), 1)); } void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index 1258d3f5..469615b5 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -3,6 +3,7 @@ #include <string> #include <vector> +#include "sparse_vector.h" #include "wordid.h" #include "array2d.h" @@ -15,10 +16,10 @@ struct LatticeTools { struct LatticeArc { WordID label; - double cost; + SparseVector<double> features; int dist2next; - LatticeArc() : label(), cost(), dist2next() {} - LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {} + LatticeArc() : label(), features(), dist2next() {} + LatticeArc(WordID w, const SparseVector<double>& f, int i) : label(w), features(f), dist2next(i) {} }; class Lattice : public std::vector<std::vector<LatticeArc> > { diff --git a/decoder/parser_test.cc b/decoder/parser_test.cc index e2916e44..37c45cc8 100644 --- a/decoder/parser_test.cc +++ b/decoder/parser_test.cc @@ -10,8 +10,8 @@ using namespace std; BOOST_AUTO_TEST_CASE(Parse) { - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); |