From 24cee877f2bb4b490495ea578564d0266b1addd3 Mon Sep 17 00:00:00 2001 From: CHRISTOPHER DYER Date: Tue, 3 Feb 2015 21:24:07 -0500 Subject: support multiple sparse features on lattice edges --- decoder/bottom_up_parser-rs.cc | 29 ++++++++++++++--------------- decoder/bottom_up_parser.cc | 41 +++++++++++++++++++---------------------- decoder/csplit.cc | 2 +- decoder/grammar_test.cc | 4 ++-- decoder/hg_io.cc | 21 ++++++++++++++++----- decoder/hg_test.cc | 6 +++--- decoder/lattice.cc | 2 +- decoder/lattice.h | 7 ++++--- decoder/parser_test.cc | 4 ++-- 9 files changed, 62 insertions(+), 54 deletions(-) diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc index fbde7e24..863d7e2f 100644 --- a/decoder/bottom_up_parser-rs.cc +++ b/decoder/bottom_up_parser-rs.cc @@ -35,7 +35,7 @@ class RSChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector& lattice_feats); // returns true if a new node was added to the chart // false otherwise @@ -43,7 +43,7 @@ class RSChart { const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector& lattice_feats); void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); void TopoSortUnaries(); @@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0; // "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" struct RSActiveItem { explicit RSActiveItem(const GrammarIter* g, int i) : - gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {} - void ExtendTerminal(int symbol, float src_cost) { - lattice_cost += src_cost; + gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {} + void ExtendTerminal(int symbol, const SparseVector& src_feats) { + lattice_feats += src_feats; if (symbol != kEPS) gptr_ = gptr_->Extend(symbol); } @@ -85,7 +85,7 @@ struct RSActiveItem { } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO: use SparseVector to encode input features + SparseVector lattice_feats; short i_; }; @@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { @@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i, TRulePtr rule = rules->GetIthRule(k); // apply rule, and if we create a new node, apply any necessary // unary rules - if (ApplyRule(i, j, rule, tail, lattice_cost)) { + if (ApplyRule(i, j, rule, tail, lattice_feats)) { unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; ApplyUnaryRules(i, j, rule->lhs_, nodeidx); } @@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig //cerr << " --MATCH\n"; WordID new_lhs = unaries_[ri]->GetLHS(); const Hypergraph::TailNodeVector ant(1, nodeidx); - if (ApplyRule(i, j, unaries_[ri], ant, 0)) { + if (ApplyRule(i, j, unaries_[ri], ant, SparseVector())) { //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; unsigned nodeidx = nodemap_(i,j)[new_lhs]; ApplyUnaryRules(i, j, new_lhs, nodeidx); @@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { // deal with completed rules const RuleBin* rb = x.gptr_->GetRules(); - if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats); //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; // continue looking for extensions of the rule to the right @@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { if (in_edge.dist2next == check_edge_len) { //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; RSActiveItem copy = x; - copy.ExtendTerminal(in_edge.label, in_edge.cost); + copy.ExtendTerminal(in_edge.label, in_edge.features); if (copy) AddToChart(copy, i, k); } } @@ -306,7 +305,7 @@ bool RSChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector()); } } if (!SILENT) cerr << endl; diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index b30f1ec6..a614b8b3 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -38,13 +38,13 @@ class PassiveChart { const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost); + const SparseVector& lattice_feats); void ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost); + const SparseVector& lattice_feats); void ApplyUnaryRules(const int i, const int j); void TopoSortUnaries(); @@ -59,7 +59,6 @@ class PassiveChart { const WordID goal_cat_; // category that is being searched for at [0,n] TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found - const int lc_fid_; vector unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] @@ -74,18 +73,18 @@ class ActiveChart { act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} struct ActiveItem { - ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) : - gptr_(g), ant_nodes_(a), lattice_cost(lcost) {} + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector& lfeats) : + gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {} explicit ActiveItem(const GrammarIter* g) : - gptr_(g), ant_nodes_(), lattice_cost(0.0) {} + gptr_(g), ant_nodes_(), lattice_feats() {} - void ExtendTerminal(int symbol, float src_cost, vector* out_cell) const { + void ExtendTerminal(int symbol, const SparseVector& src_feats, vector* out_cell) const { if (symbol == kEPS) { - out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats)); } else { const GrammarIter* ni = gptr_->Extend(symbol); if (ni) - out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector* out_cell) const { @@ -96,12 +95,12 @@ class ActiveChart { for (unsigned i = 0; i < ant_nodes_.size(); ++i) na[i] = ant_nodes_[i]; na[ant_nodes_.size()] = node_index; - out_cell->push_back(ActiveItem(ni, na, lattice_cost)); + out_cell->push_back(ActiveItem(ni, na, lattice_feats)); } const GrammarIter* gptr_; Hypergraph::TailNodeVector ant_nodes_; - float lattice_cost; // TODO? use SparseVector + SparseVector lattice_feats; }; inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } @@ -134,12 +133,12 @@ class ActiveChart { for (vector::const_iterator ai = out_arcs.begin(); ai != out_arcs.end(); ++ai) { const WordID& f = ai->label; - const double& c = ai->cost; + const SparseVector& c = ai->features; const int& len = ai->dist2next; //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; const vector& ec = act_chart_(i, j-1); //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; - //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; } + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; } for (vector::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } @@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")), unaries_() { act_chart_.resize(grammars_.size()); for (unsigned i = 0; i < grammars_.size(); ++i) { @@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, - const float lattice_cost) { + const SparseVector& lattice_feats) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; @@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i, new_edge->i_ = i; new_edge->j_ = j; new_edge->feature_values_ = r->GetFeatureValues(); - if (lattice_cost && lc_fid_) - new_edge->feature_values_.set_value(lc_fid_, lattice_cost); + new_edge->feature_values_ += lattice_feats; Cat2NodeMap& c2n = nodemap_(i,j); const bool is_goal = (r->GetLHS() == kGOAL); const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); @@ -265,12 +262,12 @@ void PassiveChart::ApplyRules(const int i, const int j, const RuleBin* rules, const Hypergraph::TailNodeVector& tail, - const float lattice_cost) { + const SparseVector& lattice_feats) { const int n = rules->GetNumRules(); //cerr << i << " " << j << ": NUM RULES: " << n << endl; for (int k = 0; k < n; ++k) { //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; - ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats); } } @@ -283,7 +280,7 @@ void PassiveChart::ApplyUnaryRules(const int i, const int j) { if (unaries_[ri]->f()[0] == cat) { //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, SparseVector()); // may update nodes } } } @@ -313,7 +310,7 @@ bool PassiveChart::Parse() { ai != cell.end(); ++ai) { const RuleBin* rules = (ai->gptr_->GetRules()); if (!rules) continue; - ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost); + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats); } } } @@ -331,7 +328,7 @@ bool PassiveChart::Parse() { const Hypergraph::Node& node = forest_->nodes_[dh[di]]; if (node.cat_ == goal_cat_) { Hypergraph::TailNodeVector ant(1, node.id_); - ApplyRule(0, input_.size(), goal_rule_, ant, 0); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector()); } } } diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 7ee4092e..7a6ed102 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -150,7 +150,7 @@ bool CompoundSplit::TranslateImpl(const string& input, SplitUTF8String(input, &in); smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign for (int i = 0; i < in.size(); ++i) - smeta->src_lattice_.push_back(vector(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); + smeta->src_lattice_.push_back(vector(1, LatticeArc(TD::Convert(in[i]), SparseVector(), 1))); smeta->ComputeInputLatticeType(); pimpl_->BuildTrellis(in, forest); forest->Reweight(weights); diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc index 69240139..11df1e3f 100644 --- a/decoder/grammar_test.cc +++ b/decoder/grammar_test.cc @@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(TestTextGrammarFile) { GrammarPtr g(new TextGrammar(path + "/grammar.prune")); vector grammars(1, g); - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc index 626b2954..d97ab3dc 100644 --- a/decoder/hg_io.cc +++ b/decoder/hg_io.cc @@ -87,6 +87,12 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses return os.str(); } +// TODO this should write out the PLF with the Python dictionary format +// rather than just the single "LatticeCost" feature +double PLFFeatureDictionary(const SparseVector& f) { + return f.get(FD::Convert("LatticeCost")); +} + string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) { static bool first = true; if (first) { InitEscapes(); first = false; } @@ -99,7 +105,7 @@ string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) os << '('; for (int j = 0; j < arcs.size(); ++j) { os << "('" << Escape(TD::Convert(arcs[j].label)) << "'," - << arcs[j].cost << ',' << arcs[j].dist2next << "),"; + << PLFFeatureDictionary(arcs[j].features) << ',' << arcs[j].dist2next << "),"; } os << "),"; } @@ -204,8 +210,14 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) { assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); } hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]); - for (int i = 0; i < probs.size(); ++i) - edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast(i)), probs[i]); + if (probs.size() != 0) { + if (probs.size() == 1) { + edge->feature_values_.set_value(FD::Convert("LatticeCost"), probs[0]); + } else { + cerr << "Don't know how to deal with multiple lattice edge features: implement Python dictionary format.\n"; + abort(); + } + } } // parse (('foo', 0.23), ('bar', 0.77)) @@ -263,7 +275,6 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { ReadFromPLF(plf, &g, 0); const int num_nodes = g.nodes_.size() - 1; l.resize(num_nodes); - int fid0=FD::Convert("Feature_0"); for (int i = 0; i < num_nodes; ++i) { vector& alts = l[i]; const Hypergraph::Node& node = g.nodes_[i]; @@ -272,7 +283,7 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) { for (int j = 0; j < num_alts; ++j) { const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]]; alts[j].label = edge.rule_->e_[1]; - alts[j].cost = edge.feature_values_.get(fid0); + alts[j].features = edge.feature_values_; alts[j].dist2next = edge.head_node_ - node.id_; } } diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc index 366b269d..ec91cd3b 100644 --- a/decoder/hg_test.cc +++ b/decoder/hg_test.cc @@ -214,8 +214,8 @@ BOOST_AUTO_TEST_CASE(TestIntersect) { BOOST_CHECK_EQUAL(4, best); Lattice target(2); - target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1)); - target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1)); + target[0].push_back(LatticeArc(TD::Convert("a"), SparseVector(), 1)); + target[1].push_back(LatticeArc(TD::Convert("b"), SparseVector(), 1)); HG::Intersect(target, &hg); hg.PrintGraphviz(); } @@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) { string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)"; HypergraphIO::ReadFromPLF(inplf, &hg); SparseVector wts; - wts.set_value(FD::Convert("Feature_0"), 1.0); + wts.set_value(FD::Convert("LatticeCost"), 1.0); hg.Reweight(wts); hg.PrintGraphviz(); string outplf = HypergraphIO::AsPLF(hg); diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 1f97048d..2740ce63 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -49,7 +49,7 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { l.clear(); l.resize(ids.size()); for (int i = 0; i < l.size(); ++i) - l[i].push_back(LatticeArc(ids[i], 0.0, 1)); + l[i].push_back(LatticeArc(ids[i], SparseVector(), 1)); } void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index 1258d3f5..469615b5 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -3,6 +3,7 @@ #include #include +#include "sparse_vector.h" #include "wordid.h" #include "array2d.h" @@ -15,10 +16,10 @@ struct LatticeTools { struct LatticeArc { WordID label; - double cost; + SparseVector features; int dist2next; - LatticeArc() : label(), cost(), dist2next() {} - LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {} + LatticeArc() : label(), features(), dist2next() {} + LatticeArc(WordID w, const SparseVector& f, int i) : label(w), features(f), dist2next(i) {} }; class Lattice : public std::vector > { diff --git a/decoder/parser_test.cc b/decoder/parser_test.cc index e2916e44..37c45cc8 100644 --- a/decoder/parser_test.cc +++ b/decoder/parser_test.cc @@ -10,8 +10,8 @@ using namespace std; BOOST_AUTO_TEST_CASE(Parse) { - LatticeArc a(TD::Convert("ein"), 0.0, 1); - LatticeArc b(TD::Convert("haus"), 0.0, 1); + LatticeArc a(TD::Convert("ein"), SparseVector(), 1); + LatticeArc b(TD::Convert("haus"), SparseVector(), 1); Lattice lattice(2); lattice[0].push_back(a); lattice[1].push_back(b); -- cgit v1.2.3