summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/decoder.cc10
-rw-r--r--decoder/decoder.h6
-rw-r--r--decoder/scfg_translator.cc63
-rw-r--r--decoder/translator.h4
-rw-r--r--dtrain/dtrain.cc4
-rw-r--r--python/src/_cdec.pyx2
6 files changed, 38 insertions, 51 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 333f0fb6..ad4e9e07 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -743,16 +743,14 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
}
vector<weight_t>& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); }
const vector<weight_t>& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); }
-void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
- assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
+void Decoder::AddSupplementalGrammar(GrammarPtr gp) {
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammar(gp);
}
-void Decoder::SetSentenceGrammarFromString(const std::string& grammar_str) {
+void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) {
assert(pimpl_->translator->GetDecoderType() == "SCFG");
- static_cast<SCFGTranslator&>(*pimpl_->translator).SetSentenceGrammarFromString(grammar_str);
+ static_cast<SCFGTranslator&>(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string);
}
-
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
string buf = input;
NgramCache::Clear(); // clear ngram cache for remote LM (if used)
diff --git a/decoder/decoder.h b/decoder/decoder.h
index 6b2f7b16..bef2ff5e 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -37,6 +37,8 @@ struct DecoderObserver {
virtual void NotifyDecodingComplete(const SentenceMetadata& smeta);
};
+struct Grammar; // TODO once the decoder interface is cleaned up,
+ // this should be somewhere else
struct Decoder {
Decoder(int argc, char** argv);
Decoder(std::istream* config_file);
@@ -54,8 +56,8 @@ struct Decoder {
// add grammar rules (currently only supported by SCFG decoders)
// that will be used on subsequent calls to Decode. rules should be in standard
// text format. This function does NOT read from a file.
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar_str);
+ void AddSupplementalGrammar(boost::shared_ptr<Grammar> gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar_string);
private:
boost::program_options::variables_map conf;
boost::shared_ptr<DecoderImpl> pimpl_;
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 185f979a..aaa6c40b 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -20,7 +20,6 @@
#define reverse_foreach BOOST_REVERSE_FOREACH
using namespace std;
-static bool usingSentenceGrammar = false;
static bool printGrammarsUsed = false;
struct SCFGTranslatorImpl {
@@ -91,31 +90,31 @@ struct SCFGTranslatorImpl {
bool show_tree_structure_;
unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
- GrammarPtr sup_grammar_;
+ set<GrammarPtr> sup_grammars_;
- struct Equals { Equals(const GrammarPtr& v) : v_(v) {}
- bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; };
+ struct ContainedIn {
+ ContainedIn(const set<GrammarPtr>& gs) : gs_(gs) {}
+ bool operator()(const GrammarPtr& x) const { return gs_.find(x) != gs_.end(); }
+ const set<GrammarPtr>& gs_;
+ };
- void SetSupplementalGrammar(const std::string& grammar_string) {
- grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end());
+ void AddSupplementalGrammarFromString(const std::string& grammar_string) {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
istringstream in(grammar_string);
- sup_grammar_.reset(new TextGrammar(&in));
- grammars.push_back(sup_grammar_);
+ TextGrammar* sent_grammar = new TextGrammar(&in);
+ sent_grammar->SetMaxSpan(max_span_limit);
+ sent_grammar->SetGrammarName("SupFromString");
+ AddSupplementalGrammar(GrammarPtr(sent_grammar));
}
- struct NameEquals { NameEquals(const string name) : name_(name) {}
- bool operator()(const GrammarPtr& x) const { return x->GetGrammarName() == name_; } const string name_; };
+ void AddSupplementalGrammar(GrammarPtr gp) {
+ sup_grammars_.insert(gp);
+ grammars.push_back(gp);
+ }
- void SetSentenceGrammarFromString(const std::string& grammar_str) {
- assert(grammar_str != "");
- if (!SILENT) cerr << "Setting sentence grammar" << endl;
- usingSentenceGrammar = true;
- istringstream in(grammar_str);
- TextGrammar* sent_grammar = new TextGrammar(&in);
- sent_grammar->SetMaxSpan(max_span_limit);
- sent_grammar->SetGrammarName("__psg");
- grammars.erase(remove_if(grammars.begin(), grammars.end(), NameEquals("__psg")), grammars.end());
- grammars.push_back(GrammarPtr(sent_grammar));
+ void RemoveSupplementalGrammars() {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), ContainedIn(sup_grammars_)), grammars.end());
+ sup_grammars_.clear();
}
bool Translate(const string& input,
@@ -301,34 +300,22 @@ Check for grammar pointer in the sentence markup, for use with sentence specific
void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
map<string,string>::const_iterator it = kv.find("grammar");
-
- if (it == kv.end()) {
- usingSentenceGrammar= false;
- return;
- }
- //Create sentence specific grammar from specified file name and load grammar into list of grammars
- usingSentenceGrammar = true;
TextGrammar* sentGrammar = new TextGrammar(it->second);
sentGrammar->SetMaxSpan(pimpl_->max_span_limit);
sentGrammar->SetGrammarName(it->second);
- pimpl_->grammars.push_back(GrammarPtr(sentGrammar));
-
+ pimpl_->AddSupplementalGrammar(GrammarPtr(sentGrammar));
}
-void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) {
- pimpl_->SetSupplementalGrammar(grammar);
+void SCFGTranslator::AddSupplementalGrammarFromString(const std::string& grammar) {
+ pimpl_->AddSupplementalGrammarFromString(grammar);
}
-void SCFGTranslator::SetSentenceGrammarFromString(const std::string& grammar_str) {
- pimpl_->SetSentenceGrammarFromString(grammar_str);
+void SCFGTranslator::AddSupplementalGrammar(GrammarPtr grammar) {
+ pimpl_->AddSupplementalGrammar(grammar);
}
void SCFGTranslator::SentenceCompleteImpl() {
-
- if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars
- {
- pimpl_->grammars.pop_back();
- }
+ pimpl_->RemoveSupplementalGrammars();
}
std::string SCFGTranslator::GetDecoderType() const {
diff --git a/decoder/translator.h b/decoder/translator.h
index cfd3b08a..fc2bb760 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -58,8 +58,8 @@ class SCFGTranslatorImpl;
class SCFGTranslator : public Translator {
public:
SCFGTranslator(const boost::program_options::variables_map& conf);
- void SetSupplementalGrammar(const std::string& grammar);
- void SetSentenceGrammarFromString(const std::string& grammar);
+ void AddSupplementalGrammar(GrammarPtr gp);
+ void AddSupplementalGrammarFromString(const std::string& grammar);
virtual std::string GetDecoderType() const;
protected:
bool TranslateImpl(const std::string& src,
diff --git a/dtrain/dtrain.cc b/dtrain/dtrain.cc
index b3e62914..b7a4bb6f 100644
--- a/dtrain/dtrain.cc
+++ b/dtrain/dtrain.cc
@@ -364,7 +364,7 @@ main(int argc, char** argv)
boost::replace_all(in, "\t", "\n");
in += "\n";
grammar_buf_out << in << DTRAIN_GRAMMAR_DELIM << " " << in_split[0] << endl;
- decoder.SetSentenceGrammarFromString(in);
+ decoder.AddSupplementalGrammarFromString(in);
src_str_buf.push_back(in_split[1]);
// decode
observer->SetRef(ref_ids);
@@ -378,7 +378,7 @@ main(int argc, char** argv)
if (boost::starts_with(rule, DTRAIN_GRAMMAR_DELIM)) break;
grammar_str += rule + "\n";
}
- decoder.SetSentenceGrammarFromString(grammar_str);
+ decoder.AddSupplementalGrammarFromString(grammar_str);
// decode
observer->SetRef(ref_ids_buf[ii]);
decoder.Decode(src_str_buf[ii], observer);
diff --git a/python/src/_cdec.pyx b/python/src/_cdec.pyx
index 879e8b72..164d6570 100644
--- a/python/src/_cdec.pyx
+++ b/python/src/_cdec.pyx
@@ -98,7 +98,7 @@ cdef class Decoder:
else:
raise TypeError('Cannot translate input type %s' % type(sentence))
if grammar:
- self.dec.SetSentenceGrammarFromString(string(<char *> grammar))
+ self.dec.AddSupplementalGrammarFromString(string(<char *> grammar))
cdef decoder.BasicObserver observer = decoder.BasicObserver()
self.dec.Decode(string(<char *>inp), &observer)
if observer.hypergraph == NULL: