From 15a587e247dc0954de27e2627f5511126243943d Mon Sep 17 00:00:00 2001 From: "linh.kitty" Date: Fri, 16 Jul 2010 17:44:44 +0000 Subject: add git-svn-id: https://ws10smt.googlecode.com/svn/trunk@286 ec762483-ff6d-05da-a07a-a48fb63a330f --- gi/scfg/abc/Release/agrammar.d | 14 ++- gi/scfg/abc/Release/scfg | Bin 4277125 -> 4437132 bytes gi/scfg/abc/Release/scfg.d | 50 +++++----- gi/scfg/abc/agrammar.cc | 153 +++++++++++++++++++++++++---- gi/scfg/abc/agrammar.h | 54 ++++++++++- gi/scfg/abc/scfg.cpp | 213 ++++++++++++++++++++++++++++++++--------- 6 files changed, 389 insertions(+), 95 deletions(-) (limited to 'gi/scfg') diff --git a/gi/scfg/abc/Release/agrammar.d b/gi/scfg/abc/Release/agrammar.d index 6cf14f0d..553752ca 100644 --- a/gi/scfg/abc/Release/agrammar.d +++ b/gi/scfg/abc/Release/agrammar.d @@ -59,7 +59,11 @@ agrammar.d agrammar.o: ../agrammar.cc \ /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/array2d.h ../../utils/Util.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/hg.h \ + /home/tnguyen/ws10smt/decoder/small_vector.h \ + /home/tnguyen/ws10smt/decoder/prob.h \ + /home/tnguyen/ws10smt/decoder/logval.h ../../utils/Util.h \ ../../utils/UtfConverter.h ../../utils/ConvertUTF.h /home/tnguyen/ws10smt/decoder/rule_lexer.h: @@ -186,6 +190,14 @@ agrammar.d agrammar.o: ../agrammar.cc \ /home/tnguyen/ws10smt/decoder/array2d.h: +/home/tnguyen/ws10smt/decoder/hg.h: + +/home/tnguyen/ws10smt/decoder/small_vector.h: + +/home/tnguyen/ws10smt/decoder/prob.h: + +/home/tnguyen/ws10smt/decoder/logval.h: + ../../utils/Util.h: ../../utils/UtfConverter.h: diff --git a/gi/scfg/abc/Release/scfg b/gi/scfg/abc/Release/scfg index 4b6cfb19..41b6b583 100755 Binary files a/gi/scfg/abc/Release/scfg and b/gi/scfg/abc/Release/scfg differ diff --git a/gi/scfg/abc/Release/scfg.d b/gi/scfg/abc/Release/scfg.d index ae7a87bb..b3cfbbb5 100644 --- a/gi/scfg/abc/Release/scfg.d +++ b/gi/scfg/abc/Release/scfg.d @@ -1,8 +1,4 @@ -scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ - /home/tnguyen/ws10smt/decoder/wordid.h \ - /home/tnguyen/ws10smt/decoder/array2d.h \ - /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \ +scfg.d scfg.o: ../scfg.cpp \ /export/ws10smt/software/include/boost/shared_ptr.hpp \ /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp \ /export/ws10smt/software/include/boost/config.hpp \ @@ -38,6 +34,12 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/smart_ptr/detail/yield_k.hpp \ /export/ws10smt/software/include/boost/memory_order.hpp \ /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp \ + /export/ws10smt/software/include/boost/pointer_cast.hpp \ + /home/tnguyen/ws10smt/decoder/lattice.h \ + /home/tnguyen/ws10smt/decoder/wordid.h \ + /home/tnguyen/ws10smt/decoder/array2d.h \ + /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/lattice.h \ /home/tnguyen/ws10smt/decoder/trule.h \ /home/tnguyen/ws10smt/decoder/sparse_vector.h \ @@ -57,27 +59,15 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/functional/hash/detail/hash_float_generic.hpp \ /export/ws10smt/software/include/boost/functional/hash/extensions.hpp \ /export/ws10smt/software/include/boost/detail/container_fwd.hpp \ - /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ - /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/hg.h \ /home/tnguyen/ws10smt/decoder/small_vector.h \ /home/tnguyen/ws10smt/decoder/prob.h \ /home/tnguyen/ws10smt/decoder/logval.h \ + /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \ + /home/tnguyen/ws10smt/decoder/grammar.h \ /home/tnguyen/ws10smt/decoder/hg_intersect.h ../../utils/ParamsArray.h \ ../../utils/Util.h ../../utils/UtfConverter.h ../../utils/ConvertUTF.h -/home/tnguyen/ws10smt/decoder/lattice.h: - -/home/tnguyen/ws10smt/decoder/wordid.h: - -/home/tnguyen/ws10smt/decoder/array2d.h: - -/home/tnguyen/ws10smt/decoder/tdict.h: - -../agrammar.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: - /export/ws10smt/software/include/boost/shared_ptr.hpp: /export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp: @@ -148,6 +138,20 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp: +/export/ws10smt/software/include/boost/pointer_cast.hpp: + +/home/tnguyen/ws10smt/decoder/lattice.h: + +/home/tnguyen/ws10smt/decoder/wordid.h: + +/home/tnguyen/ws10smt/decoder/array2d.h: + +/home/tnguyen/ws10smt/decoder/tdict.h: + +../agrammar.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: + /home/tnguyen/ws10smt/decoder/lattice.h: /home/tnguyen/ws10smt/decoder/trule.h: @@ -186,10 +190,6 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /export/ws10smt/software/include/boost/detail/container_fwd.hpp: -/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: - -/home/tnguyen/ws10smt/decoder/grammar.h: - /home/tnguyen/ws10smt/decoder/hg.h: /home/tnguyen/ws10smt/decoder/small_vector.h: @@ -198,6 +198,10 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \ /home/tnguyen/ws10smt/decoder/logval.h: +/home/tnguyen/ws10smt/decoder/bottom_up_parser.h: + +/home/tnguyen/ws10smt/decoder/grammar.h: + /home/tnguyen/ws10smt/decoder/hg_intersect.h: ../../utils/ParamsArray.h: diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc index 585255e3..016a0189 100644 --- a/gi/scfg/abc/agrammar.cc +++ b/gi/scfg/abc/agrammar.cc @@ -8,6 +8,18 @@ #include "agrammar.h" #include "../utils/Util.h" + + +aTRule::aTRule(TRulePtr rule){ + + this -> e_ = rule->e_; + this -> f_ = rule->f_; + this ->lhs_ = rule->lhs_; + this -> arity_ = rule->arity_; + this -> scores_ = rule->scores_; + ResetScore(0.00000001); +} + bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ if (rule1->lhs_ != rule2->lhs_) return false; if (rule1->f_.size() != rule2->f_.size()) return false; @@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){ return true; } + //const vector Grammar::NO_RULES; void aRemoveRule(vector & v, const TRulePtr & rule){ // remove rule from v if found for (int i=0; i< v.size(); i++) if (equal(v[i], rule )){ - cout<<"erase rule from vector:"<AsString()<AsString()< & v, const NTRule & ntrule){ // remove rule from v if found + for (int i=0; i< v.size(); i++) + if (equal(v[i].rule_, ntrule.rule_ )){ + // cout<<"erase rule from vector:"<AsString()<Arity(); } + void Dump() const { for (int i = 0; i < rules_.size(); ++i) cerr << rules_[i]->AsString() << endl; @@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin { vector rules_; }; + struct aTextGrammarNode : public GrammarIter { aTextGrammarNode() : rb_(NULL) {} ~aTextGrammarNode() { @@ -90,8 +108,8 @@ struct aTGImpl { aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {} aTextGrammar::aTextGrammar(const string& file) : - max_span_(10), - pimpl_(new aTGImpl) { + max_span_(10), + pimpl_(new aTGImpl) { ReadFromFile(file); } @@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){ goalID = TD::Convert(goal_str); } + void getNTRule( const TRulePtr & rule, map & ntrule_map){ NTRule lhs_ntrule(rule, rule->lhs_ * -1); @@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map & ntrule_map){ NTRule rhs_ntrule(rule, rule->f_.at(i) * -1); ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule; } - - } + + void aTextGrammar::AddRule(const TRulePtr& rule) { if (rule->IsUnary()) { rhs2unaries_[rule->f().front()].push_back(rule); @@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) { } void aTextGrammar::RemoveRule(const TRulePtr & rule){ - cout<<"Remove rule: "<AsString()<AsString()<IsUnary()) { aRemoveRule(rhs2unaries_[rule->f().front()], rule); aRemoveRule(unaries_, rule); @@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){ aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule); + + //remove the rule from nt_rules_ + map ntrule_map; + getNTRule (rule, ntrule_map); + for (map::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){ + aRemoveRule(nt_rules_[it->first], it->second); + } + } void aTextGrammar::RemoveNonterminal(WordID wordID){ @@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){ nt_rules_.erase(wordID); for (int i =0; i & nts){ map cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal - map cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule + map cnt_minusepsilon; cnt_addepsilon[old_rule.rule_->lhs_] = 0; cnt_minusepsilon[old_rule.rule_->lhs_] = 0; for (int j =0; j & nts){ // cout<<"print vector j_vector"< e_ = (old_rule.rule_)->e_; newrule -> f_ = old_rule.rule_->f_; @@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){ if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT for (int i =0; ilhs_ = goalID * -1; rule ->f_.push_back(v_splits[i] * -1); rule->e_.push_back(0); @@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){ } +} + +void aTextGrammar::splitAllNonterminals(){ + map >::const_iterator it; + vector v ; // WordID >0 + for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts + if (it->first != goalID || lhs_rules_.size() ==1) + v.push_back(it->first); + + for (int i=0; i< v.size(); i++) + splitNonterminal(v[i]); } +void aTextGrammar::PrintAllRules(const string & filename) const{ -void aTextGrammar::PrintAllRules() const{ - map >::const_iterator it; + + cerr<<"print grammar to "< >::const_iterator it; for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ vector v = it-> second; for (int i =0; i< v.size(); i++){ - cout<AsString()<<"\t"<AsString()<<"\t"< >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector v = it-> second; + for (int i =0; i< v.size(); i++){ + // cerr<<"Reset score of Rule "<AsString()<(v[i])->ResetScore(alpha_ /v.size()); + } + lhs_rules_[it->first] = v; + sum_probs_[it->first] = alpha_; + } + +} + +void aTextGrammar::UpdateScore(){ + + map >::const_iterator it; + for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){ + vector v = it-> second; + for (int i =0; i< v.size(); i++){ + boost::static_pointer_cast(v[i])->UpdateScore(sum_probs_[it->first] ); } + + // cerr<<"sum_probs_[it->first] ="<first] <first] = alpha_; } + +} + + +void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){ + std::vector posts ; + + prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts); + for (int i =0; ilhs_ * -1); + + if (str_lhs.find(goalstr) != string::npos) + continue; + + // cerr<AsString()<parent_rule_->AsString()<(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score); + // cerr<<"add count for rule\n"; +// cerr<<"posts[i]="<AsString()<AsString()<scores_.set_value(FD::Convert("MinusLogP"), minuslogp); + + } + private: + SparseVector sum_scores_; +}; + + class aTGImpl; struct NTRule{ @@ -20,17 +49,19 @@ struct NTRule{ for (int i=0; i< rule->f().size(); i++) if (rule->f().at(i) * -1 == nt) ntPos_.push_back(i); + + } TRulePtr rule_; - WordID nt_; //the labelID of the nt (WordID>0); + WordID nt_; //the labelID of the nt (nt_>0); vector ntPos_; //position of nt_ -1: lhs, from 0...f_.size() for nt of f_() //i.e the rules is: NP-> DET NP; if nt_=5 is the labelID of NP then ntPos_ = (-1, 1): the indexes of nonterminal NP - }; + struct aTextGrammar : public Grammar { aTextGrammar(); aTextGrammar(const std::string& file); @@ -46,9 +77,20 @@ struct aTextGrammar : public Grammar { void setMaxSplit(int max_split); void splitNonterminal(WordID wordID); - void PrintAllRules() const; + + void splitAllNonterminals(); + + void PrintAllRules(const string & filename) const; void PrintNonterminalRules(WordID nt) const; void SetGoalNT(const string & goal_str); + + void ResetScore(); + + void UpdateScore(); + + void UpdateHgProsteriorProb(Hypergraph & hg); + + void set_alpha(double alpha){alpha_ = alpha;} private: void RemoveRule(const TRulePtr & rule); @@ -57,9 +99,15 @@ struct aTextGrammar : public Grammar { int max_span_; int max_split_; boost::shared_ptr pimpl_; + map > lhs_rules_;// WordID >0 map > nt_rules_; + map sum_probs_; + map cnt_rules; + + double alpha_; + // map > grSplitNonterminals; WordID goalID; }; diff --git a/gi/scfg/abc/scfg.cpp b/gi/scfg/abc/scfg.cpp index 4d094488..b3dbad34 100644 --- a/gi/scfg/abc/scfg.cpp +++ b/gi/scfg/abc/scfg.cpp @@ -1,3 +1,8 @@ +#include +#include + +#include +#include #include "lattice.h" #include "tdict.h" #include "agrammar.h" @@ -9,13 +14,53 @@ using namespace std; +vector src_corpus; +vector tgt_corpus; + +bool openParallelCorpora(string & input_filename){ + ifstream input_file; + + input_file.open(input_filename.c_str()); + if (!input_file) { + cerr << "Cannot open input file " << input_filename << ". Exiting..." << endl; + return false; + } + + int line =0; + while (!input_file.eof()) { + // get a line of source language data + // cerr<<"new line "< v = tokenize(str, delimiters); + + if ( (v.size() != 2) and (v.size() != 3) ) { + cerr< reweight; + + reweight.set_value(FD::Convert("MinusLogP"), -1 ); + hg.Reweight(reweight); + return true; + } @@ -71,74 +124,140 @@ int main(int argc, char** argv){ ParamsArray params(argc, argv); params.setDescription("scfg models"); - params.addConstraint("grammar_file", "grammar file ", true); // optional + params.addConstraint("grammar_file", "grammar file (default ./grammar.pr )", true); // optional + + params.addConstraint("input_file", "parallel input file (default ./parallel_corpora)", true); //optional + + params.addConstraint("output_file", "grammar output file (default ./grammar_output)", true); //optional + + params.addConstraint("goal_symbol", "top nonterminal symbol (default: X)", true); //optional + + params.addConstraint("split", "split one nonterminal into 'split' nonterminals (default: 2)", true); //optional - params.addConstraint("input_file", "parallel input file", true); //optional + params.addConstraint("prob_iters", "number of iterations (default: 10)", true); //optional + + params.addConstraint("split_iters", "number of splitting iterations (default: 3)", true); //optional + + params.addConstraint("alpha", "alpha (default: 0.1)", true); //optional if (!params.runConstraints("scfg")) { return 0; } cerr<<"get parametters\n\n\n"; - string input_file = params.asString("input_file", "parallel_corpora"); string grammar_file = params.asString("grammar_file", "./grammar.pr"); + string input_file = params.asString("input_file", "parallel_corpora"); - string src = "el gato ."; - - string tgt = "the cat ."; - - - string goal_sym = "X"; - srand(123); - /*load grammar*/ + string output_file = params.asString("output_file", "grammar_output"); + string goal_sym = params.asString("goal_symbol", "X"); + int max_split = atoi(params.asString("split", "2").c_str()); + + int prob_iters = atoi(params.asString("prob_iters", "2").c_str()); + int split_iters = atoi(params.asString("split_iters", "1").c_str()); + double alpha = atof(params.asString("alpha", ".001").c_str()); + + ///// + cerr<<"grammar_file ="<SetGoalNT(goal_sym); - cout<<"before split nonterminal"<setMaxSplit(max_split); + agrammar->set_alpha(alpha); + srand(123); + + GrammarPtr g( agrammar); Hypergraph hg; - if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg.PrintGraphviz(); + int data_size = src_corpus.size(); + for (int i =0; i PrintAllRules(output_file+".s" + itos(i+1)); + agrammar->splitAllNonterminals(); + + //vector src_corpus; + //vector tgt_corpus; + + for (int j=0; jResetScore(); + // cerr<<"done reset grammar score\n"; + for (int k=0; k (g)->UpdateHgProsteriorProb(hg); + hg.clear(); + } + boost::static_pointer_cast(g)->UpdateScore(); + } + boost::static_pointer_cast(g)->PrintAllRules(output_file+".e" + itos(i+1)); + } - if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg.PrintGraphviz(); - //hg.clear(); - if (1==1) return 1; + + + - agrammar->PrintAllRules(); - /*split grammar*/ - cout<<"split NTs\n"; - cerr<<"first of all write all nonterminals"<printAllNonterminals(); - agrammar->setMaxSplit(2); - agrammar->splitNonterminal(4); - cout<<"after split nonterminal"<PrintAllRules(); - Hypergraph hg1; - if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){ - cerr<<"target sentence is not parsed by the grammar!\n"; - return 1; - } - hg1.PrintGraphviz(); + // // agrammar->ResetScore(); + // // agrammar->UpdateScore(); + // if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){ + // cerr<<"target sentence is not parsed by the grammar!\n"; + // return 1; + + // } + // // hg.PrintGraphviz(); + // //hg.clear(); + + // agrammar->PrintAllRules(); + // /*split grammar*/ + // cout<<"split NTs\n"; + // cerr<<"first of all write all nonterminals"<printAllNonterminals(); + // cout<<"after split nonterminal"<PrintAllRules(); + // Hypergraph hg1; + // if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){ + // cerr<<"target sentence is not parsed by the grammar!\n"; + // return 1; + + // } + + // hg1.PrintGraphviz(); - agrammar->splitNonterminal(15); - cout<<"after split nonterminal"<PrintAllRules(); + // agrammar->splitNonterminal(15); + // cout<<"after split nonterminal"<PrintAllRules(); /*load training corpus*/ -- cgit v1.2.3