diff options
author | linh.kitty <linh.kitty@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 17:44:44 +0000 |
---|---|---|
committer | linh.kitty <linh.kitty@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-16 17:44:44 +0000 |
commit | 15a587e247dc0954de27e2627f5511126243943d (patch) | |
tree | 3fbcdfc8416814c528d1c686c10f757f798c8e1e /gi/scfg/abc/agrammar.cc | |
parent | 3de962259255a3621f5cd150e805c7dbcf7a7666 (diff) |
add
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@286 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/scfg/abc/agrammar.cc')
-rw-r--r-- | gi/scfg/abc/agrammar.cc | 153 |
1 files changed, 132 insertions, 21 deletions
diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc index 585255e3..016a0189 100644 --- a/gi/scfg/abc/agrammar.cc +++ b/gi/scfg/abc/agrammar.cc @@ -8,6 +8,18 @@ #include "agrammar.h" #include "../utils/Util.h" + + +aTRule::aTRule(TRulePtr rule){ + + this -> e_ = rule->e_; + this -> f_ = rule->f_; + this ->lhs_ = rule->lhs_; + this -> arity_ = rule->arity_; + this -> scores_ = rule->scores_; + ResetScore(0.00000001); +} + bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ if (rule1->lhs_ != rule2->lhs_) return false; if (rule1->f_.size() != rule2->f_.size()) return false; @@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ return true; } + //const vector<TRulePtr> Grammar::NO_RULES; void aRemoveRule(vector<TRulePtr> & v, const TRulePtr & rule){ // remove rule from v if found for (int i=0; i< v.size(); i++) if (equal(v[i], rule )){ - cout<<"erase rule from vector:"<<rule->AsString()<<endl; + // cout<<"erase rule from vector:"<<rule->AsString()<<endl; v.erase(v.begin()+i); } } +void aRemoveRule(vector<NTRule> & v, const NTRule & ntrule){ // remove rule from v if found + for (int i=0; i< v.size(); i++) + if (equal(v[i].rule_, ntrule.rule_ )){ + // cout<<"erase rule from vector:"<<rule->AsString()<<endl; + v.erase(v.begin()+i); + } +} + struct aTextRuleBin : public RuleBin { int GetNumRules() const { return rules_.size(); @@ -40,20 +61,16 @@ struct aTextRuleBin : public RuleBin { void AddRule(TRulePtr t) { rules_.push_back(t); } - void RemoveRule(TRulePtr t){ - for (int i=0; i<rules_.size(); i++){ - if (equal(rules_.at(i), t)){ - rules_.erase(rules_.begin() + i); - //cout<<"IntextRulebin removerulle\n"; - return; - } - } + + void RemoveRule(const TRulePtr & rule ){ + aRemoveRule(rules_, rule); } int Arity() const { return rules_.front()->Arity(); } + void Dump() const { for (int i = 0; i < rules_.size(); ++i) cerr << rules_[i]->AsString() << endl; @@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin { vector<TRulePtr> rules_; }; + struct aTextGrammarNode : public GrammarIter { aTextGrammarNode() : rb_(NULL) {} ~aTextGrammarNode() { @@ -90,8 +108,8 @@ struct aTGImpl { aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {} aTextGrammar::aTextGrammar(const string& file) : - max_span_(10), - pimpl_(new aTGImpl) { + max_span_(10), + pimpl_(new aTGImpl) { ReadFromFile(file); } @@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){ goalID = TD::Convert(goal_str); } + void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){ NTRule lhs_ntrule(rule, rule->lhs_ * -1); @@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){ NTRule rhs_ntrule(rule, rule->f_.at(i) * -1); ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule; } - - } + + void aTextGrammar::AddRule(const TRulePtr& rule) { if (rule->IsUnary()) { rhs2unaries_[rule->f().front()].push_back(rule); @@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) { } void aTextGrammar::RemoveRule(const TRulePtr & rule){ - cout<<"Remove rule: "<<rule->AsString()<<endl; + // cout<<"Remove rule: "<<rule->AsString()<<endl; if (rule->IsUnary()) { aRemoveRule(rhs2unaries_[rule->f().front()], rule); aRemoveRule(unaries_, rule); @@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){ aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule); + + //remove the rule from nt_rules_ + map<WordID, NTRule> ntrule_map; + getNTRule (rule, ntrule_map); + for (map<WordID,NTRule>::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){ + aRemoveRule(nt_rules_[it->first], it->second); + } + } void aTextGrammar::RemoveNonterminal(WordID wordID){ @@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){ nt_rules_.erase(wordID); for (int i =0; i<rules.size(); i++) RemoveRule(rules[i].rule_); + sum_probs_.erase(wordID); + cnt_rules.erase(wordID); } @@ -199,7 +228,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){ map<WordID, int> cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal - map<WordID, int> cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule + map<WordID, int> cnt_minusepsilon; cnt_addepsilon[old_rule.rule_->lhs_] = 0; cnt_minusepsilon[old_rule.rule_->lhs_] = 0; for (int j =0; j<nts.size(); j++) { cnt_addepsilon[nts[j] ] = 0; cnt_minusepsilon[nts[j] ] = 0;} @@ -217,7 +246,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){ // cout<<"print vector j_vector"<<endl; // for (int k=0; k<ntPos.size();k++) cout<<j_vector[k]<<" "; cout<<endl; //now use the vector to create a new rule - TRulePtr newrule(new TRule()); + TRulePtr newrule(new aTRule()); newrule -> e_ = (old_rule.rule_)->e_; newrule -> f_ = old_rule.rule_->f_; @@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){ if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT for (int i =0; i<v_splits.size(); i++){ - TRulePtr rule (new TRule()); + TRulePtr rule (new aTRule()); rule ->lhs_ = goalID * -1; rule ->f_.push_back(v_splits[i] * -1); rule->e_.push_back(0); @@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){ } +} + +void aTextGrammar::splitAllNonterminals(){ + map<WordID, vector<TRulePtr> >::const_iterator it; + vector<WordID> v ; // WordID >0 + for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts + if (it->first != goalID || lhs_rules_.size() ==1) + v.push_back(it->first); + + for (int i=0; i< v.size(); i++) + splitNonterminal(v[i]); } +void aTextGrammar::PrintAllRules(const string & filename) const{ -void aTextGrammar::PrintAllRules() const{ - map<WordID, vector<TRulePtr> >::const_iterator it; + + cerr<<"print grammar to "<<filename<<endl; + + ofstream outfile(filename.c_str()); + if (!outfile.good()) { + cerr << "error opening output file " << filename << endl; + exit(1); + } + + map<WordID, vector<TRulePtr > >::const_iterator it; for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ vector<TRulePtr> v = it-> second; for (int i =0; i< v.size(); i++){ - cout<<v[i]->AsString()<<"\t"<<endl; + outfile<<v[i]->AsString()<<"\t"<<endl; + } + } +} + + +void aTextGrammar::ResetScore(){ + + map<WordID, vector<TRulePtr > >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector<TRulePtr> v = it-> second; + for (int i =0; i< v.size(); i++){ + // cerr<<"Reset score of Rule "<<v[i]->AsString()<<endl; + boost::static_pointer_cast<aTRule>(v[i])->ResetScore(alpha_ /v.size()); + } + lhs_rules_[it->first] = v; + sum_probs_[it->first] = alpha_; + } + +} + +void aTextGrammar::UpdateScore(){ + + map<WordID, vector<TRulePtr > >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector<TRulePtr> v = it-> second; + for (int i =0; i< v.size(); i++){ + boost::static_pointer_cast<aTRule>(v[i])->UpdateScore(sum_probs_[it->first] ); } + + // cerr<<"sum_probs_[it->first] ="<<sum_probs_[it->first] <<endl; + sum_probs_[it->first] = alpha_; } + +} + + +void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){ + std::vector<prob_t> posts ; + + prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts); + for (int i =0; i<posts.size(); i++){ + + //cout<<posts[i]<<endl; + Hypergraph::Edge& e = hg.edges_[i]; + string goalstr("Goal"); + string str_lhs = TD::Convert(e.rule_->lhs_ * -1); + + if (str_lhs.find(goalstr) != string::npos) + continue; + + // cerr<<e.rule_->AsString()<<endl; + // cerr<<e.rule_->parent_rule_->AsString()<<endl; + + boost::static_pointer_cast<aTRule>(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score); + // cerr<<"add count for rule\n"; +// cerr<<"posts[i]="<<posts[i]<<" goal_score="<<goal_score<<endl; +// cerr<<"posts[i] /goal_score="<<(posts[i] /goal_score)<<endl; + sum_probs_[e.rule_->parent_rule_->lhs_* -1 ] += posts[i] /goal_score; + + } + + } @@ -364,7 +473,9 @@ void aTextGrammar::PrintNonterminalRules(WordID nt) const{ } static void AddRuleHelper(const TRulePtr& new_rule, void* extra) { - static_cast<aTextGrammar*>(extra)->AddRule(new_rule); + aTRule *p = new aTRule(new_rule); + + static_cast<aTextGrammar*>(extra)->AddRule(TRulePtr(p)); } void aTextGrammar::ReadFromFile(const string& filename) { |