summaryrefslogtreecommitdiff
path: root/decoder/hg_io.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/hg_io.cc')
-rw-r--r--decoder/hg_io.cc90
1 files changed, 68 insertions, 22 deletions
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index 626b2954..71f50a29 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -87,6 +87,12 @@ string HypergraphIO::AsPLF(const Hypergraph& hg, bool include_global_parentheses
return os.str();
}
+// TODO this should write out the PLF with the Python dictionary format
+// rather than just the single "LatticeCost" feature
+double PLFFeatureDictionary(const SparseVector<double>& f) {
+ return f.get(FD::Convert("LatticeCost"));
+}
+
string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses) {
static bool first = true;
if (first) { InitEscapes(); first = false; }
@@ -99,7 +105,7 @@ string HypergraphIO::AsPLF(const Lattice& lat, bool include_global_parentheses)
os << '(';
for (int j = 0; j < arcs.size(); ++j) {
os << "('" << Escape(TD::Convert(arcs[j].label)) << "',"
- << arcs[j].cost << ',' << arcs[j].dist2next << "),";
+ << PLFFeatureDictionary(arcs[j].features) << ',' << arcs[j].dist2next << "),";
}
os << "),";
}
@@ -128,7 +134,10 @@ inline void eatws(const std::string& in, int& c) {
std::string getEscapedString(const std::string& in, int &c)
{
eatws(in,c);
- if (get(in,c++) != quote) return "ERROR";
+ if (get(in,c++) != quote) {
+ cerr << "Expected escaped string to begin with " << quote << ". Got " << get(in, c - 1) << "\n";
+ abort();
+ }
std::string res;
char cur = 0;
do {
@@ -146,7 +155,7 @@ float getFloat(const std::string& in, int &c)
{
std::string tmp;
eatws(in,c);
- while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',') {
+ while (c < (int)in.size() && get(in,c) != ' ' && get(in,c) != ')' && get(in,c) != ',' && get(in,c) != '}') {
tmp += get(in,c++);
}
eatws(in,c);
@@ -171,7 +180,18 @@ int getInt(const std::string& in, int &c)
// maximum number of nodes permitted
#define MAX_NODES 100000000
-// parse ('foo', 0.23)
+
+void ReadPLFFeature(const std::string& in, int &c, map<string, float>& features) {
+ eatws(in,c);
+ string name = getEscapedString(in,c);
+ eatws(in,c);
+ if (get(in,c++) != ':') { cerr << "PCN/PLF parse error: expected : after feature name " << name << "\n"; abort(); }
+ float value = getFloat(in, c);
+ eatws(in,c);
+ features[name] = value;
+}
+
+// parse ('foo', 0.23, 1)
void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
if (get(in,c++) != '(') { cerr << "PCN/PLF parse error: expected (\n"; abort(); }
vector<WordID> ewords(2, 0);
@@ -180,22 +200,49 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
r->ComputeArity();
// cerr << "RULE: " << r->AsString() << endl;
if (get(in,c++) != ',') { cerr << in << endl; cerr << "PCN/PLF parse error: expected , after string\n"; abort(); }
+ eatws(in,c);
+
+ map<string, float> features;
size_t cnNext = 1;
- std::vector<float> probs;
- probs.push_back(getFloat(in,c));
- while (get(in,c) == ',') {
+ // Read in sparse feature format
+ if (get(in,c) == '{') {
c++;
- float val = getFloat(in,c);
- probs.push_back(val);
- // cerr << val << endl; //REMO
+ eatws(in,c);
+ if (get(in,c) != '}') {
+ ReadPLFFeature(in, c, features);
+ }
+ while (get(in,c) == ',') {
+ c++;
+ if (get(in,c) == '}') { break; }
+ ReadPLFFeature(in, c, features);
+ }
+ if (get(in,c++) != '}') { cerr << "PCN/PLF parse error: expected } after feature dictionary\n"; abort(); }
+ eatws(in,c);
+ if (get(in, c++) != ',') { cerr << "PCN/PLF parse error: expected , after feature dictionary\n"; abort(); }
+ cnNext = static_cast<size_t>(getFloat(in, c));
}
- //if we read more than one prob, this was a lattice, last item was column increment
- if (probs.size()>1) {
+ // Read in dense feature format
+ else {
+ std::vector<float> probs;
+ probs.push_back(getFloat(in,c));
+ while (get(in,c) == ',') {
+ c++;
+ float val = getFloat(in,c);
+ probs.push_back(val);
+ // cerr << val << endl; //REMO
+ }
+ if (probs.size() == 0) { cerr << "PCN/PLF parse error: missing destination state increment\n"; abort(); }
+
+ // the last item was column increment
cnNext = static_cast<size_t>(probs.back());
probs.pop_back();
- if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); }
+
+ for (unsigned i = 0; i < probs.size(); ++i) {
+ features["LatticeCost_" + to_string(i)] = probs[i];
+ }
}
- if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block\n"; abort(); }
+ if (get(in,c++) != ')') { cerr << "PCN/PLF parse error: expected ) at end of cn alt block. Got " << get(in, c-1) << "\n"; abort(); }
+ if (cnNext < 1) { cerr << cnNext << endl << "PCN/PLF parse error: bad link length at last element of cn alt block\n"; abort(); }
eatws(in,c);
Hypergraph::TailNodeVector tail(1, cur_node);
Hypergraph::Edge* edge = hg->AddEdge(r, tail);
@@ -204,15 +251,15 @@ void ReadPLFEdge(const std::string& in, int &c, int cur_node, Hypergraph* hg) {
assert(head_node < MAX_NODES); // prevent malicious PLFs from using all the memory
if (hg->nodes_.size() < (head_node + 1)) { hg->ResizeNodes(head_node + 1); }
hg->ConnectEdgeToHeadNode(edge, &hg->nodes_[head_node]);
- for (int i = 0; i < probs.size(); ++i)
- edge->feature_values_.set_value(FD::Convert("Feature_" + boost::lexical_cast<string>(i)), probs[i]);
+ for (map<string, float>::iterator it = features.begin(); it != features.end(); ++it) {
+ edge->feature_values_.set_value(FD::Convert(it->first), it->second);
+ }
}
-// parse (('foo', 0.23), ('bar', 0.77))
+// parse (('foo', 0.23, 1), ('bar', 0.77, 1))
void ReadPLFNode(const std::string& in, int &c, int cur_node, int line, Hypergraph* hg) {
- //cerr << "PLF READING NODE " << cur_node << endl;
if (hg->nodes_.size() < (cur_node + 1)) { hg->ResizeNodes(cur_node + 1); }
- if (get(in,c++) != '(') { cerr << line << ": Syntax error 1\n"; abort(); }
+ if (get(in,c++) != '(') { cerr << line << ": Syntax error 1 in PLF\n"; abort(); }
eatws(in,c);
while (1) {
if (c > (int)in.size()) { break; }
@@ -237,7 +284,7 @@ void HypergraphIO::ReadFromPLF(const std::string& in, Hypergraph* hg, int line)
hg->clear();
int c = 0;
int cur_node = 0;
- if (in[c++] != '(') { cerr << line << ": Syntax error!\n"; abort(); }
+ if (in[c++] != '(') { cerr << line << ": Syntax error in PLF!\n"; abort(); }
while (1) {
if (c > (int)in.size()) { break; }
if (PLF::get(in,c) == ')') {
@@ -263,7 +310,6 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
ReadFromPLF(plf, &g, 0);
const int num_nodes = g.nodes_.size() - 1;
l.resize(num_nodes);
- int fid0=FD::Convert("Feature_0");
for (int i = 0; i < num_nodes; ++i) {
vector<LatticeArc>& alts = l[i];
const Hypergraph::Node& node = g.nodes_[i];
@@ -272,7 +318,7 @@ void HypergraphIO::PLFtoLattice(const string& plf, Lattice* pl) {
for (int j = 0; j < num_alts; ++j) {
const Hypergraph::Edge& edge = g.edges_[node.out_edges_[j]];
alts[j].label = edge.rule_->e_[1];
- alts[j].cost = edge.feature_values_.get(fid0);
+ alts[j].features = edge.feature_values_;
alts[j].dist2next = edge.head_node_ - node.id_;
}
}