summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorCHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local>2015-02-03 21:24:07 -0500
committerCHRISTOPHER DYER <cdyer@CHRISTOPHERs-MacBook-Pro.local>2015-02-03 21:24:07 -0500
commitc7b2a39958912d7b85a384a871609e6db73042c7 (patch)
tree5405ac792a173edaa50e67750c2b10722a47962a /decoder
parentafd65846cf1456a8b49e8482b9a40777014f6883 (diff)
support multiple sparse features on lattice edges
Diffstat (limited to 'decoder')
-rw-r--r--decoder/bottom_up_parser-rs.cc29
-rw-r--r--decoder/bottom_up_parser.cc41
-rw-r--r--decoder/csplit.cc2
-rw-r--r--decoder/grammar_test.cc4
-rw-r--r--decoder/hg_io.cc21
-rw-r--r--decoder/hg_test.cc6
-rw-r--r--decoder/lattice.cc2
-rw-r--r--decoder/lattice.h7
-rw-r--r--decoder/parser_test.cc4
9 files changed, 62 insertions, 54 deletions
diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc
index fbde7e24..863d7e2f 100644
--- a/decoder/bottom_up_parser-rs.cc
+++ b/decoder/bottom_up_parser-rs.cc
@@ -35,7 +35,7 @@ class RSChart {
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
// returns true if a new node was added to the chart
// false otherwise
@@ -43,7 +43,7 @@ class RSChart {
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx);
void TopoSortUnaries();
@@ -69,9 +69,9 @@ WordID RSChart::kGOAL = 0;
// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span"
struct RSActiveItem {
explicit RSActiveItem(const GrammarIter* g, int i) :
- gptr_(g), ant_nodes_(), lattice_cost(0.0), i_(i) {}
- void ExtendTerminal(int symbol, float src_cost) {
- lattice_cost += src_cost;
+ gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {}
+ void ExtendTerminal(int symbol, const SparseVector<double>& src_feats) {
+ lattice_feats += src_feats;
if (symbol != kEPS)
gptr_ = gptr_->Extend(symbol);
}
@@ -85,7 +85,7 @@ struct RSActiveItem {
}
const GrammarIter* gptr_;
Hypergraph::TailNodeVector ant_nodes_;
- float lattice_cost; // TODO: use SparseVector<double> to encode input features
+ SparseVector<double> lattice_feats;
short i_;
};
@@ -174,7 +174,7 @@ bool RSChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
//cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
@@ -182,8 +182,7 @@ bool RSChart::ApplyRule(const int i,
new_edge->i_ = i;
new_edge->j_ = j;
new_edge->feature_values_ = r->GetFeatureValues();
- if (lattice_cost && lc_fid_)
- new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ new_edge->feature_values_ += lattice_feats;
Cat2NodeMap& c2n = nodemap_(i,j);
const bool is_goal = (r->GetLHS() == kGOAL);
const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
@@ -211,7 +210,7 @@ void RSChart::ApplyRules(const int i,
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
const int n = rules->GetNumRules();
//cerr << i << " " << j << ": NUM RULES: " << n << endl;
for (int k = 0; k < n; ++k) {
@@ -219,7 +218,7 @@ void RSChart::ApplyRules(const int i,
TRulePtr rule = rules->GetIthRule(k);
// apply rule, and if we create a new node, apply any necessary
// unary rules
- if (ApplyRule(i, j, rule, tail, lattice_cost)) {
+ if (ApplyRule(i, j, rule, tail, lattice_feats)) {
unsigned nodeidx = nodemap_(i,j)[rule->lhs_];
ApplyUnaryRules(i, j, rule->lhs_, nodeidx);
}
@@ -233,7 +232,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig
//cerr << " --MATCH\n";
WordID new_lhs = unaries_[ri]->GetLHS();
const Hypergraph::TailNodeVector ant(1, nodeidx);
- if (ApplyRule(i, j, unaries_[ri], ant, 0)) {
+ if (ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>())) {
//cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl;
unsigned nodeidx = nodemap_(i,j)[new_lhs];
ApplyUnaryRules(i, j, new_lhs, nodeidx);
@@ -245,7 +244,7 @@ void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsig
void RSChart::AddToChart(const RSActiveItem& x, int i, int j) {
// deal with completed rules
const RuleBin* rb = x.gptr_->GetRules();
- if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_cost);
+ if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats);
//cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n";
// continue looking for extensions of the rule to the right
@@ -264,7 +263,7 @@ void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) {
if (in_edge.dist2next == check_edge_len) {
//cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl;
RSActiveItem copy = x;
- copy.ExtendTerminal(in_edge.label, in_edge.cost);
+ copy.ExtendTerminal(in_edge.label, in_edge.features);
if (copy) AddToChart(copy, i, k);
}
}
@@ -306,7 +305,7 @@ bool RSChart::Parse() {
const Hypergraph::Node& node = forest_->nodes_[dh[di]];
if (node.cat_ == goal_cat_) {
Hypergraph::TailNodeVector ant(1, node.id_);
- ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>());
}
}
if (!SILENT) cerr << endl;
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index b30f1ec6..a614b8b3 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -38,13 +38,13 @@ class PassiveChart {
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost);
+ const SparseVector<double>& lattice_feats);
void ApplyUnaryRules(const int i, const int j);
void TopoSortUnaries();
@@ -59,7 +59,6 @@ class PassiveChart {
const WordID goal_cat_; // category that is being searched for at [0,n]
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
- const int lc_fid_;
vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
@@ -74,18 +73,18 @@ class ActiveChart {
act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {}
struct ActiveItem {
- ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, float lcost) :
- gptr_(g), ant_nodes_(a), lattice_cost(lcost) {}
+ ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector<double>& lfeats) :
+ gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {}
explicit ActiveItem(const GrammarIter* g) :
- gptr_(g), ant_nodes_(), lattice_cost(0.0) {}
+ gptr_(g), ant_nodes_(), lattice_feats() {}
- void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const {
+ void ExtendTerminal(int symbol, const SparseVector<double>& src_feats, vector<ActiveItem>* out_cell) const {
if (symbol == kEPS) {
- out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost));
+ out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats));
} else {
const GrammarIter* ni = gptr_->Extend(symbol);
if (ni)
- out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
+ out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats));
}
}
void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const {
@@ -96,12 +95,12 @@ class ActiveChart {
for (unsigned i = 0; i < ant_nodes_.size(); ++i)
na[i] = ant_nodes_[i];
na[ant_nodes_.size()] = node_index;
- out_cell->push_back(ActiveItem(ni, na, lattice_cost));
+ out_cell->push_back(ActiveItem(ni, na, lattice_feats));
}
const GrammarIter* gptr_;
Hypergraph::TailNodeVector ant_nodes_;
- float lattice_cost; // TODO? use SparseVector<double>
+ SparseVector<double> lattice_feats;
};
inline const vector<ActiveItem>& operator()(int i, int j) const { return act_chart_(i,j); }
@@ -134,12 +133,12 @@ class ActiveChart {
for (vector<LatticeArc>::const_iterator ai = out_arcs.begin();
ai != out_arcs.end(); ++ai) {
const WordID& f = ai->label;
- const double& c = ai->cost;
+ const SparseVector<double>& c = ai->features;
const int& len = ai->dist2next;
//cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl;
const vector<ActiveItem>& ec = act_chart_(i, j-1);
//cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl;
- //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; }
+ //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; }
for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di)
di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1));
}
@@ -163,7 +162,6 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")),
unaries_() {
act_chart_.resize(grammars_.size());
for (unsigned i = 0; i < grammars_.size(); ++i) {
@@ -232,7 +230,7 @@ void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
// cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
@@ -240,8 +238,7 @@ void PassiveChart::ApplyRule(const int i,
new_edge->i_ = i;
new_edge->j_ = j;
new_edge->feature_values_ = r->GetFeatureValues();
- if (lattice_cost && lc_fid_)
- new_edge->feature_values_.set_value(lc_fid_, lattice_cost);
+ new_edge->feature_values_ += lattice_feats;
Cat2NodeMap& c2n = nodemap_(i,j);
const bool is_goal = (r->GetLHS() == kGOAL);
const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS());
@@ -265,12 +262,12 @@ void PassiveChart::ApplyRules(const int i,
const int j,
const RuleBin* rules,
const Hypergraph::TailNodeVector& tail,
- const float lattice_cost) {
+ const SparseVector<double>& lattice_feats) {
const int n = rules->GetNumRules();
//cerr << i << " " << j << ": NUM RULES: " << n << endl;
for (int k = 0; k < n; ++k) {
//cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl;
- ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost);
+ ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats);
}
}
@@ -283,7 +280,7 @@ void PassiveChart::ApplyUnaryRules(const int i, const int j) {
if (unaries_[ri]->f()[0] == cat) {
//cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, SparseVector<double>()); // may update nodes
}
}
}
@@ -313,7 +310,7 @@ bool PassiveChart::Parse() {
ai != cell.end(); ++ai) {
const RuleBin* rules = (ai->gptr_->GetRules());
if (!rules) continue;
- ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_cost);
+ ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats);
}
}
}
@@ -331,7 +328,7 @@ bool PassiveChart::Parse() {
const Hypergraph::Node& node = forest_->nodes_[dh[di]];
if (node.cat_ == goal_cat_) {
Hypergraph::TailNodeVector ant(1, node.id_);
- ApplyRule(0, input_.size(), goal_rule_, ant, 0);
+ ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector<double>());
}
}
}
diff --git a/decoder/csplit.cc b/decoder/csplit.cc
index 7ee4092e..7a6ed102 100644
--- a/decoder/csplit.cc
+++ b/decoder/csplit.cc
@@ -150,7 +150,7 @@ bool CompoundSplit::TranslateImpl(const string& input,
SplitUTF8String(input, &in);
smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign
for (int i = 0; i < in.size(); ++i)
- smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1)));
+ smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), SparseVector<double>(), 1)));
smeta->ComputeInputLatticeType();
pimpl_->BuildTrellis(in, forest);
forest->Reweight(weights);
diff --git a/decoder/grammar_test.cc b/decoder/grammar_test.cc
index 69240139..11df1e3f 100644
--- a/decoder/grammar_test.cc
+++ b/decoder/grammar_test.cc
@@ -47,8 +47,8 @@ BOOST_AUTO_TEST_CASE(TestTextGrammarFile) {
GrammarPtr g(new TextGrammar(path + "/grammar.prune"));
vector<GrammarPtr> grammars(1, g);
- LatticeArc a(TD::Convert("ein"), 0.0, 1);
- LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1);
+ LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1);
Lattice lattice(2);
lattice[0].push_back(a);
lattice[1].push_back(b);
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index 626b2954..d97ab3dc 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -87,6 +87,12 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses
return os.str();
}
+// TODO this should write out the PLF with the Python dictionary format
+// rather than just the single "LatticeCost" feature
+double PLFFeatureDictionary(const SparseVector<double>& f) {
+ return f.get(FD::Convert("LatticeCost"));
+}
+
string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) {
static bool first = true;
if (first) { InitEscapes(); first = false; }
@@ -99,7 +105,7 @@ string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses)
os << '(';
for (int j = 0; j < arcs.size(); ++j) {
os << "('" << Escape(TD::Convert(arcs[j].label)) << "',"
- << arcs[j].cost << ',' << arcs[j].dist2next << "),";
+ << PLFFeatureDictionary(arcs[j].features) << ',' << arcs[j].dist2next << "),";
}
os << "),";
}
@@ -204,8 +210,14 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
- for (int i = 0; i < probs.size(); ++i)
- edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
+ if (probs.size() != 0) {
+ if (probs.size() == 1) {
+ edge->feature_values_.set_value(FD::Convert("LatticeCost"), probs[0]);
+ } else {
+ cerr << "Don't know how to deal with multiple lattice edge features: implement Python dictionary format.\n";
+ abort();
+ }
+ }
}
// parse (('foo', 0.23), ('bar', 0.77))
@@ -263,7 +275,6 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
ReadFromPLF(plf, &g, 0);
const int num_nodes = g.nodes_.size() - 1;
l.resize(num_nodes);
- int fid0=FD::Convert("Feature_0");
for (int i = 0; i < num_nodes; ++i) {
vector<LatticeArc>& alts = l[i];
const Hypergraph::Node& node = g.nodes_[i];
@@ -272,7 +283,7 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
for (int j = 0; j < num_alts; ++j) {
const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
alts[j].label = edge.rule_->e_[1];
- alts[j].cost = edge.feature_values_.get(fid0);
+ alts[j].features = edge.feature_values_;
alts[j].dist2next = edge.head_node_ - node.id_;
}
}
diff --git a/decoder/hg_test.cc b/decoder/hg_test.cc
index 366b269d..ec91cd3b 100644
--- a/decoder/hg_test.cc
+++ b/decoder/hg_test.cc
@@ -214,8 +214,8 @@ BOOST_AUTO_TEST_CASE(TestIntersect) {
BOOST_CHECK_EQUAL(4, best);
Lattice target(2);
- target[0].push_back(LatticeArc(TD::Convert("a"), 0.0, 1));
- target[1].push_back(LatticeArc(TD::Convert("b"), 0.0, 1));
+ target[0].push_back(LatticeArc(TD::Convert("a"), SparseVector<double>(), 1));
+ target[1].push_back(LatticeArc(TD::Convert("b"), SparseVector<double>(), 1));
HG::Intersect(target, &hg);
hg.PrintGraphviz();
}
@@ -256,7 +256,7 @@ BOOST_AUTO_TEST_CASE(PLF) {
string inplf = "((('haupt',-2.06655,1),('hauptgrund',-5.71033,2),),(('grund',-1.78709,1),),(('für\\'',0.1,1),),)";
HypergraphIO::ReadFromPLF(inplf, &hg);
SparseVector<double> wts;
- wts.set_value(FD::Convert("Feature_0"), 1.0);
+ wts.set_value(FD::Convert("LatticeCost"), 1.0);
hg.Reweight(wts);
hg.PrintGraphviz();
string outplf = HypergraphIO::AsPLF(hg);
diff --git a/decoder/lattice.cc b/decoder/lattice.cc
index 1f97048d..2740ce63 100644
--- a/decoder/lattice.cc
+++ b/decoder/lattice.cc
@@ -49,7 +49,7 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
l.clear();
l.resize(ids.size());
for (int i = 0; i < l.size(); ++i)
- l[i].push_back(LatticeArc(ids[i], 0.0, 1));
+ l[i].push_back(LatticeArc(ids[i], SparseVector<double>(), 1));
}
void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
diff --git a/decoder/lattice.h b/decoder/lattice.h
index 1258d3f5..469615b5 100644
--- a/decoder/lattice.h
+++ b/decoder/lattice.h
@@ -3,6 +3,7 @@
#include <string>
#include <vector>
+#include "sparse_vector.h"
#include "wordid.h"
#include "array2d.h"
@@ -15,10 +16,10 @@ struct LatticeTools {
struct LatticeArc {
WordID label;
- double cost;
+ SparseVector<double> features;
int dist2next;
- LatticeArc() : label(), cost(), dist2next() {}
- LatticeArc(WordID w, double c, int i) : label(w), cost(c), dist2next(i) {}
+ LatticeArc() : label(), features(), dist2next() {}
+ LatticeArc(WordID w, const SparseVector<double>& f, int i) : label(w), features(f), dist2next(i) {}
};
class Lattice : public std::vector<std::vector<LatticeArc> > {
diff --git a/decoder/parser_test.cc b/decoder/parser_test.cc
index e2916e44..37c45cc8 100644
--- a/decoder/parser_test.cc
+++ b/decoder/parser_test.cc
@@ -10,8 +10,8 @@
using namespace std;
BOOST_AUTO_TEST_CASE(Parse) {
- LatticeArc a(TD::Convert("ein"), 0.0, 1);
- LatticeArc b(TD::Convert("haus"), 0.0, 1);
+ LatticeArc a(TD::Convert("ein"), SparseVector<double>(), 1);
+ LatticeArc b(TD::Convert("haus"), SparseVector<double>(), 1);
Lattice lattice(2);
lattice[0].push_back(a);
lattice[1].push_back(b);