summaryrefslogtreecommitdiff
path: root/training/utils/grammar_convert.cc
diff options
context:
space:
mode:
Diffstat (limited to 'training/utils/grammar_convert.cc')
-rw-r--r--training/utils/grammar_convert.cc21
1 files changed, 15 insertions, 6 deletions
diff --git a/training/utils/grammar_convert.cc b/training/utils/grammar_convert.cc
index 607a7cb9..5c1b4d4a 100644
--- a/training/utils/grammar_convert.cc
+++ b/training/utils/grammar_convert.cc
@@ -56,15 +56,22 @@ int GetOrCreateNode(const WordID& lhs, map<WordID, int>* lhs2node, Hypergraph* h
return node_id - 1;
}
+void AddDummyGoalNode(Hypergraph* hg) {
+ static const int kGOAL = -TD::Convert("Goal");
+ static TRulePtr kGOAL_RULE(new TRule("[Goal] ||| [X] ||| [1]"));
+ unsigned old_goal_node_idx = hg->nodes_.size() - 1;
+ HG::Node* goal_node = hg->AddNode(kGOAL);
+ goal_node->node_hash = goal_node->id_ * 10 + 1;
+ TailNodeVector tail(1, old_goal_node_idx);
+ HG::Edge* new_edge = hg->AddEdge(kGOAL_RULE, tail);
+ hg->ConnectEdgeToHeadNode(new_edge, goal_node);
+}
+
void FilterAndCheckCorrectness(int goal, Hypergraph* hg) {
if (goal < 0) {
cerr << "Error! [S] not found in grammar!\n";
exit(1);
}
- if (hg->nodes_[goal].in_edges_.size() != 1) {
- cerr << "Error! [S] has more than one rewrite!\n";
- exit(1);
- }
int old_size = hg->nodes_.size();
hg->TopologicallySortNodesAndEdges(goal);
if (hg->nodes_.size() != old_size) {
@@ -292,10 +299,10 @@ int main(int argc, char **argv) {
int lc = 0;
Hypergraph hg;
map<WordID, int> lhs2node;
+ string line;
while(*in) {
- string line;
+ getline(*in,line);
++lc;
- getline(*in, line);
if (is_json_input) {
if (line.empty() || line[0] == '#') continue;
string ref;
@@ -319,6 +326,7 @@ int main(int argc, char **argv) {
if (line.empty()) {
int goal = lhs2node[kSTART] - 1;
FilterAndCheckCorrectness(goal, &hg);
+ AddDummyGoalNode(&hg);
ProcessHypergraph(w, conf, "", &hg);
hg.clear();
lhs2node.clear();
@@ -342,6 +350,7 @@ int main(int argc, char **argv) {
edge->feature_values_ = tr->scores_;
Hypergraph::Node* node = &hg.nodes_[head];
hg.ConnectEdgeToHeadNode(edge, node);
+ node->node_hash = lc;
}
}
}