summaryrefslogtreecommitdiff
path: root/gi/scfg/abc/agrammar.cc
diff options
context:
space:
mode:
Diffstat (limited to 'gi/scfg/abc/agrammar.cc')
-rw-r--r--gi/scfg/abc/agrammar.cc153
1 files changed, 132 insertions, 21 deletions
diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc
index 585255e3..016a0189 100644
--- a/gi/scfg/abc/agrammar.cc
+++ b/gi/scfg/abc/agrammar.cc
@@ -8,6 +8,18 @@
#include "agrammar.h"
#include "../utils/Util.h"
+
+
+aTRule::aTRule(TRulePtr rule){
+
+ this -> e_ = rule->e_;
+ this -> f_ = rule->f_;
+ this ->lhs_ = rule->lhs_;
+ this -> arity_ = rule->arity_;
+ this -> scores_ = rule->scores_;
+ ResetScore(0.00000001);
+}
+
bool equal(TRulePtr const & rule1, TRulePtr const & rule2){
if (rule1->lhs_ != rule2->lhs_) return false;
if (rule1->f_.size() != rule2->f_.size()) return false;
@@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){
return true;
}
+
//const vector<TRulePtr> Grammar::NO_RULES;
void aRemoveRule(vector<TRulePtr> & v, const TRulePtr & rule){ // remove rule from v if found
for (int i=0; i< v.size(); i++)
if (equal(v[i], rule )){
- cout<<"erase rule from vector:"<<rule->AsString()<<endl;
+ // cout<<"erase rule from vector:"<<rule->AsString()<<endl;
v.erase(v.begin()+i);
}
}
+void aRemoveRule(vector<NTRule> & v, const NTRule & ntrule){ // remove rule from v if found
+ for (int i=0; i< v.size(); i++)
+ if (equal(v[i].rule_, ntrule.rule_ )){
+ // cout<<"erase rule from vector:"<<rule->AsString()<<endl;
+ v.erase(v.begin()+i);
+ }
+}
+
struct aTextRuleBin : public RuleBin {
int GetNumRules() const {
return rules_.size();
@@ -40,20 +61,16 @@ struct aTextRuleBin : public RuleBin {
void AddRule(TRulePtr t) {
rules_.push_back(t);
}
- void RemoveRule(TRulePtr t){
- for (int i=0; i<rules_.size(); i++){
- if (equal(rules_.at(i), t)){
- rules_.erase(rules_.begin() + i);
- //cout<<"IntextRulebin removerulle\n";
- return;
- }
- }
+
+ void RemoveRule(const TRulePtr & rule ){
+ aRemoveRule(rules_, rule);
}
int Arity() const {
return rules_.front()->Arity();
}
+
void Dump() const {
for (int i = 0; i < rules_.size(); ++i)
cerr << rules_[i]->AsString() << endl;
@@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin {
vector<TRulePtr> rules_;
};
+
struct aTextGrammarNode : public GrammarIter {
aTextGrammarNode() : rb_(NULL) {}
~aTextGrammarNode() {
@@ -90,8 +108,8 @@ struct aTGImpl {
aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {}
aTextGrammar::aTextGrammar(const string& file) :
- max_span_(10),
- pimpl_(new aTGImpl) {
+ max_span_(10),
+ pimpl_(new aTGImpl) {
ReadFromFile(file);
}
@@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){
goalID = TD::Convert(goal_str);
}
+
void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){
NTRule lhs_ntrule(rule, rule->lhs_ * -1);
@@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){
NTRule rhs_ntrule(rule, rule->f_.at(i) * -1);
ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule;
}
-
-
}
+
+
void aTextGrammar::AddRule(const TRulePtr& rule) {
if (rule->IsUnary()) {
rhs2unaries_[rule->f().front()].push_back(rule);
@@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) {
}
void aTextGrammar::RemoveRule(const TRulePtr & rule){
- cout<<"Remove rule: "<<rule->AsString()<<endl;
+ // cout<<"Remove rule: "<<rule->AsString()<<endl;
if (rule->IsUnary()) {
aRemoveRule(rhs2unaries_[rule->f().front()], rule);
aRemoveRule(unaries_, rule);
@@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){
aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule);
+
+ //remove the rule from nt_rules_
+ map<WordID, NTRule> ntrule_map;
+ getNTRule (rule, ntrule_map);
+ for (map<WordID,NTRule>::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){
+ aRemoveRule(nt_rules_[it->first], it->second);
+ }
+
}
void aTextGrammar::RemoveNonterminal(WordID wordID){
@@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){
nt_rules_.erase(wordID);
for (int i =0; i<rules.size(); i++)
RemoveRule(rules[i].rule_);
+ sum_probs_.erase(wordID);
+ cnt_rules.erase(wordID);
}
@@ -199,7 +228,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){
map<WordID, int> cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal
- map<WordID, int> cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule
+ map<WordID, int> cnt_minusepsilon;
cnt_addepsilon[old_rule.rule_->lhs_] = 0;
cnt_minusepsilon[old_rule.rule_->lhs_] = 0;
for (int j =0; j<nts.size(); j++) { cnt_addepsilon[nts[j] ] = 0; cnt_minusepsilon[nts[j] ] = 0;}
@@ -217,7 +246,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){
// cout<<"print vector j_vector"<<endl;
// for (int k=0; k<ntPos.size();k++) cout<<j_vector[k]<<" "; cout<<endl;
//now use the vector to create a new rule
- TRulePtr newrule(new TRule());
+ TRulePtr newrule(new aTRule());
newrule -> e_ = (old_rule.rule_)->e_;
newrule -> f_ = old_rule.rule_->f_;
@@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){
if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT
for (int i =0; i<v_splits.size(); i++){
- TRulePtr rule (new TRule());
+ TRulePtr rule (new aTRule());
rule ->lhs_ = goalID * -1;
rule ->f_.push_back(v_splits[i] * -1);
rule->e_.push_back(0);
@@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){
}
+}
+
+void aTextGrammar::splitAllNonterminals(){
+ map<WordID, vector<TRulePtr> >::const_iterator it;
+ vector<WordID> v ; // WordID >0
+ for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts
+ if (it->first != goalID || lhs_rules_.size() ==1)
+ v.push_back(it->first);
+
+ for (int i=0; i< v.size(); i++)
+ splitNonterminal(v[i]);
}
+void aTextGrammar::PrintAllRules(const string & filename) const{
-void aTextGrammar::PrintAllRules() const{
- map<WordID, vector<TRulePtr> >::const_iterator it;
+
+ cerr<<"print grammar to "<<filename<<endl;
+
+ ofstream outfile(filename.c_str());
+ if (!outfile.good()) {
+ cerr << "error opening output file " << filename << endl;
+ exit(1);
+ }
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
vector<TRulePtr> v = it-> second;
for (int i =0; i< v.size(); i++){
- cout<<v[i]->AsString()<<"\t"<<endl;
+ outfile<<v[i]->AsString()<<"\t"<<endl;
+ }
+ }
+}
+
+
+void aTextGrammar::ResetScore(){
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
+ for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
+ vector<TRulePtr> v = it-> second;
+ for (int i =0; i< v.size(); i++){
+ // cerr<<"Reset score of Rule "<<v[i]->AsString()<<endl;
+ boost::static_pointer_cast<aTRule>(v[i])->ResetScore(alpha_ /v.size());
+ }
+ lhs_rules_[it->first] = v;
+ sum_probs_[it->first] = alpha_;
+ }
+
+}
+
+void aTextGrammar::UpdateScore(){
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
+ for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
+ vector<TRulePtr> v = it-> second;
+ for (int i =0; i< v.size(); i++){
+ boost::static_pointer_cast<aTRule>(v[i])->UpdateScore(sum_probs_[it->first] );
}
+
+ // cerr<<"sum_probs_[it->first] ="<<sum_probs_[it->first] <<endl;
+ sum_probs_[it->first] = alpha_;
}
+
+}
+
+
+void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){
+ std::vector<prob_t> posts ;
+
+ prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts);
+ for (int i =0; i<posts.size(); i++){
+
+ //cout<<posts[i]<<endl;
+ Hypergraph::Edge& e = hg.edges_[i];
+ string goalstr("Goal");
+ string str_lhs = TD::Convert(e.rule_->lhs_ * -1);
+
+ if (str_lhs.find(goalstr) != string::npos)
+ continue;
+
+ // cerr<<e.rule_->AsString()<<endl;
+ // cerr<<e.rule_->parent_rule_->AsString()<<endl;
+
+ boost::static_pointer_cast<aTRule>(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score);
+ // cerr<<"add count for rule\n";
+// cerr<<"posts[i]="<<posts[i]<<" goal_score="<<goal_score<<endl;
+// cerr<<"posts[i] /goal_score="<<(posts[i] /goal_score)<<endl;
+ sum_probs_[e.rule_->parent_rule_->lhs_* -1 ] += posts[i] /goal_score;
+
+ }
+
+
}
@@ -364,7 +473,9 @@ void aTextGrammar::PrintNonterminalRules(WordID nt) const{
}
static void AddRuleHelper(const TRulePtr& new_rule, void* extra) {
- static_cast<aTextGrammar*>(extra)->AddRule(new_rule);
+ aTRule *p = new aTRule(new_rule);
+
+ static_cast<aTextGrammar*>(extra)->AddRule(TRulePtr(p));
}
void aTextGrammar::ReadFromFile(const string& filename) {