diff options
Diffstat (limited to 'gi')
| -rw-r--r-- | gi/scfg/abc/Release/agrammar.d | 14 | ||||
| -rwxr-xr-x | gi/scfg/abc/Release/scfg | bin | 4277125 -> 4437132 bytes | |||
| -rw-r--r-- | gi/scfg/abc/Release/scfg.d | 50 | ||||
| -rw-r--r-- | gi/scfg/abc/agrammar.cc | 153 | ||||
| -rw-r--r-- | gi/scfg/abc/agrammar.h | 54 | ||||
| -rw-r--r-- | gi/scfg/abc/scfg.cpp | 213 | 
6 files changed, 389 insertions, 95 deletions
| diff --git a/gi/scfg/abc/Release/agrammar.d b/gi/scfg/abc/Release/agrammar.d index 6cf14f0d..553752ca 100644 --- a/gi/scfg/abc/Release/agrammar.d +++ b/gi/scfg/abc/Release/agrammar.d @@ -59,7 +59,11 @@ agrammar.d agrammar.o: ../agrammar.cc \   /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \   /home/tnguyen/ws10smt/decoder/grammar.h \   /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/array2d.h ../../utils/Util.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/hg.h \ + /home/tnguyen/ws10smt/decoder/small_vector.h \ + /home/tnguyen/ws10smt/decoder/prob.h \ + /home/tnguyen/ws10smt/decoder/logval.h ../../utils/Util.h \   ../../utils/UtfConverter.h ../../utils/ConvertUTF.h  /home/tnguyen/ws10smt/decoder/rule_lexer.h: @@ -186,6 +190,14 @@ agrammar.d agrammar.o: ../agrammar.cc \  /home/tnguyen/ws10smt/decoder/array2d.h: +/home/tnguyen/ws10smt/decoder/hg.h: + +/home/tnguyen/ws10smt/decoder/small_vector.h: + +/home/tnguyen/ws10smt/decoder/prob.h: + +/home/tnguyen/ws10smt/decoder/logval.h: +  ../../utils/Util.h:  ../../utils/UtfConverter.h: diff --git a/gi/scfg/abc/Release/scfg b/gi/scfg/abc/Release/scfgBinary files differ index 4b6cfb19..41b6b583 100755 --- a/gi/scfg/abc/Release/scfg +++ b/gi/scfg/abc/Release/scfg diff --git a/gi/scfg/abc/Release/scfg.d b/gi/scfg/abc/Release/scfg.d index ae7a87bb..b3cfbbb5 100644 --- a/gi/scfg/abc/Release/scfg.d +++ b/gi/scfg/abc/Release/scfg.d @@ -1,8 +1,4 @@ -scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/wordid.h \ - /home/tnguyen/ws10smt/decoder/array2d.h \ - /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \ +scfg.d scfg.o: ../scfg.cpp \   /export/ws10smt/software/include/boost/shared_ptr.hpp \   /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp \   /export/ws10smt/software/include/boost/config.hpp \ @@ -38,6 +34,12 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \   /export/ws10smt/software/include/boost/smart_ptr/detail/yield_k.hpp \   /export/ws10smt/software/include/boost/memory_order.hpp \   /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp \ + /export/ws10smt/software/include/boost/pointer_cast.hpp \ + /home/tnguyen/ws10smt/decoder/lattice.h \ + /home/tnguyen/ws10smt/decoder/wordid.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \   /home/tnguyen/ws10smt/decoder/lattice.h \   /home/tnguyen/ws10smt/decoder/trule.h \   /home/tnguyen/ws10smt/decoder/sparse_vector.h \ @@ -57,27 +59,15 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \   /export/ws10smt/software/include/boost/functional/hash/detail/hash_float_generic.hpp \   /export/ws10smt/software/include/boost/functional/hash/extensions.hpp \   /export/ws10smt/software/include/boost/detail/container_fwd.hpp \ - /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \   /home/tnguyen/ws10smt/decoder/hg.h \   /home/tnguyen/ws10smt/decoder/small_vector.h \   /home/tnguyen/ws10smt/decoder/prob.h \   /home/tnguyen/ws10smt/decoder/logval.h \ + /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \   /home/tnguyen/ws10smt/decoder/hg_intersect.h ../../utils/ParamsArray.h \   ../../utils/Util.h ../../utils/UtfConverter.h ../../utils/ConvertUTF.h -/home/tnguyen/ws10smt/decoder/lattice.h: - -/home/tnguyen/ws10smt/decoder/wordid.h: - -/home/tnguyen/ws10smt/decoder/array2d.h: - -/home/tnguyen/ws10smt/decoder/tdict.h: - -../agrammar.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: -  /export/ws10smt/software/include/boost/shared_ptr.hpp:  /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp: @@ -148,6 +138,20 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \  /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp: +/export/ws10smt/software/include/boost/pointer_cast.hpp: + +/home/tnguyen/ws10smt/decoder/lattice.h: + +/home/tnguyen/ws10smt/decoder/wordid.h: + +/home/tnguyen/ws10smt/decoder/array2d.h: + +/home/tnguyen/ws10smt/decoder/tdict.h: + +../agrammar.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: +  /home/tnguyen/ws10smt/decoder/lattice.h:  /home/tnguyen/ws10smt/decoder/trule.h: @@ -186,10 +190,6 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \  /export/ws10smt/software/include/boost/detail/container_fwd.hpp: -/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: -  /home/tnguyen/ws10smt/decoder/hg.h:  /home/tnguyen/ws10smt/decoder/small_vector.h: @@ -198,6 +198,10 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \  /home/tnguyen/ws10smt/decoder/logval.h: +/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: +  /home/tnguyen/ws10smt/decoder/hg_intersect.h:  ../../utils/ParamsArray.h: diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc index 585255e3..016a0189 100644 --- a/gi/scfg/abc/agrammar.cc +++ b/gi/scfg/abc/agrammar.cc @@ -8,6 +8,18 @@  #include "agrammar.h"  #include "../utils/Util.h" + + +aTRule::aTRule(TRulePtr rule){ + +  this -> e_ = rule->e_; +  this -> f_ = rule->f_; +  this ->lhs_ = rule->lhs_; +  this -> arity_ = rule->arity_; +  this -> scores_ = rule->scores_; +  ResetScore(0.00000001); +} +  bool equal(TRulePtr const & rule1, TRulePtr const & rule2){    if (rule1->lhs_ != rule2->lhs_) return false;    if (rule1->f_.size() != rule2->f_.size()) return false; @@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){    return true;  } +  //const vector<TRulePtr> Grammar::NO_RULES;  void aRemoveRule(vector<TRulePtr> & v, const TRulePtr  & rule){ // remove rule from v if found    for (int i=0; i< v.size(); i++)      if (equal(v[i], rule )){ -      cout<<"erase rule from vector:"<<rule->AsString()<<endl; +      //      cout<<"erase rule from vector:"<<rule->AsString()<<endl;        v.erase(v.begin()+i);      }  } +void aRemoveRule(vector<NTRule> & v, const NTRule  & ntrule){ // remove rule from v if found +  for (int i=0; i< v.size(); i++) +    if (equal(v[i].rule_, ntrule.rule_ )){ +      //      cout<<"erase rule from vector:"<<rule->AsString()<<endl; +       v.erase(v.begin()+i); +    } +} +  struct aTextRuleBin : public RuleBin {    int GetNumRules() const {      return rules_.size(); @@ -40,20 +61,16 @@ struct aTextRuleBin : public RuleBin {    void AddRule(TRulePtr t) {      rules_.push_back(t);    } -  void RemoveRule(TRulePtr t){ -    for (int i=0; i<rules_.size(); i++){ -      if (equal(rules_.at(i), t)){ -	rules_.erase(rules_.begin() + i); -	//cout<<"IntextRulebin removerulle\n"; -	return; -      } -    } + +  void RemoveRule(const TRulePtr & rule ){ +    aRemoveRule(rules_, rule);    }    int Arity() const {      return rules_.front()->Arity();    } +    void Dump() const {      for (int i = 0; i < rules_.size(); ++i)        cerr << rules_[i]->AsString() << endl; @@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin {    vector<TRulePtr> rules_;  }; +  struct aTextGrammarNode : public GrammarIter {    aTextGrammarNode() : rb_(NULL) {}    ~aTextGrammarNode() { @@ -90,8 +108,8 @@ struct aTGImpl {  aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {}  aTextGrammar::aTextGrammar(const string& file) :  -    max_span_(10), -    pimpl_(new aTGImpl) { +  max_span_(10), +  pimpl_(new aTGImpl) {    ReadFromFile(file);  } @@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){    goalID = TD::Convert(goal_str);  } +  void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){    NTRule lhs_ntrule(rule, rule->lhs_ * -1); @@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){          NTRule rhs_ntrule(rule, rule->f_.at(i) * -1);  	ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule;      } -   -    } + +  void aTextGrammar::AddRule(const TRulePtr& rule) {    if (rule->IsUnary()) {      rhs2unaries_[rule->f().front()].push_back(rule); @@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) {  }  void aTextGrammar::RemoveRule(const TRulePtr & rule){ -  cout<<"Remove rule:  "<<rule->AsString()<<endl; +  //  cout<<"Remove rule:  "<<rule->AsString()<<endl;    if (rule->IsUnary()) {      aRemoveRule(rhs2unaries_[rule->f().front()], rule);      aRemoveRule(unaries_, rule); @@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){    aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule); + +  //remove the rule from nt_rules_ +  map<WordID, NTRule> ntrule_map; +  getNTRule (rule, ntrule_map); +  for (map<WordID,NTRule>::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){ +    aRemoveRule(nt_rules_[it->first], it->second); +  } +  }  void aTextGrammar::RemoveNonterminal(WordID wordID){ @@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){    nt_rules_.erase(wordID);    for (int i =0; i<rules.size(); i++)      RemoveRule(rules[i].rule_); +  sum_probs_.erase(wordID); +  cnt_rules.erase(wordID);  } @@ -199,7 +228,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){      map<WordID, int> cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal -    map<WordID, int> cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule +    map<WordID, int> cnt_minusepsilon;       cnt_addepsilon[old_rule.rule_->lhs_] = 0;      cnt_minusepsilon[old_rule.rule_->lhs_] = 0;      for (int j =0; j<nts.size(); j++) {   cnt_addepsilon[nts[j] ] = 0;   cnt_minusepsilon[nts[j] ] = 0;} @@ -217,7 +246,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){        //      cout<<"print vector j_vector"<<endl;        //      for (int k=0; k<ntPos.size();k++) cout<<j_vector[k]<<"  "; cout<<endl;        //now use the vector to create a new rule -      TRulePtr newrule(new TRule()); +      TRulePtr newrule(new aTRule());        newrule -> e_   = (old_rule.rule_)->e_;        newrule -> f_ = old_rule.rule_->f_; @@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){    if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT      for (int i =0; i<v_splits.size(); i++){ -      TRulePtr rule (new TRule()); +      TRulePtr rule (new aTRule());        rule ->lhs_ = goalID * -1;        rule ->f_.push_back(v_splits[i] * -1);        rule->e_.push_back(0); @@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){    } +} + +void aTextGrammar::splitAllNonterminals(){ +  map<WordID, vector<TRulePtr> >::const_iterator it; +  vector<WordID> v ; // WordID >0 +  for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts +    if (it->first != goalID || lhs_rules_.size() ==1) +      v.push_back(it->first); +   +  for (int i=0; i< v.size(); i++) +    splitNonterminal(v[i]);  } +void aTextGrammar::PrintAllRules(const string & filename) const{ -void aTextGrammar::PrintAllRules() const{ -  map<WordID, vector<TRulePtr> >::const_iterator it; +   +  cerr<<"print grammar to "<<filename<<endl; + +  ofstream outfile(filename.c_str()); +  if (!outfile.good()) { +    cerr << "error opening output file " << filename << endl; +    exit(1); +  } + +  map<WordID, vector<TRulePtr > >::const_iterator it;    for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){      vector<TRulePtr> v = it-> second;      for (int i =0; i< v.size(); i++){ -      cout<<v[i]->AsString()<<"\t"<<endl; +      outfile<<v[i]->AsString()<<"\t"<<endl; +    } +  } +} + + +void aTextGrammar::ResetScore(){ + +  map<WordID, vector<TRulePtr > >::const_iterator it; +  for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ +    vector<TRulePtr> v = it-> second; +    for (int i =0; i< v.size(); i++){ +      //      cerr<<"Reset score of Rule "<<v[i]->AsString()<<endl; +      boost::static_pointer_cast<aTRule>(v[i])->ResetScore(alpha_ /v.size()); +    } +    lhs_rules_[it->first] = v; +    sum_probs_[it->first] = alpha_; +  } + +} + +void aTextGrammar::UpdateScore(){ + +  map<WordID, vector<TRulePtr > >::const_iterator it; +  for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ +    vector<TRulePtr> v = it-> second; +    for (int i =0; i< v.size(); i++){ +      boost::static_pointer_cast<aTRule>(v[i])->UpdateScore(sum_probs_[it->first] );      } + +    //    cerr<<"sum_probs_[it->first]  ="<<sum_probs_[it->first] <<endl; +    sum_probs_[it->first] = alpha_;    } + +} + + +void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){ +  std::vector<prob_t> posts ; +   +  prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts); +  for (int i =0; i<posts.size(); i++){ + +    //cout<<posts[i]<<endl; +    Hypergraph::Edge& e = hg.edges_[i]; +    string goalstr("Goal"); +    string str_lhs = TD::Convert(e.rule_->lhs_ * -1); + +    if (str_lhs.find(goalstr) != string::npos) +      continue; + +    //    cerr<<e.rule_->AsString()<<endl; +    //    cerr<<e.rule_->parent_rule_->AsString()<<endl; + +    boost::static_pointer_cast<aTRule>(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score); + //    cerr<<"add count for rule\n"; +//     cerr<<"posts[i]="<<posts[i]<<"  goal_score="<<goal_score<<endl; +//     cerr<<"posts[i] /goal_score="<<(posts[i] /goal_score)<<endl; +    sum_probs_[e.rule_->parent_rule_->lhs_* -1 ] += posts[i] /goal_score; + +  } +   +  } @@ -364,7 +473,9 @@ void aTextGrammar::PrintNonterminalRules(WordID nt) const{  }  static void AddRuleHelper(const TRulePtr& new_rule, void* extra) { -  static_cast<aTextGrammar*>(extra)->AddRule(new_rule); +  aTRule  *p = new aTRule(new_rule);  +   +  static_cast<aTextGrammar*>(extra)->AddRule(TRulePtr(p));  }  void aTextGrammar::ReadFromFile(const string& filename) { diff --git a/gi/scfg/abc/agrammar.h b/gi/scfg/abc/agrammar.h index 8a7186bf..0a8a60ac 100644 --- a/gi/scfg/abc/agrammar.h +++ b/gi/scfg/abc/agrammar.h @@ -2,10 +2,39 @@  #define AGRAMMAR_H_  #include "grammar.h" +#include "hg.h"  using namespace std; +class aTRule: public TRule{ + public: + aTRuleTRule : TRule(){ResetScore(0.00000001); } +  aTRule(TRulePtr rule_); + +  void ResetScore(double initscore){//cerr<<"Reset Score "<<this->AsString()<<endl; +    sum_scores_.set_value(FD::Convert("Prob"), initscore);} +  void AddProb(double p ){ +    //    cerr<<"in AddProb p="<<p<<endl; +    //    cerr<<"prob sumscores ="<<sum_scores_[FD::Convert("Prob")]<<endl; +    sum_scores_.add_value(FD::Convert("Prob"), p); +    //    cerr<<"after AddProb\n"; +  } + +  void UpdateScore(double sumprob){ +    double minuslogp = 0 - log( sum_scores_.value(FD::Convert("Prob")) /sumprob); +    if (sumprob<  sum_scores_.value(FD::Convert("Prob"))){ +      cerr<<"UpdateScore sumprob="<<sumprob<< "  sum_scores_.value(FD::Convert(\"Prob\"))="<< sum_scores_.value(FD::Convert("Prob"))<< this->AsString()<<endl; +      exit(1); +    } +    this->scores_.set_value(FD::Convert("MinusLogP"), minuslogp); + +  } + private: +  SparseVector<double> sum_scores_; +}; + +  class aTGImpl;  struct NTRule{ @@ -20,17 +49,19 @@ struct NTRule{      for (int i=0; i< rule->f().size(); i++)        if (rule->f().at(i) * -1 == nt)  	ntPos_.push_back(i); + +    }    TRulePtr rule_; -  WordID nt_; //the labelID of the nt (WordID>0); +  WordID nt_; //the labelID of the nt (nt_>0);    vector<int> ntPos_; //position of nt_ -1: lhs, from 0...f_.size() for nt of f_()    //i.e the rules is: NP-> DET NP; if nt_=5 is the labelID of NP then ntPos_ = (-1, 1): the indexes of nonterminal NP -    }; +  struct aTextGrammar : public Grammar {    aTextGrammar();    aTextGrammar(const std::string& file); @@ -46,9 +77,20 @@ struct aTextGrammar : public Grammar {    void setMaxSplit(int max_split);    void splitNonterminal(WordID wordID); -  void PrintAllRules() const; + +  void splitAllNonterminals(); + +  void PrintAllRules(const string & filename) const;    void PrintNonterminalRules(WordID nt) const;    void SetGoalNT(const string & goal_str); + +  void ResetScore(); + +  void UpdateScore(); + +  void UpdateHgProsteriorProb(Hypergraph & hg); + +  void set_alpha(double alpha){alpha_ = alpha;}   private:    void RemoveRule(const TRulePtr & rule); @@ -57,9 +99,15 @@ struct aTextGrammar : public Grammar {    int max_span_;    int max_split_;    boost::shared_ptr<aTGImpl> pimpl_; +    map <WordID, vector<TRulePtr> > lhs_rules_;// WordID >0    map <WordID, vector<NTRule> > nt_rules_;  +  map <WordID, double> sum_probs_; +  map <WordID, double> cnt_rules; + +  double alpha_; +    //  map<WordID, vector<WordID> > grSplitNonterminals;    WordID goalID;  }; diff --git a/gi/scfg/abc/scfg.cpp b/gi/scfg/abc/scfg.cpp index 4d094488..b3dbad34 100644 --- a/gi/scfg/abc/scfg.cpp +++ b/gi/scfg/abc/scfg.cpp @@ -1,3 +1,8 @@ +#include <iostream> +#include <fstream> + +#include <boost/shared_ptr.hpp> +#include <boost/pointer_cast.hpp>  #include "lattice.h"  #include "tdict.h"  #include "agrammar.h" @@ -9,13 +14,53 @@  using namespace std; +vector<string> src_corpus; +vector<string> tgt_corpus; + +bool openParallelCorpora(string & input_filename){ +  ifstream input_file; + +  input_file.open(input_filename.c_str()); +  if (!input_file) { +    cerr << "Cannot open input file " << input_filename << ". Exiting..." << endl; +    return false; +  }  + +  int line =0; +  while (!input_file.eof()) { +    // get a line of source language data                                                                                                                                           +    //    cerr<<"new line "<<ctr<<endl;                                                                                                                                            +    string str; + +    getline(input_file, str); +    line++; +    if (str.length()==0){ +      cerr<<" sentence number "<<line<<" is empty, skip the sentence\n"; +      continue; +    } +    string delimiters("|||"); + +    vector<string> v = tokenize(str, delimiters); + +    if ( (v.size() != 2)  and (v.size() != 3) )  { +      cerr<<str<<endl; +      cerr<<" source or target sentence is not found in sentence number "<<line<<" , skip the sentence\n"; +      continue; +    } + +    src_corpus.push_back(v[0]); +    tgt_corpus.push_back(v[1]); +  } +  return true; +} + +  typedef aTextGrammar aGrammar;  aGrammar * load_grammar(string & grammar_filename){    cerr<<"start_load_grammar "<<grammar_filename<<endl;    aGrammar * test = new aGrammar(grammar_filename); -    return test;  } @@ -26,7 +71,6 @@ Lattice convertSentenceToLattice(const string & str){    Lattice lsentence;    lsentence.resize(vID.size()); -    for (int i=0; i<vID.size(); i++){      lsentence[i].push_back( LatticeArc(vID[i], 0.0, 1) );   @@ -41,6 +85,8 @@ Lattice convertSentenceToLattice(const string & str){  bool parseSentencePair(const string & goal_sym, const string & src, const string & tgt,  GrammarPtr & g, Hypergraph &hg){ + +  //  cout<<"  Start parse the sentence pairs\n"<<endl;    Lattice lsource = convertSentenceToLattice(src);    //parse the source sentence by the grammar @@ -51,7 +97,7 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string    if (!parser.Parse(lsource, &hg)){ -     cerr<<"source sentence does not parse by the grammar!"<<endl; +     cerr<<"source sentence is not parsed by the grammar!"<<endl;       return false;     } @@ -59,8 +105,15 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string    Lattice ltarget = convertSentenceToLattice(tgt);    //forest.PrintGraphviz(); -  return HG::Intersect(ltarget, & hg); +  if (!HG::Intersect(ltarget, & hg)) return false; + +  SparseVector<double> reweight; +   +  reweight.set_value(FD::Convert("MinusLogP"), -1 ); +  hg.Reweight(reweight); +  return true; +    } @@ -71,74 +124,140 @@ int main(int argc, char** argv){    ParamsArray params(argc, argv);    params.setDescription("scfg models"); -  params.addConstraint("grammar_file", "grammar file ", true); //  optional                                +  params.addConstraint("grammar_file", "grammar file (default ./grammar.pr )", true); //  optional                                + +  params.addConstraint("input_file", "parallel input file (default ./parallel_corpora)", true); //optional                                          + +  params.addConstraint("output_file", "grammar output file (default ./grammar_output)", true); //optional                                          + +  params.addConstraint("goal_symbol", "top nonterminal symbol (default: X)", true); //optional                                          + +  params.addConstraint("split", "split one nonterminal into 'split' nonterminals (default: 2)", true); //optional                                          -  params.addConstraint("input_file", "parallel input file", true); //optional                                          +  params.addConstraint("prob_iters", "number of iterations (default: 10)", true); //optional                                          + +  params.addConstraint("split_iters", "number of splitting iterations (default: 3)", true); //optional                                          + +  params.addConstraint("alpha", "alpha (default: 0.1)", true); //optional                                             if (!params.runConstraints("scfg")) {      return 0;    }    cerr<<"get parametters\n\n\n"; -  string input_file = params.asString("input_file", "parallel_corpora");    string grammar_file = params.asString("grammar_file", "./grammar.pr"); +  string input_file = params.asString("input_file", "parallel_corpora"); -  string src = "el gato ."; -   -  string tgt = "the cat ."; - - -  string goal_sym = "X"; -  srand(123); -  /*load grammar*/ +  string output_file = params.asString("output_file", "grammar_output"); +  string goal_sym = params.asString("goal_symbol", "X"); +  int max_split = atoi(params.asString("split", "2").c_str()); +   +  int prob_iters = atoi(params.asString("prob_iters", "2").c_str()); +  int split_iters = atoi(params.asString("split_iters", "1").c_str()); +  double alpha = atof(params.asString("alpha", ".001").c_str()); + +  ///// +  cerr<<"grammar_file ="<<grammar_file<<endl; +  cerr<<"input_file ="<< input_file<<endl; +  cerr<<"output_file ="<< output_file<<endl; +  cerr<<"goal_sym ="<< goal_sym<<endl; +  cerr<<"max_split ="<< max_split<<endl; +  cerr<<"prob_iters ="<< prob_iters<<endl; +  cerr<<"split_iters ="<< split_iters<<endl; +  cerr<<"alpha ="<< alpha<<endl; +  ////////////////////////// + +  cerr<<"\n\nLoad parallel corpus...\n"; +  if (! openParallelCorpora(input_file)) +    exit(1); + +  cerr<<"Load grammar file ...\n";    aGrammar * agrammar = load_grammar(grammar_file);    agrammar->SetGoalNT(goal_sym); -  cout<<"before split nonterminal"<<endl; -  GrammarPtr g( agrammar); +  agrammar->setMaxSplit(max_split); +  agrammar->set_alpha(alpha); +  srand(123); + +  GrammarPtr g( agrammar);    Hypergraph hg; -  if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ -    cerr<<"target sentence is not parsed by the grammar!\n"; -    return 1; -   } -   hg.PrintGraphviz(); +  int data_size = src_corpus.size(); +  for (int i =0; i <split_iters; i++){ +     +    cerr<<"Split Nonterminals, iteration "<<(i+1)<<endl; +    agrammar->PrintAllRules(output_file+".s" + itos(i+1)); +    agrammar->splitAllNonterminals(); + +    //vector<string> src_corpus; +    //vector<string> tgt_corpus; +     +    for (int j=0; j<prob_iters; j++){ +      cerr<<"reset grammar score\n"; +      agrammar->ResetScore(); +      //      cerr<<"done reset grammar score\n"; +      for (int k=0; k <data_size; k++){ +	string src = src_corpus[k]; +   +	string tgt = tgt_corpus[k]; +	cerr <<"parse sentence pair: "<<src<<"  |||  "<<tgt<<endl; + +	if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ +	  cerr<<"target sentence is not parsed by the grammar!\n"; +	  //return 1; +	  continue; + +	}  +	cerr<<"update edge posterior prob"<<endl; +	boost::static_pointer_cast<aGrammar>(g)->UpdateHgProsteriorProb(hg); +	hg.clear(); +      } +      boost::static_pointer_cast<aGrammar>(g)->UpdateScore(); +    } +    boost::static_pointer_cast<aGrammar>(g)->PrintAllRules(output_file+".e" + itos(i+1)); +  } -  if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ -    cerr<<"target sentence is not parsed by the grammar!\n"; -    return 1; -   } -   hg.PrintGraphviz(); -   //hg.clear(); -  if (1==1) return 1; + +   + -  agrammar->PrintAllRules(); -  /*split grammar*/ -  cout<<"split NTs\n";  -  cerr<<"first of all write all nonterminals"<<endl; -  // agrammar->printAllNonterminals(); -  agrammar->setMaxSplit(2); -  agrammar->splitNonterminal(4); -  cout<<"after split nonterminal"<<endl; -  agrammar->PrintAllRules(); -  Hypergraph hg1; -  if (! parseSentencePair(goal_sym, src, tgt,  g, hg1) ){ -    cerr<<"target sentence is not parsed by the grammar!\n"; -    return 1; -  } -  hg1.PrintGraphviz(); +  // // agrammar->ResetScore(); +  // // agrammar->UpdateScore(); +  // if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ +  //   cerr<<"target sentence is not parsed by the grammar!\n"; +  //   return 1; + +  //  } +  // //   hg.PrintGraphviz(); +  //  //hg.clear(); + +  // agrammar->PrintAllRules(); +  // /*split grammar*/ +  // cout<<"split NTs\n";  +  // cerr<<"first of all write all nonterminals"<<endl; +  // // agrammar->printAllNonterminals(); +  // cout<<"after split nonterminal"<<endl; +  // agrammar->PrintAllRules(); +  // Hypergraph hg1; +  // if (! parseSentencePair(goal_sym, src, tgt,  g, hg1) ){ +  //   cerr<<"target sentence is not parsed by the grammar!\n"; +  //   return 1; + +  // } + +  // hg1.PrintGraphviz(); -  agrammar->splitNonterminal(15); -  cout<<"after split nonterminal"<<TD::Convert(15)<<endl; -  agrammar->PrintAllRules(); +  // agrammar->splitNonterminal(15); +  // cout<<"after split nonterminal"<<TD::Convert(15)<<endl; +  // agrammar->PrintAllRules();    /*load training corpus*/ | 
