diff options
author | Patrick Simianer <p@simianer.de> | 2014-08-23 22:59:16 +0100 |
---|---|---|
committer | Patrick Simianer <p@simianer.de> | 2014-08-23 22:59:16 +0100 |
commit | cef65063cec641a93973b38a48e100fdd115db44 (patch) | |
tree | 32d5f10757e021a9fad01156fbff62a96212f006 /fast/grammar.hh | |
parent | 190f68c880eb27506669e95e2bc0493e2ec42c4c (diff) |
rewritten grammar
Diffstat (limited to 'fast/grammar.hh')
-rw-r--r-- | fast/grammar.hh | 303 |
1 files changed, 259 insertions, 44 deletions
diff --git a/fast/grammar.hh b/fast/grammar.hh index 1b9ac5a..e5acb8a 100644 --- a/fast/grammar.hh +++ b/fast/grammar.hh @@ -1,13 +1,14 @@ #pragma once -#include <fstream> #include <iostream> +#include <fstream> +#include <map> #include <sstream> #include <string> -#include <map> -#include <msgpack.hpp> #include <vector> +#include <msgpack.hpp> + #include "sparse_vector.hh" #include "util.hh" @@ -16,46 +17,138 @@ using namespace std; namespace G { -struct NT { - string symbol; - size_t index; +enum item_type { + UNDEFINED, + TERMINAL, + NON_TERMINAL +}; - NT() {}; - NT(string& s); +struct Item { + virtual size_t index() const { return 0; } + virtual symbol_t symbol() const { return ""; } + virtual item_type type() const { return UNDEFINED; } - string repr() const; - string escaped() const; + virtual ostream& repr(ostream& os) const { return os << "Item<>"; } + virtual ostream& escaped(ostream& os) const { return os << ""; } - friend ostream& operator<<(ostream& os, const NT& t); + friend ostream& + operator<<(ostream& os, const Item& i) + { + return i.repr(os); + }; }; -struct T { - string word; // use word ids instead? +struct NT : public Item { + symbol_t symbol_; + size_t index_; + + NT() {} + + NT(string const& s) + { + index_ = 0; // default + string t(s); + t.erase(0, 1); t.pop_back(); // remove '[' and ']' + istringstream ss(s); + if (ss >> index_) { // [i] + symbol_ = ""; + index_ = stoi(s); + + return; + } else { // [X] + symbol_ = s; + + return; + } + string buf; + size_t j = 0; + while (ss.good() && getline(ss, buf, ',')) { + if (j == 0) { + symbol_ = buf; + } else { + index_ = stoi(buf); + } + j++; + } + } - T(const string& s); + virtual size_t index() const { return index_; } + virtual symbol_t symbol() const { return symbol_; } + virtual item_type type() { return NON_TERMINAL; } - string repr() const; - string escaped() const; + virtual ostream& + repr(ostream& os) const + { + return os << "NT<" << symbol_ << "," << index_ << ">"; + } - friend ostream& operator<<(ostream& os, const NT& nt); + virtual ostream& + escaped(ostream& os) const + { + os << "[" << symbol_; + if (index_ > 0) + os << "," << index_; + os << "]"; + + return os; + } }; -enum item_type { - NON_TERMINAL, - TERMINAL +struct T : public Item { + symbol_t symbol_; + + T(string const& s) + { + symbol_ = s; + } + + virtual symbol_t symbol() const { return symbol_; } + virtual item_type type() { return TERMINAL; } + + virtual ostream& + repr(ostream& os) const + { + return os << "T<" << symbol_ << ">"; + } + + virtual ostream& + escaped(ostream& os) const + { + os << util::json_escape(symbol_); + } }; -struct Item { - item_type type; - NT* nt; - T* t; +struct Vocabulary +{ + unordered_map<symbol_t, size_t> m_; + vector<Item*> items_; - Item(string& s); + bool + is_non_terminal(string const& s) + { + return s.front() == '[' && s.back() == ']'; + } - string repr() const; - string escaped() const; + Item* + get(symbol_t const& s) + { + if (is_non_terminal(s)) + return new NT(s); + if (m_.find(s) != m_.end()) + return items_[m_[s]]; + return add(s); + } - friend ostream& operator<<(ostream& os, const Item& i); + Item* + add(symbol_t const& s) + { + size_t next_index_ = items_.size(); + T* item = new T(s); + items_.push_back(item); + m_[s] = next_index_; + + return item; + } }; struct Rule { @@ -65,35 +158,157 @@ struct Rule { size_t arity; Sv::SparseVector<string, score_t>* f; map<size_t, size_t> order; - string as_str_; // FIXME + string as_str_; + + Rule() {} + + Rule(string const& s, Vocabulary& vocab) { from_s(this, s, vocab); } + + static void + from_s(Rule* r, string const& s, Vocabulary& vocab) + { + istringstream ss(s); + string buf; + size_t j = 0, i = 0; + r->arity = 0; + vector<NT*> rhs_non_terminals; + r->f = new Sv::SparseVector<string, score_t>(); + while (ss >> buf) { + if (buf == "|||") { j++; continue; } + if (j == 0) { // left-hand side + r->lhs = new NT(buf); + } else if (j == 1) { // right-hand side + Item* item = vocab.get(buf); + r->rhs.push_back(item); + if (item->type() == NON_TERMINAL) { + r->arity++; + rhs_non_terminals.push_back(reinterpret_cast<NT*>(item)); + } + } else if (j == 2) { // target + Item* item = vocab.get(buf); + if (item->type() == NON_TERMINAL) { + r->order[i] = item->index(); + i++; + if (item->symbol() == "") { // only [1], [2] ... on target + reinterpret_cast<NT*>(item)->symbol_ = \ + rhs_non_terminals[item->index()-1]->symbol(); + } + } + r->target.push_back(item); + } else if (j == 3) { // feature vector + Sv::SparseVector<string, score_t>::from_s(r->f, buf); + // FIXME: this is slow!!! ^^^ + } else if (j == 4) { // alignment + } else { + // error + } + if (j == 4) break; + } + } - Rule() {}; - Rule(const string& s); - static void from_s(Rule* r, const string& s); + ostream& + repr(ostream& os) const + { + os << "Rule<lhs="; + lhs->repr(os); + os << ", rhs:{"; + for (auto it = rhs.begin(); it != rhs.end(); it++) { + (**it).repr(os); + if (next(it) != rhs.end()) os << " "; + } + os << "}, target:{"; + for (auto it = target.begin(); it != target.end(); it++) { + (**it).repr(os); + if (next(it) != target.end()) os << " "; + } + os << "}, f:"; + f->repr(os); + os << ", arity=" << arity << \ + ", map:" << "TODO" << \ + ">"; - string repr() const; - string escaped() const; + return os; + } - friend ostream& operator<<(ostream& os, const Rule& r); + ostream& + escaped(ostream& os) const + { + lhs->escaped(os); + os << " ||| "; + for (auto it = rhs.begin(); it != rhs.end(); it++) { + (**it).escaped(os); + if (next(it) != rhs.end()) os << " "; + } + os << " ||| "; + for (auto it = target.begin(); it != target.end(); it++) { + (**it).escaped(os); + if (next(it) != target.end()) os << " "; + } + os << " ||| "; + f->escaped(os); + os << " ||| "; + os << "TODO"; - void prep_for_serialization_() { as_str_ = escaped(); }; // FIXME + return os; + }; - MSGPACK_DEFINE(as_str_); // TODO + friend ostream& + operator<<(ostream& os, const Rule& r) + { + return r.repr(os); + }; + + // -- + void + prep_for_serialization_() + { + ostringstream os; + escaped(os); + as_str_ = os.str(); + }; + MSGPACK_DEFINE(as_str_); + // ^^^ FIXME }; struct Grammar { vector<Rule*> rules; vector<Rule*> flat; - vector<Rule*> start_nt; - vector<Rule*> start_t; + vector<Rule*> start_non_terminal; + vector<Rule*> start_terminal; + + Grammar() {} + + Grammar(string const& fn, Vocabulary& vocab) + { + ifstream ifs(fn); + string line; + while (getline(ifs, line)) { + G::Rule* r = new G::Rule(line, vocab); + rules.push_back(r); + if (r->arity == 0) + flat.push_back(r); + else if (r->rhs.front()->type() == NON_TERMINAL) + start_non_terminal.push_back(r); + else + start_terminal.push_back(r); + } + } - Grammar() {}; - Grammar(const string& fn); + void add_glue(); + // ^^^ TODO + void add_pass_through(const string& input); + // ^^^ TODO - void add_glue(); // TODO - void add_pass_through(const string& input); // TODO + friend ostream& + operator<<(ostream& os, Grammar const& g) + { + for (const auto it: g.rules) { + it->repr(os); + os << endl; + } - friend ostream& operator<<(ostream& os, const Grammar& g); + return os; + } }; } // namespace G |