summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
Diffstat (limited to 'decoder')
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/decoder.cc16
-rw-r--r--decoder/decoder.h6
-rw-r--r--decoder/earley_composer.cc38
-rw-r--r--decoder/ff_ngrams.cc85
-rw-r--r--decoder/ff_ngrams.h2
-rw-r--r--decoder/ff_tagger.cc17
-rw-r--r--decoder/hg.cc63
-rw-r--r--decoder/hg.h17
-rw-r--r--decoder/hg_io.cc1
-rw-r--r--decoder/hg_remove_eps.cc91
-rw-r--r--decoder/hg_remove_eps.h13
-rw-r--r--decoder/inside_outside.h4
-rw-r--r--decoder/rescore_translator.cc58
-rw-r--r--decoder/scfg_translator.cc70
-rw-r--r--decoder/tagger.cc1
-rw-r--r--decoder/translator.h17
-rw-r--r--decoder/trule.cc4
18 files changed, 347 insertions, 158 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 00d01e53..0a792549 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -37,9 +37,11 @@ libcdec_a_SOURCES = \
fst_translator.cc \
csplit.cc \
translator.cc \
+ rescore_translator.cc \
scfg_translator.cc \
hg.cc \
hg_io.cc \
+ hg_remove_eps.cc \
decoder.cc \
hg_intersect.cc \
hg_sampler.cc \
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 333f0fb6..a6f7b1ce 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -527,8 +527,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
}
formalism = LowercaseString(str("formalism",conf));
- if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', or 'tagger'\n";
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lextrans" && formalism != "pb" && formalism != "csplit" && formalism != "tagger" && formalism != "lexalign" && formalism != "rescore") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lextrans', 'lexalign', 'rescore', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -675,6 +675,8 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
translator.reset(new LexicalTrans(conf));
else if (formalism == "lexalign")
translator.reset(new LexicalAlign(conf));
+ else if (formalism == "rescore")
+ translator.reset(new RescoreTranslator(conf));
else if (formalism == "tagger")
translator.reset(new Tagger(conf));
else
@@ -743,16 +745,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
}
vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
-void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
- assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
+void Decoder::AddSupplementalGrammar(GrammarPtr gp) {
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammar(gp);
}
-void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) {
+void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str);
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
-
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 6b2f7b16..bef2ff5e 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -37,6 +37,8 @@ struct DecoderObserver {
virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);
};
+struct Grammar; // TODO once the decoder interface is cleaned up,
+ // this should be somewhere else
struct Decoder {
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
@@ -54,8 +56,8 @@ struct Decoder {
// add grammar rules (currently only supported by SCFG decoders)
// that will be used on subsequent calls to Decode. rules should be in standard
// text format. This function does NOT read from a file.
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar_str);
+ void AddSupplementalGrammar(boost::shared_ptr<Grammar> gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar_string);
private:
boost::program_options::variables_map conf;
boost::shared_ptr<DecoderImpl> pimpl_;
diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc
index d265d954..efce70a6 100644
--- a/decoder/earley_composer.cc
+++ b/decoder/earley_composer.cc
@@ -16,6 +16,7 @@
#include "sparse_vector.h"
#include "tdict.h"
#include "hg.h"
+#include "hg_remove_eps.h"
using namespace std;
using namespace std::tr1;
@@ -48,6 +49,27 @@ static void InitializeConstants() {
}
////////////////////////////////////////////////////////////
+TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) {
+ TRule* r = new TRule(*kX1X2);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs1;
+ r->f_[1] = rhs2;
+ return TRulePtr(r);
+}
+
+TRulePtr CreateUnaryRule(int lhs, int rhs1) {
+ TRule* r = new TRule(*kX1);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs1;
+ return TRulePtr(r);
+}
+
+TRulePtr CreateEpsilonRule(int lhs) {
+ TRule* r = new TRule(*kEPSRule);
+ r->lhs_ = lhs;
+ return TRulePtr(r);
+}
+
class EGrammarNode {
friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest);
friend void AddGrammarRule(const string& r, map<WordID, EGrammarNode>* g);
@@ -356,7 +378,7 @@ class EarleyComposerImpl {
}
if (goal_node) {
forest->PruneUnreachable(goal_node->id_);
- forest->EpsilonRemove(kEPS);
+ RemoveEpsilons(forest, kEPS);
}
FreeAll();
return goal_node;
@@ -557,24 +579,30 @@ class EarleyComposerImpl {
}
Hypergraph::Node*& head_node = edge2node[edge];
if (!head_node)
- head_node = hg->AddNode(kPHRASE);
+ head_node = hg->AddNode(edge->cat);
if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) {
assert(goal_node == NULL || goal_node == head_node);
goal_node = head_node;
}
+ int rhs1 = 0;
+ int rhs2 = 0;
Hypergraph::TailNodeVector tail;
SparseVector<double> extra;
if (edge->IsCreatedByPredict()) {
// extra.set_value(FD::Convert("predict"), 1);
} else if (edge->IsCreatedByScan()) {
tail.push_back(edge2node[edge->active_parent]->id_);
+ rhs1 = edge->active_parent->cat;
if (tps) {
tail.push_back(tps->id_);
+ rhs2 = kPHRASE;
}
//extra.set_value(FD::Convert("scan"), 1);
} else if (edge->IsCreatedByComplete()) {
tail.push_back(edge2node[edge->active_parent]->id_);
+ rhs1 = edge->active_parent->cat;
tail.push_back(edge2node[edge->passive_parent]->id_);
+ rhs2 = edge->passive_parent->cat;
//extra.set_value(FD::Convert("complete"), 1);
} else {
assert(!"unexpected edge type!");
@@ -592,11 +620,11 @@ class EarleyComposerImpl {
#endif
Hypergraph::Edge* hg_edge = NULL;
if (tail.size() == 0) {
- hg_edge = hg->AddEdge(kEPSRule, tail);
+ hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail);
} else if (tail.size() == 1) {
- hg_edge = hg->AddEdge(kX1, tail);
+ hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail);
} else if (tail.size() == 2) {
- hg_edge = hg->AddEdge(kX1X2, tail);
+ hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail);
}
if (edge->features)
hg_edge->feature_values_ += *edge->features;
diff --git a/decoder/ff_ngrams.cc b/decoder/ff_ngrams.cc
index d6d79f5e..9c13fdbb 100644
--- a/decoder/ff_ngrams.cc
+++ b/decoder/ff_ngrams.cc
@@ -48,6 +48,9 @@ struct State {
namespace {
string Escape(const string& x) {
+ if (x.find('=') == string::npos && x.find(';') == string::npos) {
+ return x;
+ }
string y = x;
for (int i = 0; i < y.size(); ++i) {
if (y[i] == '=') y[i]='_';
@@ -57,10 +60,17 @@ namespace {
}
}
-static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order) {
+static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order, vector<string>& prefixes, string& target_separator) {
vector<string> const& argv=SplitOnWhitespace(in);
*explicit_markers = false;
*order = 3;
+ prefixes.push_back("NOT-USED");
+ prefixes.push_back("U:"); // default unigram prefix
+ prefixes.push_back("B:"); // default bigram prefix
+ prefixes.push_back("T:"); // ...etc
+ prefixes.push_back("4:"); // ...etc
+ prefixes.push_back("5:"); // max allowed!
+ target_separator = "_";
#define LMSPEC_NEXTARG if (i==argv.end()) { \
cerr << "Missing argument for "<<*last<<". "; goto usage; \
} else { ++i; }
@@ -73,6 +83,30 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order)
case 'x':
*explicit_markers = true;
break;
+ case 'U':
+ LMSPEC_NEXTARG;
+ prefixes[1] = *i;
+ break;
+ case 'B':
+ LMSPEC_NEXTARG;
+ prefixes[2] = *i;
+ break;
+ case 'T':
+ LMSPEC_NEXTARG;
+ prefixes[3] = *i;
+ break;
+ case '4':
+ LMSPEC_NEXTARG;
+ prefixes[4] = *i;
+ break;
+ case '5':
+ LMSPEC_NEXTARG;
+ prefixes[5] = *i;
+ break;
+ case 'S':
+ LMSPEC_NEXTARG;
+ target_separator = *i;
+ break;
case 'o':
LMSPEC_NEXTARG; *order=atoi((*i).c_str());
break;
@@ -86,7 +120,29 @@ static bool ParseArgs(string const& in, bool* explicit_markers, unsigned* order)
}
return true;
usage:
- cerr << "NgramFeatures is incorrect!\n";
+ cerr << "Wrong parameters for NgramFeatures.\n\n"
+
+ << "NgramFeatures Usage: \n"
+ << " feature_function=NgramFeatures filename.lm [-x] [-o <order>] \n"
+ << " [-U <unigram-prefix>] [-B <bigram-prefix>][-T <trigram-prefix>]\n"
+ << " [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]\n\n"
+
+ << "Defaults: \n"
+ << " <order> = 3\n"
+ << " <unigram-prefix> = U:\n"
+ << " <bigram-prefix> = B:\n"
+ << " <trigram-prefix> = T:\n"
+ << " <4-gram-prefix> = 4:\n"
+ << " <5-gram-prefix> = 5:\n"
+ << " <separator> = _\n"
+ << " -x (i.e. explicit sos/eos markers) is turned off\n\n"
+
+ << "Example configuration: \n"
+ << " feature_function=NgramFeatures -o 3 -T tri: -S |\n\n"
+
+ << "Example feature instantiation: \n"
+ << " tri:a|b|c \n\n";
+
return false;
}
@@ -158,16 +214,12 @@ class NgramDetectorImpl {
int& fid = ft->fids[curword];
++n;
if (!fid) {
- const char* code="_UBT456789"; // prefix code (unigram, bigram, etc.)
ostringstream os;
- os << code[n] << ':';
+ os << prefixes_[n];
for (int i = n-1; i >= 0; --i) {
- os << (i != n-1 ? "_" : "");
+ os << (i != n-1 ? target_separator_ : "");
const string& tok = TD::Convert(buf[i]);
- if (tok.find('=') == string::npos)
- os << tok;
- else
- os << Escape(tok);
+ os << Escape(tok);
}
fid = FD::Convert(os.str());
}
@@ -297,7 +349,8 @@ class NgramDetectorImpl {
}
public:
- explicit NgramDetectorImpl(bool explicit_markers, unsigned order) :
+ explicit NgramDetectorImpl(bool explicit_markers, unsigned order,
+ vector<string>& prefixes, string& target_separator) :
kCDEC_UNK(TD::Convert("<unk>")) ,
add_sos_eos_(!explicit_markers) {
order_ = order;
@@ -305,6 +358,8 @@ class NgramDetectorImpl {
unscored_size_offset_ = (order_ - 1) * sizeof(WordID);
is_complete_offset_ = unscored_size_offset_ + 1;
unscored_words_offset_ = is_complete_offset_ + 1;
+ prefixes_ = prefixes;
+ target_separator_ = target_separator;
// special handling of beginning / ending sentence markers
dummy_state_ = new char[state_size_];
@@ -340,6 +395,8 @@ class NgramDetectorImpl {
char* dummy_state_;
vector<const void*> dummy_ants_;
TRulePtr dummy_rule_;
+ vector<string> prefixes_;
+ string target_separator_;
struct FidTree {
map<WordID, int> fids;
map<WordID, FidTree> levels;
@@ -348,11 +405,13 @@ class NgramDetectorImpl {
};
NgramDetector::NgramDetector(const string& param) {
- string filename, mapfile, featname;
+ string filename, mapfile, featname, target_separator;
+ vector<string> prefixes;
bool explicit_markers = false;
unsigned order = 3;
- ParseArgs(param, &explicit_markers, &order);
- pimpl_ = new NgramDetectorImpl(explicit_markers, order);
+ ParseArgs(param, &explicit_markers, &order, prefixes, target_separator);
+ pimpl_ = new NgramDetectorImpl(explicit_markers, order, prefixes,
+ target_separator);
SetStateSize(pimpl_->ReserveStateSize());
}
diff --git a/decoder/ff_ngrams.h b/decoder/ff_ngrams.h
index 82f61b33..064dbb49 100644
--- a/decoder/ff_ngrams.h
+++ b/decoder/ff_ngrams.h
@@ -10,7 +10,7 @@
struct NgramDetectorImpl;
class NgramDetector : public FeatureFunction {
public:
- // param = "filename.lm [-o n]"
+ // param = "filename.lm [-o <order>] [-U <unigram-prefix>] [-B <bigram-prefix>] [-T <trigram-prefix>] [-4 <4-gram-prefix>] [-5 <5-gram-prefix>] [-S <separator>]
NgramDetector(const std::string& param);
~NgramDetector();
virtual void FinalTraversalFeatures(const void* context,
diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc
index 019315a2..fd9210fa 100644
--- a/decoder/ff_tagger.cc
+++ b/decoder/ff_tagger.cc
@@ -8,6 +8,17 @@
using namespace std;
+namespace {
+ string Escape(const string& x) {
+ string y = x;
+ for (int i = 0; i < y.size(); ++i) {
+ if (y[i] == '=') y[i]='_';
+ if (y[i] == ';') y[i]='_';
+ }
+ return y;
+ }
+}
+
Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) :
FeatureFunction(sizeof(WordID)) {
no_uni_ = (LowercaseString(param) == "no_uni");
@@ -28,7 +39,7 @@ void Tagger_BigramIndicator::FireFeature(const WordID& left,
os << '_';
if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
}
- fid = FD::Convert(os.str());
+ fid = FD::Convert(Escape(os.str()));
}
features->set_value(fid, 1.0);
}
@@ -90,7 +101,7 @@ void LexicalPairIndicator::FireFeature(WordID src,
if (!fid) {
ostringstream os;
os << name_ << ':' << TD::Convert(src) << ':' << TD::Convert(trg);
- fid = FD::Convert(os.str());
+ fid = FD::Convert(Escape(os.str()));
}
features->set_value(fid, 1.0);
}
@@ -127,7 +138,7 @@ void OutputIndicator::FireFeature(WordID trg,
if (escape.count(trg)) trg = escape[trg];
ostringstream os;
os << "T:" << TD::Convert(trg);
- fid = FD::Convert(os.str());
+ fid = FD::Convert(Escape(os.str()));
}
features->set_value(fid, 1.0);
}
diff --git a/decoder/hg.cc b/decoder/hg.cc
index dd272221..7240a8ab 100644
--- a/decoder/hg.cc
+++ b/decoder/hg.cc
@@ -605,69 +605,6 @@ void Hypergraph::TopologicallySortNodesAndEdges(int goal_index,
#endif
}
-TRulePtr Hypergraph::kEPSRule;
-TRulePtr Hypergraph::kUnaryRule;
-
-void Hypergraph::EpsilonRemove(WordID eps) {
- if (!kEPSRule) {
- kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
- kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
- }
- vector<bool> kill(edges_.size(), false);
- for (unsigned i = 0; i < edges_.size(); ++i) {
- const Edge& edge = edges_[i];
- if (edge.tail_nodes_.empty() &&
- edge.rule_->f_.size() == 1 &&
- edge.rule_->f_[0] == eps) {
- kill[i] = true;
- if (!edge.feature_values_.empty()) {
- Node& node = nodes_[edge.head_node_];
- if (node.in_edges_.size() != 1) {
- cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
- // this *probably* means that there are multiple derivations of the
- // same sequence via different paths through the input forest
- // this needs to be investigated and fixed
- } else {
- for (unsigned j = 0; j < node.out_edges_.size(); ++j)
- edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
- // cerr << "PROMOTED " << edge.feature_values_ << endl;
- }
- }
- }
- }
- bool created_eps = false;
- PruneEdges(kill);
- for (unsigned i = 0; i < nodes_.size(); ++i) {
- const Node& node = nodes_[i];
- if (node.in_edges_.empty()) {
- for (unsigned j = 0; j < node.out_edges_.size(); ++j) {
- Edge& edge = edges_[node.out_edges_[j]];
- if (edge.rule_->Arity() == 2) {
- assert(edge.rule_->f_.size() == 2);
- assert(edge.rule_->e_.size() == 2);
- edge.rule_ = kUnaryRule;
- unsigned cur = node.id_;
- int t = -1;
- assert(edge.tail_nodes_.size() == 2);
- for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; }
- assert(t != -1);
- edge.tail_nodes_.resize(1);
- edge.tail_nodes_[0] = t;
- } else {
- edge.rule_ = kEPSRule;
- edge.rule_->f_[0] = eps;
- edge.rule_->e_[0] = eps;
- edge.tail_nodes_.clear();
- created_eps = true;
- }
- }
- }
- }
- vector<bool> k2(edges_.size(), false);
- PruneEdges(k2);
- if (created_eps) EpsilonRemove(eps);
-}
-
struct EdgeWeightSorter {
const Hypergraph& hg;
EdgeWeightSorter(const Hypergraph& h) : hg(h) {}
diff --git a/decoder/hg.h b/decoder/hg.h
index 91d25f01..591e98ce 100644
--- a/decoder/hg.h
+++ b/decoder/hg.h
@@ -148,7 +148,7 @@ public:
void show(std::ostream &o,unsigned mask=SPAN|RULE) const {
o<<'{';
if (mask&CATEGORY)
- o<<TD::Convert(rule_->GetLHS());
+ o<< '[' << TD::Convert(-rule_->GetLHS()) << ']';
if (mask&PREV_SPAN)
o<<'<'<<prev_i_<<','<<prev_j_<<'>';
if (mask&SPAN)
@@ -156,9 +156,9 @@ public:
if (mask&PROB)
o<<" p="<<edge_prob_;
if (mask&FEATURES)
- o<<" "<<feature_values_;
+ o<<' '<<feature_values_;
if (mask&RULE)
- o<<rule_->AsString(mask&RULE_LHS);
+ o<<' '<<rule_->AsString(mask&RULE_LHS);
if (USE_INFO_EDGE) {
std::string const& i=info();
if (mask&&!i.empty()) o << " |||"<<i; // remember, the initial space is expected as part of i
@@ -384,14 +384,6 @@ public:
// compute the total number of paths in the forest
double NumberOfPaths() const;
- // BEWARE. this assumes that the source and target language
- // strings are identical and that there are no loops.
- // It assumes a bunch of other things about where the
- // epsilons will be. It tries to assert failure if you
- // break these assumptions, but it may not.
- // TODO - make this work
- void EpsilonRemove(WordID eps);
-
// multiple the weights vector by the edge feature vector
// (inner product) to set the edge probabilities
template <class V>
@@ -535,9 +527,6 @@ public:
private:
Hypergraph(int num_nodes, int num_edges, bool is_lc) : is_linear_chain_(is_lc), nodes_(num_nodes), edges_(num_edges),edges_topo_(true) {}
-
- static TRulePtr kEPSRule;
- static TRulePtr kUnaryRule;
};
diff --git a/decoder/hg_io.cc b/decoder/hg_io.cc
index bfb2fb80..8bd40387 100644
--- a/decoder/hg_io.cc
+++ b/decoder/hg_io.cc
@@ -261,6 +261,7 @@ static void WriteRule(const TRule& r, ostream* out) {
}
bool HypergraphIO::WriteToJSON(const Hypergraph& hg, bool remove_rules, ostream* out) {
+ if (hg.empty()) { *out << "{}\n"; return true; }
map<const TRule*, int> rid;
ostream& o = *out;
rid[NULL] = 0;
diff --git a/decoder/hg_remove_eps.cc b/decoder/hg_remove_eps.cc
new file mode 100644
index 00000000..050c4876
--- /dev/null
+++ b/decoder/hg_remove_eps.cc
@@ -0,0 +1,91 @@
+#include "hg_remove_eps.h"
+
+#include <cassert>
+
+#include "trule.h"
+#include "hg.h"
+
+using namespace std;
+
+namespace {
+ TRulePtr kEPSRule;
+ TRulePtr kUnaryRule;
+
+ TRulePtr CreateUnaryRule(int lhs, int rhs) {
+ if (!kUnaryRule) kUnaryRule.reset(new TRule("[X] ||| [X,1] ||| [X,1]"));
+ TRule* r = new TRule(*kUnaryRule);
+ assert(lhs < 0);
+ assert(rhs < 0);
+ r->lhs_ = lhs;
+ r->f_[0] = rhs;
+ return TRulePtr(r);
+ }
+
+ TRulePtr CreateEpsilonRule(int lhs, WordID eps) {
+ if (!kEPSRule) kEPSRule.reset(new TRule("[X] ||| <eps> ||| <eps>"));
+ TRule* r = new TRule(*kEPSRule);
+ r->lhs_ = lhs;
+ assert(lhs < 0);
+ assert(eps > 0);
+ r->e_[0] = eps;
+ r->f_[0] = eps;
+ return TRulePtr(r);
+ }
+}
+
+void RemoveEpsilons(Hypergraph* g, WordID eps) {
+ vector<bool> kill(g->edges_.size(), false);
+ for (unsigned i = 0; i < g->edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = g->edges_[i];
+ if (edge.tail_nodes_.empty() &&
+ edge.rule_->f_.size() == 1 &&
+ edge.rule_->f_[0] == eps) {
+ kill[i] = true;
+ if (!edge.feature_values_.empty()) {
+ Hypergraph::Node& node = g->nodes_[edge.head_node_];
+ if (node.in_edges_.size() != 1) {
+ cerr << "[WARNING] <eps> edge with features going into non-empty node - can't promote\n";
+ // this *probably* means that there are multiple derivations of the
+ // same sequence via different paths through the input forest
+ // this needs to be investigated and fixed
+ } else {
+ for (unsigned j = 0; j < node.out_edges_.size(); ++j)
+ g->edges_[node.out_edges_[j]].feature_values_ += edge.feature_values_;
+ // cerr << "PROMOTED " << edge.feature_values_ << endl;
+ }
+ }
+ }
+ }
+ bool created_eps = false;
+ g->PruneEdges(kill);
+ for (unsigned i = 0; i < g->nodes_.size(); ++i) {
+ const Hypergraph::Node& node = g->nodes_[i];
+ if (node.in_edges_.empty()) {
+ for (unsigned j = 0; j < node.out_edges_.size(); ++j) {
+ Hypergraph::Edge& edge = g->edges_[node.out_edges_[j]];
+ const int lhs = edge.rule_->lhs_;
+ if (edge.rule_->Arity() == 2) {
+ assert(edge.rule_->f_.size() == 2);
+ assert(edge.rule_->e_.size() == 2);
+ unsigned cur = node.id_;
+ int t = -1;
+ assert(edge.tail_nodes_.size() == 2);
+ int rhs = 0;
+ for (unsigned i = 0; i < 2u; ++i) if (edge.tail_nodes_[i] != cur) { t = edge.tail_nodes_[i]; rhs = edge.rule_->f_[i]; }
+ assert(t != -1);
+ edge.tail_nodes_.resize(1);
+ edge.tail_nodes_[0] = t;
+ edge.rule_ = CreateUnaryRule(lhs, rhs);
+ } else {
+ edge.rule_ = CreateEpsilonRule(lhs, eps);
+ edge.tail_nodes_.clear();
+ created_eps = true;
+ }
+ }
+ }
+ }
+ vector<bool> k2(g->edges_.size(), false);
+ g->PruneEdges(k2);
+ if (created_eps) RemoveEpsilons(g, eps);
+}
+
diff --git a/decoder/hg_remove_eps.h b/decoder/hg_remove_eps.h
new file mode 100644
index 00000000..82f06039
--- /dev/null
+++ b/decoder/hg_remove_eps.h
@@ -0,0 +1,13 @@
+#ifndef _HG_REMOVE_EPS_H_
+#define _HG_REMOVE_EPS_H_
+
+#include "wordid.h"
+class Hypergraph;
+
+// This is not a complete implementation of the general algorithm for
+// doing this. It makes a few weird assumptions, for example, that
+// if some nonterminal X rewrites as eps, then that is the only thing
+// that it rewrites as. This needs to be fixed for the general case!
+void RemoveEpsilons(Hypergraph* g, WordID eps);
+
+#endif
diff --git a/decoder/inside_outside.h b/decoder/inside_outside.h
index bb7f9fcc..f73a1d3f 100644
--- a/decoder/inside_outside.h
+++ b/decoder/inside_outside.h
@@ -41,10 +41,6 @@ WeightType Inside(const Hypergraph& hg,
WeightType* const cur_node_inside_score = &inside_score[i];
Hypergraph::EdgesVector const& in=hg.nodes_[i].in_edges_;
const unsigned num_in_edges = in.size();
- if (num_in_edges == 0) {
- *cur_node_inside_score = WeightType(1); //FIXME: why not call weight(edge) instead?
- continue;
- }
for (unsigned j = 0; j < num_in_edges; ++j) {
const Hypergraph::Edge& edge = hg.edges_[in[j]];
WeightType score = weight(edge);
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc
new file mode 100644
index 00000000..10192f7a
--- /dev/null
+++ b/decoder/rescore_translator.cc
@@ -0,0 +1,58 @@
+#include "translator.h"
+
+#include <sstream>
+#include <boost/shared_ptr.hpp>
+
+#include "sentence_metadata.h"
+#include "hg.h"
+#include "hg_io.h"
+#include "tdict.h"
+
+using namespace std;
+
+struct RescoreTranslatorImpl {
+ RescoreTranslatorImpl(const boost::program_options::variables_map& conf) :
+ goal_sym(conf["goal"].as<string>()),
+ kGOAL_RULE(new TRule("[Goal] ||| [" + goal_sym + ",1] ||| [1]")),
+ kGOAL(TD::Convert("Goal") * -1) {
+ }
+
+ bool Translate(const string& input,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ if (input == "{}") return false;
+ if (input.find("{\"rules\"") == 0) {
+ istringstream is(input);
+ Hypergraph src_cfg_hg;
+ if (!HypergraphIO::ReadFromJSON(&is, forest)) {
+ cerr << "Parse error while reading HG from JSON.\n";
+ abort();
+ }
+ } else {
+ cerr << "Can only read HG input from JSON: use training/grammar_convert\n";
+ abort();
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(kGOAL);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ forest->Reweight(weights);
+ return true;
+ }
+
+ const string goal_sym;
+ const TRulePtr kGOAL_RULE;
+ const WordID kGOAL;
+};
+
+RescoreTranslator::RescoreTranslator(const boost::program_options::variables_map& conf) :
+ pimpl_(new RescoreTranslatorImpl(conf)) {}
+
+bool RescoreTranslator::TranslateImpl(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* minus_lm_forest) {
+ smeta->SetSourceLength(0); // don't know how to compute this
+ return pimpl_->Translate(input, weights, minus_lm_forest);
+}
+
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 185f979a..a978cfc2 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -20,7 +20,6 @@
#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
-static bool usingSentenceGrammar = false;
static bool printGrammarsUsed = false;
struct SCFGTranslatorImpl {
@@ -91,31 +90,31 @@ struct SCFGTranslatorImpl {
bool show_tree_structure_;
unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
- GrammarPtr sup_grammar_;
+ set<GrammarPtr> sup_grammars_;
- struct Equals { Equals(const GrammarPtr& v) : v_(v) {}
- bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; };
+ struct ContainedIn {
+ ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {}
+ bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); }
+ const set<GrammarPtr>& gs_;
+ };
- void SetSupplementalGrammar(const std::string& grammar_string) {
- grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end());
+ void AddSupplementalGrammarFromString(const std::string& grammar_string) {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
istringstream in(grammar_string);
- sup_grammar_.reset(new TextGrammar(&in));
- grammars.push_back(sup_grammar_);
+ TextGrammar* sent_grammar = new TextGrammar(&in);
+ sent_grammar->SetMaxSpan(max_span_limit);
+ sent_grammar->SetGrammarName("SupFromString");
+ AddSupplementalGrammar(GrammarPtr(sent_grammar));
}
- struct NameEquals { NameEquals(const string name) : name_(name) {}
- bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; };
+ void AddSupplementalGrammar(GrammarPtr gp) {
+ sup_grammars_.insert(gp);
+ grammars.push_back(gp);
+ }
- void SetSentenceGrammarFromString(const std::string& grammar_str) {
- assert(grammar_str != "");
- if (!SILENT) cerr << "Setting sentence grammar" << endl;
- usingSentenceGrammar = true;
- istringstream in(grammar_str);
- TextGrammar* sent_grammar = new TextGrammar(&in);
- sent_grammar->SetMaxSpan(max_span_limit);
- sent_grammar->SetGrammarName("__psg");
- grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end());
- grammars.push_back(GrammarPtr(sent_grammar));
+ void RemoveSupplementalGrammars() {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
+ sup_grammars_.clear();
}
bool Translate(const string& input,
@@ -300,35 +299,24 @@ Check for grammar pointer in the sentence markup, for use with sentence specific
*/
void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
map<string,string>::const_iterator it = kv.find("grammar");
-
-
- if (it == kv.end()) {
- usingSentenceGrammar= false;
- return;
+ if (it != kv.end()) {
+ TextGrammar* sentGrammar = new TextGrammar(it->second);
+ sentGrammar->SetMaxSpan(pimpl_->max_span_limit);
+ sentGrammar->SetGrammarName(it->second);
+ pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar));
}
- //Create sentence specific grammar from specified file name and load grammar into list of grammars
- usingSentenceGrammar = true;
- TextGrammar* sentGrammar = new TextGrammar(it->second);
- sentGrammar->SetMaxSpan(pimpl_->max_span_limit);
- sentGrammar->SetGrammarName(it->second);
- pimpl_->grammars.push_back(GrammarPtr(sentGrammar));
-
}
-void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) {
- pimpl_->SetSupplementalGrammar(grammar);
+void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) {
+ pimpl_->AddSupplementalGrammarFromString(grammar);
}
-void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) {
- pimpl_->SetSentenceGrammarFromString(grammar_str);
+void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) {
+ pimpl_->AddSupplementalGrammar(grammar);
}
void SCFGTranslator::SentenceCompleteImpl() {
-
- if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars
- {
- pimpl_->grammars.pop_back();
- }
+ pimpl_->RemoveSupplementalGrammars();
}
std::string SCFGTranslator::GetDecoderType() const {
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 54890e85..63e855c8 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -54,6 +54,7 @@ struct TaggerImpl {
const int new_node_id = forest->AddNode(kXCAT)->id_;
for (int k = 0; k < tagset_.size(); ++k) {
TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k]));
+ rule->lhs_ = kXCAT;
Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
edge->i_ = i;
edge->j_ = i+1;
diff --git a/decoder/translator.h b/decoder/translator.h
index cfd3b08a..c0800e84 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -58,8 +58,8 @@ class SCFGTranslatorImpl;
class SCFGTranslator : public Translator {
public:
SCFGTranslator(const boost::program_options::variables_map& conf);
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar);
+ void AddSupplementalGrammar(GrammarPtr gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar);
virtual std::string GetDecoderType() const;
protected:
bool TranslateImpl(const std::string& src,
@@ -85,4 +85,17 @@ class FSTTranslator : public Translator {
boost::shared_ptr<FSTTranslatorImpl> pimpl_;
};
+class RescoreTranslatorImpl;
+class RescoreTranslator : public Translator {
+ public:
+ RescoreTranslator(const boost::program_options::variables_map& conf);
+ private:
+ bool TranslateImpl(const std::string& src,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* minus_lm_forest);
+ private:
+ boost::shared_ptr<RescoreTranslatorImpl> pimpl_;
+};
+
#endif
diff --git a/decoder/trule.cc b/decoder/trule.cc
index 187a003d..896f9f3d 100644
--- a/decoder/trule.cc
+++ b/decoder/trule.cc
@@ -237,9 +237,9 @@ void TRule::ComputeArity() {
string TRule::AsString(bool verbose) const {
ostringstream os;
int idx = 0;
- if (lhs_ && verbose) {
+ if (lhs_) {
os << '[' << TD::Convert(lhs_ * -1) << "] |||";
- }
+ } else { os << "NOLHS |||"; }
for (unsigned i = 0; i < f_.size(); ++i) {
const WordID& w = f_[i];
if (w < 0) {