summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/decoder.cc4
-rw-r--r--decoder/decoder.h5
-rw-r--r--decoder/grammar.cc12
-rw-r--r--decoder/grammar.h9
-rw-r--r--decoder/scfg_translator.cc19
-rw-r--r--decoder/translator.cc4
-rw-r--r--decoder/translator.h3
8 files changed, 54 insertions, 4 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index 4c688180..633542f0 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -79,4 +79,4 @@ libcdec_a_SOURCES = \
if GLC
# Until we build GLC as a library...
libcdec_a_SOURCES += ff_glc.cc
-endif \ No newline at end of file
+endif
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index a21b47c0..3551b584 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -638,6 +638,10 @@ bool Decoder::Decode(const string& input, DecoderObserver* o) {
return res;
}
void Decoder::SetWeights(const vector<double>& weights) { pimpl_->SetWeights(weights); }
+void Decoder::SetSupplementalGrammar(const std::string& grammar_string) {
+ assert(pimpl_->translator->GetDecoderType() == "SCFG");
+ static_cast<SCFGTranslator&>(*pimpl_->translator).SetSupplementalGrammar(grammar_string);
+}
bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
diff --git a/decoder/decoder.h b/decoder/decoder.h
index abaf3740..813400e3 100644
--- a/decoder/decoder.h
+++ b/decoder/decoder.h
@@ -29,6 +29,11 @@ struct Decoder {
void SetId(int id);
~Decoder();
const boost::program_options::variables_map& GetConf() const { return conf; }
+
+ // add grammar rules (currently only supported by SCFG decoders)
+ // that will be used on subsequent calls to Decode. rules should be in standard
+ // text format. This function does NOT read from a file.
+ void SetSupplementalGrammar(const std::string& grammar);
private:
boost::program_options::variables_map conf;
boost::shared_ptr<DecoderImpl> pimpl_;
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index bbe2f01a..7e6bbc66 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -77,6 +77,12 @@ TextGrammar::TextGrammar(const string& file) :
ReadFromFile(file);
}
+TextGrammar::TextGrammar(istream* in) :
+ max_span_(10),
+ pimpl_(new TGImpl) {
+ ReadFromStream(in);
+}
+
const GrammarIter* TextGrammar::GetRoot() const {
return &pimpl_->root_;
}
@@ -107,7 +113,11 @@ static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level
void TextGrammar::ReadFromFile(const string& filename) {
ReadFile in(filename);
- RuleLexer::ReadRules(in.stream(), &AddRuleHelper, this);
+ ReadFromStream(in.stream());
+}
+
+void TextGrammar::ReadFromStream(istream* in) {
+ RuleLexer::ReadRules(in, &AddRuleHelper, this);
}
bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {
diff --git a/decoder/grammar.h b/decoder/grammar.h
index 1173e3cd..f5d00817 100644
--- a/decoder/grammar.h
+++ b/decoder/grammar.h
@@ -1,13 +1,15 @@
#ifndef GRAMMAR_H_
#define GRAMMAR_H_
+#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
-#include <boost/shared_ptr.hpp>
#include <string>
+#include <boost/shared_ptr.hpp>
+
#include "lattice.h"
#include "trule.h"
@@ -62,12 +64,14 @@ typedef boost::shared_ptr<Grammar> GrammarPtr;
class TGImpl;
struct TextGrammar : public Grammar {
TextGrammar();
- TextGrammar(const std::string& file);
+ explicit TextGrammar(const std::string& file);
+ explicit TextGrammar(std::istream* in);
void SetMaxSpan(int m) { max_span_ = m; }
virtual const GrammarIter* GetRoot() const;
void AddRule(const TRulePtr& rule, const unsigned int ctf_level=0, const TRulePtr& coarse_parent=TRulePtr());
void ReadFromFile(const std::string& filename);
+ void ReadFromStream(std::istream* in);
virtual bool HasRuleForSpan(int i, int j, int distance) const;
const std::vector<TRulePtr>& GetUnaryRules(const WordID& cat) const;
@@ -92,4 +96,5 @@ struct PassThroughGrammar : public TextGrammar {
};
void RefineRule(TRulePtr pt, const unsigned int ctf_level);
+
#endif
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index afe796a5..a19e9d75 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -91,6 +91,17 @@ struct SCFGTranslatorImpl {
bool show_tree_structure_;
unsigned int ctf_iterations_;
vector<GrammarPtr> grammars;
+ GrammarPtr sup_grammar_;
+
+ struct Equals { Equals(const GrammarPtr& v) : v_(v) {}
+ bool operator()(const GrammarPtr& x) const { return x == v_; } const GrammarPtr& v_; };
+
+ void SetSupplementalGrammar(const std::string& grammar_string) {
+ grammars.erase(remove_if(grammars.begin(), grammars.end(), Equals(sup_grammar_)), grammars.end());
+ istringstream in(grammar_string);
+ sup_grammar_.reset(new TextGrammar(&in));
+ grammars.push_back(sup_grammar_);
+ }
bool Translate(const string& input,
SentenceMetadata* smeta,
@@ -290,6 +301,10 @@ void SCFGTranslator::ProcessMarkupHintsImpl(const map<string, string>& kv) {
}
+void SCFGTranslator::SetSupplementalGrammar(const std::string& grammar) {
+ pimpl_->SetSupplementalGrammar(grammar);
+}
+
void SCFGTranslator::SentenceCompleteImpl() {
if(usingSentenceGrammar) // Drop the last sentence grammar from the list of grammars
@@ -299,3 +314,7 @@ void SCFGTranslator::SentenceCompleteImpl() {
}
}
+std::string SCFGTranslator::GetDecoderType() const {
+ return "SCFG";
+}
+
diff --git a/decoder/translator.cc b/decoder/translator.cc
index d1ca125b..6ba74030 100644
--- a/decoder/translator.cc
+++ b/decoder/translator.cc
@@ -9,6 +9,10 @@ using namespace std;
Translator::~Translator() {}
+std::string Translator::GetDecoderType() const {
+ return "UNKNOWN";
+}
+
void Translator::ProcessMarkupHints(const map<string, string>& kv) {
if (state_ != kUninitialized) {
cerr << "Translator::ProcessMarkupHints in wrong state: " << state_ << endl;
diff --git a/decoder/translator.h b/decoder/translator.h
index 6b0a02e4..9d6dd97d 100644
--- a/decoder/translator.h
+++ b/decoder/translator.h
@@ -39,6 +39,7 @@ class Translator {
// Free any sentence-specific resources
void SentenceComplete();
+ virtual std::string GetDecoderType() const;
protected:
virtual bool TranslateImpl(const std::string& src,
SentenceMetadata* smeta,
@@ -55,6 +56,8 @@ class SCFGTranslatorImpl;
class SCFGTranslator : public Translator {
public:
SCFGTranslator(const boost::program_options::variables_map& conf);
+ void SetSupplementalGrammar(const std::string& grammar);
+ virtual std::string GetDecoderType() const;
protected:
bool TranslateImpl(const std::string& src,
SentenceMetadata* smeta,