diff options
Diffstat (limited to 'gi')
-rw-r--r-- | gi/scfg/abc/Release/agrammar.d | 14 | ||||
-rwxr-xr-x | gi/scfg/abc/Release/scfg | bin | 4277125 -> 4437132 bytes | |||
-rw-r--r-- | gi/scfg/abc/Release/scfg.d | 50 | ||||
-rw-r--r-- | gi/scfg/abc/agrammar.cc | 153 | ||||
-rw-r--r-- | gi/scfg/abc/agrammar.h | 54 | ||||
-rw-r--r-- | gi/scfg/abc/scfg.cpp | 213 |
6 files changed, 389 insertions, 95 deletions
diff --git a/gi/scfg/abc/Release/agrammar.d b/gi/scfg/abc/Release/agrammar.d index 6cf14f0d..553752ca 100644 --- a/gi/scfg/abc/Release/agrammar.d +++ b/gi/scfg/abc/Release/agrammar.d @@ -59,7 +59,11 @@ agrammar.d agrammar.o: ../agrammar.cc \ /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/array2d.h ../../utils/Util.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/hg.h \ + /home/tnguyen/ws10smt/decoder/small_vector.h \ + /home/tnguyen/ws10smt/decoder/prob.h \ + /home/tnguyen/ws10smt/decoder/logval.h ../../utils/Util.h \ ../../utils/UtfConverter.h ../../utils/ConvertUTF.h /home/tnguyen/ws10smt/decoder/rule_lexer.h: @@ -186,6 +190,14 @@ agrammar.d agrammar.o: ../agrammar.cc \ /home/tnguyen/ws10smt/decoder/array2d.h: +/home/tnguyen/ws10smt/decoder/hg.h: + +/home/tnguyen/ws10smt/decoder/small_vector.h: + +/home/tnguyen/ws10smt/decoder/prob.h: + +/home/tnguyen/ws10smt/decoder/logval.h: + ../../utils/Util.h: ../../utils/UtfConverter.h: diff --git a/gi/scfg/abc/Release/scfg b/gi/scfg/abc/Release/scfg Binary files differindex 4b6cfb19..41b6b583 100755 --- a/gi/scfg/abc/Release/scfg +++ b/gi/scfg/abc/Release/scfg diff --git a/gi/scfg/abc/Release/scfg.d b/gi/scfg/abc/Release/scfg.d index ae7a87bb..b3cfbbb5 100644 --- a/gi/scfg/abc/Release/scfg.d +++ b/gi/scfg/abc/Release/scfg.d @@ -1,8 +1,4 @@ -scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/wordid.h \ - /home/tnguyen/ws10smt/decoder/array2d.h \ - /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \ +scfg.d scfg.o: ../scfg.cpp \ /export/ws10smt/software/include/boost/shared_ptr.hpp \ /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp \ /export/ws10smt/software/include/boost/config.hpp \ @@ -38,6 +34,12 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/smart_ptr/detail/yield_k.hpp \ /export/ws10smt/software/include/boost/memory_order.hpp \ /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp \ + /export/ws10smt/software/include/boost/pointer_cast.hpp \ + /home/tnguyen/ws10smt/decoder/lattice.h \ + /home/tnguyen/ws10smt/decoder/wordid.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/lattice.h \ /home/tnguyen/ws10smt/decoder/trule.h \ /home/tnguyen/ws10smt/decoder/sparse_vector.h \ @@ -57,27 +59,15 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/functional/hash/detail/hash_float_generic.hpp \ /export/ws10smt/software/include/boost/functional/hash/extensions.hpp \ /export/ws10smt/software/include/boost/detail/container_fwd.hpp \ - /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/hg.h \ /home/tnguyen/ws10smt/decoder/small_vector.h \ /home/tnguyen/ws10smt/decoder/prob.h \ /home/tnguyen/ws10smt/decoder/logval.h \ + /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/hg_intersect.h ../../utils/ParamsArray.h \ ../../utils/Util.h ../../utils/UtfConverter.h ../../utils/ConvertUTF.h -/home/tnguyen/ws10smt/decoder/lattice.h: - -/home/tnguyen/ws10smt/decoder/wordid.h: - -/home/tnguyen/ws10smt/decoder/array2d.h: - -/home/tnguyen/ws10smt/decoder/tdict.h: - -../agrammar.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: - /export/ws10smt/software/include/boost/shared_ptr.hpp: /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp: @@ -148,6 +138,20 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp: +/export/ws10smt/software/include/boost/pointer_cast.hpp: + +/home/tnguyen/ws10smt/decoder/lattice.h: + +/home/tnguyen/ws10smt/decoder/wordid.h: + +/home/tnguyen/ws10smt/decoder/array2d.h: + +/home/tnguyen/ws10smt/decoder/tdict.h: + +../agrammar.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: + /home/tnguyen/ws10smt/decoder/lattice.h: /home/tnguyen/ws10smt/decoder/trule.h: @@ -186,10 +190,6 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/detail/container_fwd.hpp: -/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: - /home/tnguyen/ws10smt/decoder/hg.h: /home/tnguyen/ws10smt/decoder/small_vector.h: @@ -198,6 +198,10 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /home/tnguyen/ws10smt/decoder/logval.h: +/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: + /home/tnguyen/ws10smt/decoder/hg_intersect.h: ../../utils/ParamsArray.h: diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc index 585255e3..016a0189 100644 --- a/gi/scfg/abc/agrammar.cc +++ b/gi/scfg/abc/agrammar.cc @@ -8,6 +8,18 @@ #include "agrammar.h" #include "../utils/Util.h" + + +aTRule::aTRule(TRulePtr rule){ + + this -> e_ = rule->e_; + this -> f_ = rule->f_; + this ->lhs_ = rule->lhs_; + this -> arity_ = rule->arity_; + this -> scores_ = rule->scores_; + ResetScore(0.00000001); +} + bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ if (rule1->lhs_ != rule2->lhs_) return false; if (rule1->f_.size() != rule2->f_.size()) return false; @@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ return true; } + //const vector<TRulePtr> Grammar::NO_RULES; void aRemoveRule(vector<TRulePtr> & v, const TRulePtr & rule){ // remove rule from v if found for (int i=0; i< v.size(); i++) if (equal(v[i], rule )){ - cout<<"erase rule from vector:"<<rule->AsString()<<endl; + // cout<<"erase rule from vector:"<<rule->AsString()<<endl; v.erase(v.begin()+i); } } +void aRemoveRule(vector<NTRule> & v, const NTRule & ntrule){ // remove rule from v if found + for (int i=0; i< v.size(); i++) + if (equal(v[i].rule_, ntrule.rule_ )){ + // cout<<"erase rule from vector:"<<rule->AsString()<<endl; + v.erase(v.begin()+i); + } +} + struct aTextRuleBin : public RuleBin { int GetNumRules() const { return rules_.size(); @@ -40,20 +61,16 @@ struct aTextRuleBin : public RuleBin { void AddRule(TRulePtr t) { rules_.push_back(t); } - void RemoveRule(TRulePtr t){ - for (int i=0; i<rules_.size(); i++){ - if (equal(rules_.at(i), t)){ - rules_.erase(rules_.begin() + i); - //cout<<"IntextRulebin removerulle\n"; - return; - } - } + + void RemoveRule(const TRulePtr & rule ){ + aRemoveRule(rules_, rule); } int Arity() const { return rules_.front()->Arity(); } + void Dump() const { for (int i = 0; i < rules_.size(); ++i) cerr << rules_[i]->AsString() << endl; @@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin { vector<TRulePtr> rules_; }; + struct aTextGrammarNode : public GrammarIter { aTextGrammarNode() : rb_(NULL) {} ~aTextGrammarNode() { @@ -90,8 +108,8 @@ struct aTGImpl { aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {} aTextGrammar::aTextGrammar(const string& file) : - max_span_(10), - pimpl_(new aTGImpl) { + max_span_(10), + pimpl_(new aTGImpl) { ReadFromFile(file); } @@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){ goalID = TD::Convert(goal_str); } + void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){ NTRule lhs_ntrule(rule, rule->lhs_ * -1); @@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){ NTRule rhs_ntrule(rule, rule->f_.at(i) * -1); ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule; } - - } + + void aTextGrammar::AddRule(const TRulePtr& rule) { if (rule->IsUnary()) { rhs2unaries_[rule->f().front()].push_back(rule); @@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) { } void aTextGrammar::RemoveRule(const TRulePtr & rule){ - cout<<"Remove rule: "<<rule->AsString()<<endl; + // cout<<"Remove rule: "<<rule->AsString()<<endl; if (rule->IsUnary()) { aRemoveRule(rhs2unaries_[rule->f().front()], rule); aRemoveRule(unaries_, rule); @@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){ aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule); + + //remove the rule from nt_rules_ + map<WordID, NTRule> ntrule_map; + getNTRule (rule, ntrule_map); + for (map<WordID,NTRule>::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){ + aRemoveRule(nt_rules_[it->first], it->second); + } + } void aTextGrammar::RemoveNonterminal(WordID wordID){ @@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){ nt_rules_.erase(wordID); for (int i =0; i<rules.size(); i++) RemoveRule(rules[i].rule_); + sum_probs_.erase(wordID); + cnt_rules.erase(wordID); } @@ -199,7 +228,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){ map<WordID, int> cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal - map<WordID, int> cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule + map<WordID, int> cnt_minusepsilon; cnt_addepsilon[old_rule.rule_->lhs_] = 0; cnt_minusepsilon[old_rule.rule_->lhs_] = 0; for (int j =0; j<nts.size(); j++) { cnt_addepsilon[nts[j] ] = 0; cnt_minusepsilon[nts[j] ] = 0;} @@ -217,7 +246,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){ // cout<<"print vector j_vector"<<endl; // for (int k=0; k<ntPos.size();k++) cout<<j_vector[k]<<" "; cout<<endl; //now use the vector to create a new rule - TRulePtr newrule(new TRule()); + TRulePtr newrule(new aTRule()); newrule -> e_ = (old_rule.rule_)->e_; newrule -> f_ = old_rule.rule_->f_; @@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){ if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT for (int i =0; i<v_splits.size(); i++){ - TRulePtr rule (new TRule()); + TRulePtr rule (new aTRule()); rule ->lhs_ = goalID * -1; rule ->f_.push_back(v_splits[i] * -1); rule->e_.push_back(0); @@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){ } +} + +void aTextGrammar::splitAllNonterminals(){ + map<WordID, vector<TRulePtr> >::const_iterator it; + vector<WordID> v ; // WordID >0 + for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts + if (it->first != goalID || lhs_rules_.size() ==1) + v.push_back(it->first); + + for (int i=0; i< v.size(); i++) + splitNonterminal(v[i]); } +void aTextGrammar::PrintAllRules(const string & filename) const{ -void aTextGrammar::PrintAllRules() const{ - map<WordID, vector<TRulePtr> >::const_iterator it; + + cerr<<"print grammar to "<<filename<<endl; + + ofstream outfile(filename.c_str()); + if (!outfile.good()) { + cerr << "error opening output file " << filename << endl; + exit(1); + } + + map<WordID, vector<TRulePtr > >::const_iterator it; for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ vector<TRulePtr> v = it-> second; for (int i =0; i< v.size(); i++){ - cout<<v[i]->AsString()<<"\t"<<endl; + outfile<<v[i]->AsString()<<"\t"<<endl; + } + } +} + + +void aTextGrammar::ResetScore(){ + + map<WordID, vector<TRulePtr > >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector<TRulePtr> v = it-> second; + for (int i =0; i< v.size(); i++){ + // cerr<<"Reset score of Rule "<<v[i]->AsString()<<endl; + boost::static_pointer_cast<aTRule>(v[i])->ResetScore(alpha_ /v.size()); + } + lhs_rules_[it->first] = v; + sum_probs_[it->first] = alpha_; + } + +} + +void aTextGrammar::UpdateScore(){ + + map<WordID, vector<TRulePtr > >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector<TRulePtr> v = it-> second; + for (int i =0; i< v.size(); i++){ + boost::static_pointer_cast<aTRule>(v[i])->UpdateScore(sum_probs_[it->first] ); } + + // cerr<<"sum_probs_[it->first] ="<<sum_probs_[it->first] <<endl; + sum_probs_[it->first] = alpha_; } + +} + + +void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){ + std::vector<prob_t> posts ; + + prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts); + for (int i =0; i<posts.size(); i++){ + + //cout<<posts[i]<<endl; + Hypergraph::Edge& e = hg.edges_[i]; + string goalstr("Goal"); + string str_lhs = TD::Convert(e.rule_->lhs_ * -1); + + if (str_lhs.find(goalstr) != string::npos) + continue; + + // cerr<<e.rule_->AsString()<<endl; + // cerr<<e.rule_->parent_rule_->AsString()<<endl; + + boost::static_pointer_cast<aTRule>(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score); + // cerr<<"add count for rule\n"; +// cerr<<"posts[i]="<<posts[i]<<" goal_score="<<goal_score<<endl; +// cerr<<"posts[i] /goal_score="<<(posts[i] /goal_score)<<endl; + sum_probs_[e.rule_->parent_rule_->lhs_* -1 ] += posts[i] /goal_score; + + } + + } @@ -364,7 +473,9 @@ void aTextGrammar::PrintNonterminalRules(WordID nt) const{ } static void AddRuleHelper(const TRulePtr& new_rule, void* extra) { - static_cast<aTextGrammar*>(extra)->AddRule(new_rule); + aTRule *p = new aTRule(new_rule); + + static_cast<aTextGrammar*>(extra)->AddRule(TRulePtr(p)); } void aTextGrammar::ReadFromFile(const string& filename) { diff --git a/gi/scfg/abc/agrammar.h b/gi/scfg/abc/agrammar.h index 8a7186bf..0a8a60ac 100644 --- a/gi/scfg/abc/agrammar.h +++ b/gi/scfg/abc/agrammar.h @@ -2,10 +2,39 @@ #define AGRAMMAR_H_ #include "grammar.h" +#include "hg.h" using namespace std; +class aTRule: public TRule{ + public: + aTRuleTRule : TRule(){ResetScore(0.00000001); } + aTRule(TRulePtr rule_); + + void ResetScore(double initscore){//cerr<<"Reset Score "<<this->AsString()<<endl; + sum_scores_.set_value(FD::Convert("Prob"), initscore);} + void AddProb(double p ){ + // cerr<<"in AddProb p="<<p<<endl; + // cerr<<"prob sumscores ="<<sum_scores_[FD::Convert("Prob")]<<endl; + sum_scores_.add_value(FD::Convert("Prob"), p); + // cerr<<"after AddProb\n"; + } + + void UpdateScore(double sumprob){ + double minuslogp = 0 - log( sum_scores_.value(FD::Convert("Prob")) /sumprob); + if (sumprob< sum_scores_.value(FD::Convert("Prob"))){ + cerr<<"UpdateScore sumprob="<<sumprob<< " sum_scores_.value(FD::Convert(\"Prob\"))="<< sum_scores_.value(FD::Convert("Prob"))<< this->AsString()<<endl; + exit(1); + } + this->scores_.set_value(FD::Convert("MinusLogP"), minuslogp); + + } + private: + SparseVector<double> sum_scores_; +}; + + class aTGImpl; struct NTRule{ @@ -20,17 +49,19 @@ struct NTRule{ for (int i=0; i< rule->f().size(); i++) if (rule->f().at(i) * -1 == nt) ntPos_.push_back(i); + + } TRulePtr rule_; - WordID nt_; //the labelID of the nt (WordID>0); + WordID nt_; //the labelID of the nt (nt_>0); vector<int> ntPos_; //position of nt_ -1: lhs, from 0...f_.size() for nt of f_() //i.e the rules is: NP-> DET NP; if nt_=5 is the labelID of NP then ntPos_ = (-1, 1): the indexes of nonterminal NP - }; + struct aTextGrammar : public Grammar { aTextGrammar(); aTextGrammar(const std::string& file); @@ -46,9 +77,20 @@ struct aTextGrammar : public Grammar { void setMaxSplit(int max_split); void splitNonterminal(WordID wordID); - void PrintAllRules() const; + + void splitAllNonterminals(); + + void PrintAllRules(const string & filename) const; void PrintNonterminalRules(WordID nt) const; void SetGoalNT(const string & goal_str); + + void ResetScore(); + + void UpdateScore(); + + void UpdateHgProsteriorProb(Hypergraph & hg); + + void set_alpha(double alpha){alpha_ = alpha;} private: void RemoveRule(const TRulePtr & rule); @@ -57,9 +99,15 @@ struct aTextGrammar : public Grammar { int max_span_; int max_split_; boost::shared_ptr<aTGImpl> pimpl_; + map <WordID, vector<TRulePtr> > lhs_rules_;// WordID >0 map <WordID, vector<NTRule> > nt_rules_; + map <WordID, double> sum_probs_; + map <WordID, double> cnt_rules; + + double alpha_; + // map<WordID, vector<WordID> > grSplitNonterminals; WordID goalID; }; diff --git a/gi/scfg/abc/scfg.cpp b/gi/scfg/abc/scfg.cpp index 4d094488..b3dbad34 100644 --- a/gi/scfg/abc/scfg.cpp +++ b/gi/scfg/abc/scfg.cpp @@ -1,3 +1,8 @@ +#include <iostream> +#include <fstream> + +#include <boost/shared_ptr.hpp> +#include <boost/pointer_cast.hpp> #include "lattice.h" #include "tdict.h" #include "agrammar.h" @@ -9,13 +14,53 @@ using namespace std; +vector<string> src_corpus; +vector<string> tgt_corpus; + +bool openParallelCorpora(string & input_filename){ + ifstream input_file; + + input_file.open(input_filename.c_str()); + if (!input_file) { + cerr << "Cannot open input file " << input_filename << ". Exiting..." << endl; + return false; + } + + int line =0; + while (!input_file.eof()) { + // get a line of source language data + // cerr<<"new line "<<ctr<<endl; + string str; + + getline(input_file, str); + line++; + if (str.length()==0){ + cerr<<" sentence number "<<line<<" is empty, skip the sentence\n"; + continue; + } + string delimiters("|||"); + + vector<string> v = tokenize(str, delimiters); + + if ( (v.size() != 2) and (v.size() != 3) ) { + cerr<<str<<endl; + cerr<<" source or target sentence is not found in sentence number "<<line<<" , skip the sentence\n"; + continue; + } + + src_corpus.push_back(v[0]); + tgt_corpus.push_back(v[1]); + } + return true; +} + + typedef aTextGrammar aGrammar; aGrammar * load_grammar(string & grammar_filename){ cerr<<"start_load_grammar "<<grammar_filename<<endl; aGrammar * test = new aGrammar(grammar_filename); - return test; } @@ -26,7 +71,6 @@ Lattice convertSentenceToLattice(const string & str){ Lattice lsentence; lsentence.resize(vID.size()); - for (int i=0; i<vID.size(); i++){ lsentence[i].push_back( LatticeArc(vID[i], 0.0, 1) ); @@ -41,6 +85,8 @@ Lattice convertSentenceToLattice(const string & str){ bool parseSentencePair(const string & goal_sym, const string & src, const string & tgt, GrammarPtr & g, Hypergraph &hg){ + + // cout<<" Start parse the sentence pairs\n"<<endl; Lattice lsource = convertSentenceToLattice(src); //parse the source sentence by the grammar @@ -51,7 +97,7 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string if (!parser.Parse(lsource, &hg)){ - cerr<<"source sentence does not parse by the grammar!"<<endl; + cerr<<"source sentence is not parsed by the grammar!"<<endl; return false; } @@ -59,8 +105,15 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string Lattice ltarget = convertSentenceToLattice(tgt); //forest.PrintGraphviz(); - return HG::Intersect(ltarget, & hg); + if (!HG::Intersect(ltarget, & hg)) return false; + + SparseVector<double> reweight; + + reweight.set_value(FD::Convert("MinusLogP"), -1 ); + hg.Reweight(reweight); + return true; + } @@ -71,74 +124,140 @@ int main(int argc, char** argv){ ParamsArray params(argc, argv); params.setDescription("scfg models"); - params.addConstraint("grammar_file", "grammar file ", true); // optional + params.addConstraint("grammar_file", "grammar file (default ./grammar.pr )", true); // optional + + params.addConstraint("input_file", "parallel input file (default ./parallel_corpora)", true); //optional + + params.addConstraint("output_file", "grammar output file (default ./grammar_output)", true); //optional + + params.addConstraint("goal_symbol", "top nonterminal symbol (default: X)", true); //optional + + params.addConstraint("split", "split one nonterminal into 'split' nonterminals (default: 2)", true); //optional - params.addConstraint("input_file", "parallel input file", true); //optional + params.addConstraint("prob_iters", "number of iterations (default: 10)", true); //optional + + params.addConstraint("split_iters", "number of splitting iterations (default: 3)", true); //optional + + params.addConstraint("alpha", "alpha (default: 0.1)", true); //optional if (!params.runConstraints("scfg")) { return 0; } cerr<<"get parametters\n\n\n"; - string input_file = params.asString("input_file", "parallel_corpora"); string grammar_file = params.asString("grammar_file", "./grammar.pr"); + string input_file = params.asString("input_file", "parallel_corpora"); - string src = "el gato ."; - - string tgt = "the cat ."; - - - string goal_sym = "X"; - srand(123); - /*load grammar*/ + string output_file = params.asString("output_file", "grammar_output"); + string goal_sym = params.asString("goal_symbol", "X"); + int max_split = atoi(params.asString("split", "2").c_str()); + + int prob_iters = atoi(params.asString("prob_iters", "2").c_str()); + int split_iters = atoi(params.asString("split_iters", "1").c_str()); + double alpha = atof(params.asString("alpha", ".001").c_str()); + + ///// + cerr<<"grammar_file ="<<grammar_file<<endl; + cerr<<"input_file ="<< input_file<<endl; + cerr<<"output_file ="<< output_file<<endl; + cerr<<"goal_sym ="<< goal_sym<<endl; + cerr<<"max_split ="<< max_split<<endl; + cerr<<"prob_iters ="<< prob_iters<<endl; + cerr<<"split_iters ="<< split_iters<<endl; + cerr<<"alpha ="<< alpha<<endl; + ////////////////////////// + + cerr<<"\n\nLoad parallel corpus...\n"; + if (! openParallelCorpora(input_file)) + exit(1); + + cerr<<"Load grammar file ...\n"; aGrammar * agrammar = load_grammar(grammar_file); agrammar->SetGoalNT(goal_sym); - cout<<"before split nonterminal"<<endl; - GrammarPtr g( agrammar); + agrammar->setMaxSplit(max_split); + agrammar->set_alpha(alpha); + srand(123); + + GrammarPtr g( agrammar); Hypergraph hg; - if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg.PrintGraphviz(); + int data_size = src_corpus.size(); + for (int i =0; i <split_iters; i++){ + + cerr<<"Split Nonterminals, iteration "<<(i+1)<<endl; + agrammar->PrintAllRules(output_file+".s" + itos(i+1)); + agrammar->splitAllNonterminals(); + + //vector<string> src_corpus; + //vector<string> tgt_corpus; + + for (int j=0; j<prob_iters; j++){ + cerr<<"reset grammar score\n"; + agrammar->ResetScore(); + // cerr<<"done reset grammar score\n"; + for (int k=0; k <data_size; k++){ + string src = src_corpus[k]; + + string tgt = tgt_corpus[k]; + cerr <<"parse sentence pair: "<<src<<" ||| "<<tgt<<endl; + + if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ + cerr<<"target sentence is not parsed by the grammar!\n"; + //return 1; + continue; + + } + cerr<<"update edge posterior prob"<<endl; + boost::static_pointer_cast<aGrammar>(g)->UpdateHgProsteriorProb(hg); + hg.clear(); + } + boost::static_pointer_cast<aGrammar>(g)->UpdateScore(); + } + boost::static_pointer_cast<aGrammar>(g)->PrintAllRules(output_file+".e" + itos(i+1)); + } - if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg.PrintGraphviz(); - //hg.clear(); - if (1==1) return 1; + + + - agrammar->PrintAllRules(); - /*split grammar*/ - cout<<"split NTs\n"; - cerr<<"first of all write all nonterminals"<<endl; - // agrammar->printAllNonterminals(); - agrammar->setMaxSplit(2); - agrammar->splitNonterminal(4); - cout<<"after split nonterminal"<<endl; - agrammar->PrintAllRules(); - Hypergraph hg1; - if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg1.PrintGraphviz(); + // // agrammar->ResetScore(); + // // agrammar->UpdateScore(); + // if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ + // cerr<<"target sentence is not parsed by the grammar!\n"; + // return 1; + + // } + // // hg.PrintGraphviz(); + // //hg.clear(); + + // agrammar->PrintAllRules(); + // /*split grammar*/ + // cout<<"split NTs\n"; + // cerr<<"first of all write all nonterminals"<<endl; + // // agrammar->printAllNonterminals(); + // cout<<"after split nonterminal"<<endl; + // agrammar->PrintAllRules(); + // Hypergraph hg1; + // if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){ + // cerr<<"target sentence is not parsed by the grammar!\n"; + // return 1; + + // } + + // hg1.PrintGraphviz(); - agrammar->splitNonterminal(15); - cout<<"after split nonterminal"<<TD::Convert(15)<<endl; - agrammar->PrintAllRules(); + // agrammar->splitNonterminal(15); + // cout<<"after split nonterminal"<<TD::Convert(15)<<endl; + // agrammar->PrintAllRules(); /*load training corpus*/ |