From 8c16b9384e8b85abf9e51bc670e548bf5cd4a0c5 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Sun, 12 Dec 2010 15:32:19 -0500 Subject: facility for adding programmitically generated grammars --- decoder/Makefile.am | 2 +- decoder/decoder.cc | 4 ++++ decoder/decoder.h | 5 +++++ decoder/grammar.cc | 12 +++++++++++- decoder/grammar.h | 9 +++++++-- decoder/scfg_translator.cc | 19 +++++++++++++++++++ decoder/translator.cc | 4 ++++ decoder/translator.h | 3 +++ 8 files changed, 54 insertions(+), 4 deletions(-) (limited to 'decoder') diff --git a/decoder/Makefile.am b/decoder/Makefile.am index 4c688180..633542f0 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -79,4 +79,4 @@ libcdec_a_SOURCES = \ if GLC # Until we build GLC as a library... libcdec_a_SOURCES += ff_glc.cc -endif \ No newline at end of file +endif diff --git a/decoder/decoder.cc b/decoder/decoder.cc index a21b47c0..3551b584 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -638,6 +638,10 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) { return res; } void Decoder::SetWeights(const vector& weights) { pimpl_->SetWeights(weights); } +void Decoder::SetSupplementalGrammar(const std::string& grammar_string) { + assert(pimpl_->translator->GetDecoderType() == "SCFG"); + static_cast(*pimpl_->translator).SetSupplementalGrammar(grammar_string); +} bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { diff --git a/decoder/decoder.h b/decoder/decoder.h index abaf3740..813400e3 100644 --- a/decoder/decoder.h +++ b/decoder/decoder.h @@ -29,6 +29,11 @@ struct Decoder { void SetId(int id); ~Decoder(); const boost::program_options::variables_map& GetConf() const { return conf; } + + // add grammar rules (currently only supported by SCFG decoders) + // that will be used on subsequent calls to Decode. rules should be in standard + // text format. This function does NOT read from a file. + void SetSupplementalGrammar(const std::string& grammar); private: boost::program_options::variables_map conf; boost::shared_ptr pimpl_; diff --git a/decoder/grammar.cc b/decoder/grammar.cc index bbe2f01a..7e6bbc66 100644 --- a/decoder/grammar.cc +++ b/decoder/grammar.cc @@ -77,6 +77,12 @@ TextGrammar::TextGrammar(const string& file) : ReadFromFile(file); } +TextGrammar::TextGrammar(istream* in) : + max_span_(10), + pimpl_(new TGImpl) { + ReadFromStream(in); +} + const GrammarIter* TextGrammar::GetRoot() const { return &pimpl_->root_; } @@ -107,7 +113,11 @@ static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level void TextGrammar::ReadFromFile(const string& filename) { ReadFile in(filename); - RuleLexer::ReadRules(in.stream(), &AddRuleHelper, this); + ReadFromStream(in.stream()); +} + +void TextGrammar::ReadFromStream(istream* in) { + RuleLexer::ReadRules(in, &AddRuleHelper, this); } bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const { diff --git a/decoder/grammar.h b/decoder/grammar.h index 1173e3cd..f5d00817 100644 --- a/decoder/grammar.h +++ b/decoder/grammar.h @@ -1,13 +1,15 @@ #ifndef GRAMMAR_H_ #define GRAMMAR_H_ +#include #include #include #include #include -#include #include +#include + #include "lattice.h" #include "trule.h" @@ -62,12 +64,14 @@ typedef boost::shared_ptr GrammarPtr; class TGImpl; struct TextGrammar : public Grammar { TextGrammar(); - TextGrammar(const std::string& file); + explicit TextGrammar(const std::string& file); + explicit TextGrammar(std::istream* in); void SetMaxSpan(int m) { max_span_ = m; } virtual const GrammarIter* GetRoot() const; void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr()); void ReadFromFile(const std::string& filename); + void ReadFromStream(std::istream* in); virtual bool HasRuleForSpan(int i, int j, int distance) const; const std::vector& GetUnaryRules(const WordID& cat) const; @@ -92,4 +96,5 @@ struct PassThroughGrammar : public TextGrammar { }; void RefineRule(TRulePtr pt, const unsigned int ctf_level); + #endif diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index afe796a5..a19e9d75 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -91,6 +91,17 @@ struct SCFGTranslatorImpl { bool show_tree_structure_; unsigned int ctf_iterations_; vector grammars; + GrammarPtr sup_grammar_; + + struct Equals { Equals(const GrammarPtr& v) : v_(v) {} + bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; }; + + void SetSupplementalGrammar(const std::string& grammar_string) { + grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end()); + istringstream in(grammar_string); + sup_grammar_.reset(new TextGrammar(&in)); + grammars.push_back(sup_grammar_); + } bool Translate(const string& input, SentenceMetadata* smeta, @@ -290,6 +301,10 @@ void SCFGTranslator::ProcessMarkupHintsImpl(const map& kv) { } +void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) { + pimpl_->SetSupplementalGrammar(grammar); +} + void SCFGTranslator::SentenceCompleteImpl() { if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars @@ -299,3 +314,7 @@ void SCFGTranslator::SentenceCompleteImpl() { } } +std::string SCFGTranslator::GetDecoderType() const { + return "SCFG"; +} + diff --git a/decoder/translator.cc b/decoder/translator.cc index d1ca125b..6ba74030 100644 --- a/decoder/translator.cc +++ b/decoder/translator.cc @@ -9,6 +9,10 @@ using namespace std; Translator::~Translator() {} +std::string Translator::GetDecoderType() const { + return "UNKNOWN"; +} + void Translator::ProcessMarkupHints(const map& kv) { if (state_ != kUninitialized) { cerr << "Translator::ProcessMarkupHints in wrong state: " << state_ << endl; diff --git a/decoder/translator.h b/decoder/translator.h index 6b0a02e4..9d6dd97d 100644 --- a/decoder/translator.h +++ b/decoder/translator.h @@ -39,6 +39,7 @@ class Translator { // Free any sentence-specific resources void SentenceComplete(); + virtual std::string GetDecoderType() const; protected: virtual bool TranslateImpl(const std::string& src, SentenceMetadata* smeta, @@ -55,6 +56,8 @@ class SCFGTranslatorImpl; class SCFGTranslator : public Translator { public: SCFGTranslator(const boost::program_options::variables_map& conf); + void SetSupplementalGrammar(const std::string& grammar); + virtual std::string GetDecoderType() const; protected: bool TranslateImpl(const std::string& src, SentenceMetadata* smeta, -- cgit v1.2.3