summaryrefslogtreecommitdiff
path: root/gi
diff options
context:
space:
mode:
Diffstat (limited to 'gi')
-rw-r--r--gi/scfg/abc/Release/agrammar.d14
-rwxr-xr-xgi/scfg/abc/Release/scfgbin4277125 -> 4437132 bytes
-rw-r--r--gi/scfg/abc/Release/scfg.d50
-rw-r--r--gi/scfg/abc/agrammar.cc153
-rw-r--r--gi/scfg/abc/agrammar.h54
-rw-r--r--gi/scfg/abc/scfg.cpp213
6 files changed, 389 insertions, 95 deletions
diff --git a/gi/scfg/abc/Release/agrammar.d b/gi/scfg/abc/Release/agrammar.d
index 6cf14f0d..553752ca 100644
--- a/gi/scfg/abc/Release/agrammar.d
+++ b/gi/scfg/abc/Release/agrammar.d
@@ -59,7 +59,11 @@ agrammar.d agrammar.o: ../agrammar.cc \
/home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \
/home/tnguyen/ws10smt/decoder/grammar.h \
/home/tnguyen/ws10smt/decoder/lattice.h \
- /home/tnguyen/ws10smt/decoder/array2d.h ../../utils/Util.h \
+ /home/tnguyen/ws10smt/decoder/array2d.h \
+ /home/tnguyen/ws10smt/decoder/hg.h \
+ /home/tnguyen/ws10smt/decoder/small_vector.h \
+ /home/tnguyen/ws10smt/decoder/prob.h \
+ /home/tnguyen/ws10smt/decoder/logval.h ../../utils/Util.h \
../../utils/UtfConverter.h ../../utils/ConvertUTF.h
/home/tnguyen/ws10smt/decoder/rule_lexer.h:
@@ -186,6 +190,14 @@ agrammar.d agrammar.o: ../agrammar.cc \
/home/tnguyen/ws10smt/decoder/array2d.h:
+/home/tnguyen/ws10smt/decoder/hg.h:
+
+/home/tnguyen/ws10smt/decoder/small_vector.h:
+
+/home/tnguyen/ws10smt/decoder/prob.h:
+
+/home/tnguyen/ws10smt/decoder/logval.h:
+
../../utils/Util.h:
../../utils/UtfConverter.h:
diff --git a/gi/scfg/abc/Release/scfg b/gi/scfg/abc/Release/scfg
index 4b6cfb19..41b6b583 100755
--- a/gi/scfg/abc/Release/scfg
+++ b/gi/scfg/abc/Release/scfg
Binary files differ
diff --git a/gi/scfg/abc/Release/scfg.d b/gi/scfg/abc/Release/scfg.d
index ae7a87bb..b3cfbbb5 100644
--- a/gi/scfg/abc/Release/scfg.d
+++ b/gi/scfg/abc/Release/scfg.d
@@ -1,8 +1,4 @@
-scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
- /home/tnguyen/ws10smt/decoder/wordid.h \
- /home/tnguyen/ws10smt/decoder/array2d.h \
- /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \
- /home/tnguyen/ws10smt/decoder/grammar.h \
+scfg.d scfg.o: ../scfg.cpp \
/export/ws10smt/software/include/boost/shared_ptr.hpp \
/export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp \
/export/ws10smt/software/include/boost/config.hpp \
@@ -38,6 +34,12 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
/export/ws10smt/software/include/boost/smart_ptr/detail/yield_k.hpp \
/export/ws10smt/software/include/boost/memory_order.hpp \
/export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp \
+ /export/ws10smt/software/include/boost/pointer_cast.hpp \
+ /home/tnguyen/ws10smt/decoder/lattice.h \
+ /home/tnguyen/ws10smt/decoder/wordid.h \
+ /home/tnguyen/ws10smt/decoder/array2d.h \
+ /home/tnguyen/ws10smt/decoder/tdict.h ../agrammar.h \
+ /home/tnguyen/ws10smt/decoder/grammar.h \
/home/tnguyen/ws10smt/decoder/lattice.h \
/home/tnguyen/ws10smt/decoder/trule.h \
/home/tnguyen/ws10smt/decoder/sparse_vector.h \
@@ -57,27 +59,15 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
/export/ws10smt/software/include/boost/functional/hash/detail/hash_float_generic.hpp \
/export/ws10smt/software/include/boost/functional/hash/extensions.hpp \
/export/ws10smt/software/include/boost/detail/container_fwd.hpp \
- /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \
- /home/tnguyen/ws10smt/decoder/grammar.h \
/home/tnguyen/ws10smt/decoder/hg.h \
/home/tnguyen/ws10smt/decoder/small_vector.h \
/home/tnguyen/ws10smt/decoder/prob.h \
/home/tnguyen/ws10smt/decoder/logval.h \
+ /home/tnguyen/ws10smt/decoder/bottom_up_parser.h \
+ /home/tnguyen/ws10smt/decoder/grammar.h \
/home/tnguyen/ws10smt/decoder/hg_intersect.h ../../utils/ParamsArray.h \
../../utils/Util.h ../../utils/UtfConverter.h ../../utils/ConvertUTF.h
-/home/tnguyen/ws10smt/decoder/lattice.h:
-
-/home/tnguyen/ws10smt/decoder/wordid.h:
-
-/home/tnguyen/ws10smt/decoder/array2d.h:
-
-/home/tnguyen/ws10smt/decoder/tdict.h:
-
-../agrammar.h:
-
-/home/tnguyen/ws10smt/decoder/grammar.h:
-
/export/ws10smt/software/include/boost/shared_ptr.hpp:
/export/ws10smt/software/include/boost/smart_ptr/shared_ptr.hpp:
@@ -148,6 +138,20 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
/export/ws10smt/software/include/boost/smart_ptr/detail/operator_bool.hpp:
+/export/ws10smt/software/include/boost/pointer_cast.hpp:
+
+/home/tnguyen/ws10smt/decoder/lattice.h:
+
+/home/tnguyen/ws10smt/decoder/wordid.h:
+
+/home/tnguyen/ws10smt/decoder/array2d.h:
+
+/home/tnguyen/ws10smt/decoder/tdict.h:
+
+../agrammar.h:
+
+/home/tnguyen/ws10smt/decoder/grammar.h:
+
/home/tnguyen/ws10smt/decoder/lattice.h:
/home/tnguyen/ws10smt/decoder/trule.h:
@@ -186,10 +190,6 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
/export/ws10smt/software/include/boost/detail/container_fwd.hpp:
-/home/tnguyen/ws10smt/decoder/bottom_up_parser.h:
-
-/home/tnguyen/ws10smt/decoder/grammar.h:
-
/home/tnguyen/ws10smt/decoder/hg.h:
/home/tnguyen/ws10smt/decoder/small_vector.h:
@@ -198,6 +198,10 @@ scfg.d scfg.o: ../scfg.cpp /home/tnguyen/ws10smt/decoder/lattice.h \
/home/tnguyen/ws10smt/decoder/logval.h:
+/home/tnguyen/ws10smt/decoder/bottom_up_parser.h:
+
+/home/tnguyen/ws10smt/decoder/grammar.h:
+
/home/tnguyen/ws10smt/decoder/hg_intersect.h:
../../utils/ParamsArray.h:
diff --git a/gi/scfg/abc/agrammar.cc b/gi/scfg/abc/agrammar.cc
index 585255e3..016a0189 100644
--- a/gi/scfg/abc/agrammar.cc
+++ b/gi/scfg/abc/agrammar.cc
@@ -8,6 +8,18 @@
#include "agrammar.h"
#include "../utils/Util.h"
+
+
+aTRule::aTRule(TRulePtr rule){
+
+ this -> e_ = rule->e_;
+ this -> f_ = rule->f_;
+ this ->lhs_ = rule->lhs_;
+ this -> arity_ = rule->arity_;
+ this -> scores_ = rule->scores_;
+ ResetScore(0.00000001);
+}
+
bool equal(TRulePtr const & rule1, TRulePtr const & rule2){
if (rule1->lhs_ != rule2->lhs_) return false;
if (rule1->f_.size() != rule2->f_.size()) return false;
@@ -20,16 +32,25 @@ bool equal(TRulePtr const & rule1, TRulePtr const & rule2){
return true;
}
+
//const vector<TRulePtr> Grammar::NO_RULES;
void aRemoveRule(vector<TRulePtr> & v, const TRulePtr & rule){ // remove rule from v if found
for (int i=0; i< v.size(); i++)
if (equal(v[i], rule )){
- cout<<"erase rule from vector:"<<rule->AsString()<<endl;
+ // cout<<"erase rule from vector:"<<rule->AsString()<<endl;
v.erase(v.begin()+i);
}
}
+void aRemoveRule(vector<NTRule> & v, const NTRule & ntrule){ // remove rule from v if found
+ for (int i=0; i< v.size(); i++)
+ if (equal(v[i].rule_, ntrule.rule_ )){
+ // cout<<"erase rule from vector:"<<rule->AsString()<<endl;
+ v.erase(v.begin()+i);
+ }
+}
+
struct aTextRuleBin : public RuleBin {
int GetNumRules() const {
return rules_.size();
@@ -40,20 +61,16 @@ struct aTextRuleBin : public RuleBin {
void AddRule(TRulePtr t) {
rules_.push_back(t);
}
- void RemoveRule(TRulePtr t){
- for (int i=0; i<rules_.size(); i++){
- if (equal(rules_.at(i), t)){
- rules_.erase(rules_.begin() + i);
- //cout<<"IntextRulebin removerulle\n";
- return;
- }
- }
+
+ void RemoveRule(const TRulePtr & rule ){
+ aRemoveRule(rules_, rule);
}
int Arity() const {
return rules_.front()->Arity();
}
+
void Dump() const {
for (int i = 0; i < rules_.size(); ++i)
cerr << rules_[i]->AsString() << endl;
@@ -62,6 +79,7 @@ struct aTextRuleBin : public RuleBin {
vector<TRulePtr> rules_;
};
+
struct aTextGrammarNode : public GrammarIter {
aTextGrammarNode() : rb_(NULL) {}
~aTextGrammarNode() {
@@ -90,8 +108,8 @@ struct aTGImpl {
aTextGrammar::aTextGrammar() : max_span_(10), pimpl_(new aTGImpl) {}
aTextGrammar::aTextGrammar(const string& file) :
- max_span_(10),
- pimpl_(new aTGImpl) {
+ max_span_(10),
+ pimpl_(new aTGImpl) {
ReadFromFile(file);
}
@@ -103,6 +121,7 @@ void aTextGrammar::SetGoalNT(const string & goal_str){
goalID = TD::Convert(goal_str);
}
+
void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){
NTRule lhs_ntrule(rule, rule->lhs_ * -1);
@@ -113,9 +132,9 @@ void getNTRule( const TRulePtr & rule, map<WordID, NTRule> & ntrule_map){
NTRule rhs_ntrule(rule, rule->f_.at(i) * -1);
ntrule_map[(rule->f_).at(i) *-1] = rhs_ntrule;
}
-
-
}
+
+
void aTextGrammar::AddRule(const TRulePtr& rule) {
if (rule->IsUnary()) {
rhs2unaries_[rule->f().front()].push_back(rule);
@@ -141,7 +160,7 @@ void aTextGrammar::AddRule(const TRulePtr& rule) {
}
void aTextGrammar::RemoveRule(const TRulePtr & rule){
- cout<<"Remove rule: "<<rule->AsString()<<endl;
+ // cout<<"Remove rule: "<<rule->AsString()<<endl;
if (rule->IsUnary()) {
aRemoveRule(rhs2unaries_[rule->f().front()], rule);
aRemoveRule(unaries_, rule);
@@ -158,6 +177,14 @@ void aTextGrammar::RemoveRule(const TRulePtr & rule){
aRemoveRule(lhs_rules_[rule->lhs_ * -1] , rule);
+
+ //remove the rule from nt_rules_
+ map<WordID, NTRule> ntrule_map;
+ getNTRule (rule, ntrule_map);
+ for (map<WordID,NTRule>::const_iterator it= ntrule_map.begin(); it != ntrule_map.end(); it++){
+ aRemoveRule(nt_rules_[it->first], it->second);
+ }
+
}
void aTextGrammar::RemoveNonterminal(WordID wordID){
@@ -166,6 +193,8 @@ void aTextGrammar::RemoveNonterminal(WordID wordID){
nt_rules_.erase(wordID);
for (int i =0; i<rules.size(); i++)
RemoveRule(rules[i].rule_);
+ sum_probs_.erase(wordID);
+ cnt_rules.erase(wordID);
}
@@ -199,7 +228,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){
map<WordID, int> cnt_addepsilon; //cnt_addepsilon and cont_minusepsilon to track the number of rules epsilon is added or minus for each lhs nonterminal, ideally we want these two numbers are equal
- map<WordID, int> cnt_minusepsilon; //these two number also use to control the random generated add epsilon/minus epsilon of a new rule
+ map<WordID, int> cnt_minusepsilon;
cnt_addepsilon[old_rule.rule_->lhs_] = 0;
cnt_minusepsilon[old_rule.rule_->lhs_] = 0;
for (int j =0; j<nts.size(); j++) { cnt_addepsilon[nts[j] ] = 0; cnt_minusepsilon[nts[j] ] = 0;}
@@ -217,7 +246,7 @@ void aTextGrammar::AddSplitNonTerminal(WordID nt_old, vector<WordID> & nts){
// cout<<"print vector j_vector"<<endl;
// for (int k=0; k<ntPos.size();k++) cout<<j_vector[k]<<" "; cout<<endl;
//now use the vector to create a new rule
- TRulePtr newrule(new TRule());
+ TRulePtr newrule(new aTRule());
newrule -> e_ = (old_rule.rule_)->e_;
newrule -> f_ = old_rule.rule_->f_;
@@ -323,7 +352,7 @@ void aTextGrammar::splitNonterminal(WordID wordID){
if (wordID == goalID){ //add rule X-> X1; X->X2,... if X is the goal NT
for (int i =0; i<v_splits.size(); i++){
- TRulePtr rule (new TRule());
+ TRulePtr rule (new aTRule());
rule ->lhs_ = goalID * -1;
rule ->f_.push_back(v_splits[i] * -1);
rule->e_.push_back(0);
@@ -334,20 +363,100 @@ void aTextGrammar::splitNonterminal(WordID wordID){
}
+}
+
+void aTextGrammar::splitAllNonterminals(){
+ map<WordID, vector<TRulePtr> >::const_iterator it;
+ vector<WordID> v ; // WordID >0
+ for (it = lhs_rules_.begin(); it != lhs_rules_.end(); it++) //iterate through all nts
+ if (it->first != goalID || lhs_rules_.size() ==1)
+ v.push_back(it->first);
+
+ for (int i=0; i< v.size(); i++)
+ splitNonterminal(v[i]);
}
+void aTextGrammar::PrintAllRules(const string & filename) const{
-void aTextGrammar::PrintAllRules() const{
- map<WordID, vector<TRulePtr> >::const_iterator it;
+
+ cerr<<"print grammar to "<<filename<<endl;
+
+ ofstream outfile(filename.c_str());
+ if (!outfile.good()) {
+ cerr << "error opening output file " << filename << endl;
+ exit(1);
+ }
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
vector<TRulePtr> v = it-> second;
for (int i =0; i< v.size(); i++){
- cout<<v[i]->AsString()<<"\t"<<endl;
+ outfile<<v[i]->AsString()<<"\t"<<endl;
+ }
+ }
+}
+
+
+void aTextGrammar::ResetScore(){
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
+ for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
+ vector<TRulePtr> v = it-> second;
+ for (int i =0; i< v.size(); i++){
+ // cerr<<"Reset score of Rule "<<v[i]->AsString()<<endl;
+ boost::static_pointer_cast<aTRule>(v[i])->ResetScore(alpha_ /v.size());
+ }
+ lhs_rules_[it->first] = v;
+ sum_probs_[it->first] = alpha_;
+ }
+
+}
+
+void aTextGrammar::UpdateScore(){
+
+ map<WordID, vector<TRulePtr > >::const_iterator it;
+ for (it= lhs_rules_.begin(); it != lhs_rules_.end(); it++){
+ vector<TRulePtr> v = it-> second;
+ for (int i =0; i< v.size(); i++){
+ boost::static_pointer_cast<aTRule>(v[i])->UpdateScore(sum_probs_[it->first] );
}
+
+ // cerr<<"sum_probs_[it->first] ="<<sum_probs_[it->first] <<endl;
+ sum_probs_[it->first] = alpha_;
}
+
+}
+
+
+void aTextGrammar::UpdateHgProsteriorProb(Hypergraph & hg){
+ std::vector<prob_t> posts ;
+
+ prob_t goal_score = hg.ComputeEdgePosteriors(1, &posts);
+ for (int i =0; i<posts.size(); i++){
+
+ //cout<<posts[i]<<endl;
+ Hypergraph::Edge& e = hg.edges_[i];
+ string goalstr("Goal");
+ string str_lhs = TD::Convert(e.rule_->lhs_ * -1);
+
+ if (str_lhs.find(goalstr) != string::npos)
+ continue;
+
+ // cerr<<e.rule_->AsString()<<endl;
+ // cerr<<e.rule_->parent_rule_->AsString()<<endl;
+
+ boost::static_pointer_cast<aTRule>(e.rule_->parent_rule_)->AddProb(posts[i] / goal_score);
+ // cerr<<"add count for rule\n";
+// cerr<<"posts[i]="<<posts[i]<<" goal_score="<<goal_score<<endl;
+// cerr<<"posts[i] /goal_score="<<(posts[i] /goal_score)<<endl;
+ sum_probs_[e.rule_->parent_rule_->lhs_* -1 ] += posts[i] /goal_score;
+
+ }
+
+
}
@@ -364,7 +473,9 @@ void aTextGrammar::PrintNonterminalRules(WordID nt) const{
}
static void AddRuleHelper(const TRulePtr& new_rule, void* extra) {
- static_cast<aTextGrammar*>(extra)->AddRule(new_rule);
+ aTRule *p = new aTRule(new_rule);
+
+ static_cast<aTextGrammar*>(extra)->AddRule(TRulePtr(p));
}
void aTextGrammar::ReadFromFile(const string& filename) {
diff --git a/gi/scfg/abc/agrammar.h b/gi/scfg/abc/agrammar.h
index 8a7186bf..0a8a60ac 100644
--- a/gi/scfg/abc/agrammar.h
+++ b/gi/scfg/abc/agrammar.h
@@ -2,10 +2,39 @@
#define AGRAMMAR_H_
#include "grammar.h"
+#include "hg.h"
using namespace std;
+class aTRule: public TRule{
+ public:
+ aTRuleTRule : TRule(){ResetScore(0.00000001); }
+ aTRule(TRulePtr rule_);
+
+ void ResetScore(double initscore){//cerr<<"Reset Score "<<this->AsString()<<endl;
+ sum_scores_.set_value(FD::Convert("Prob"), initscore);}
+ void AddProb(double p ){
+ // cerr<<"in AddProb p="<<p<<endl;
+ // cerr<<"prob sumscores ="<<sum_scores_[FD::Convert("Prob")]<<endl;
+ sum_scores_.add_value(FD::Convert("Prob"), p);
+ // cerr<<"after AddProb\n";
+ }
+
+ void UpdateScore(double sumprob){
+ double minuslogp = 0 - log( sum_scores_.value(FD::Convert("Prob")) /sumprob);
+ if (sumprob< sum_scores_.value(FD::Convert("Prob"))){
+ cerr<<"UpdateScore sumprob="<<sumprob<< " sum_scores_.value(FD::Convert(\"Prob\"))="<< sum_scores_.value(FD::Convert("Prob"))<< this->AsString()<<endl;
+ exit(1);
+ }
+ this->scores_.set_value(FD::Convert("MinusLogP"), minuslogp);
+
+ }
+ private:
+ SparseVector<double> sum_scores_;
+};
+
+
class aTGImpl;
struct NTRule{
@@ -20,17 +49,19 @@ struct NTRule{
for (int i=0; i< rule->f().size(); i++)
if (rule->f().at(i) * -1 == nt)
ntPos_.push_back(i);
+
+
}
TRulePtr rule_;
- WordID nt_; //the labelID of the nt (WordID>0);
+ WordID nt_; //the labelID of the nt (nt_>0);
vector<int> ntPos_; //position of nt_ -1: lhs, from 0...f_.size() for nt of f_()
//i.e the rules is: NP-> DET NP; if nt_=5 is the labelID of NP then ntPos_ = (-1, 1): the indexes of nonterminal NP
-
};
+
struct aTextGrammar : public Grammar {
aTextGrammar();
aTextGrammar(const std::string& file);
@@ -46,9 +77,20 @@ struct aTextGrammar : public Grammar {
void setMaxSplit(int max_split);
void splitNonterminal(WordID wordID);
- void PrintAllRules() const;
+
+ void splitAllNonterminals();
+
+ void PrintAllRules(const string & filename) const;
void PrintNonterminalRules(WordID nt) const;
void SetGoalNT(const string & goal_str);
+
+ void ResetScore();
+
+ void UpdateScore();
+
+ void UpdateHgProsteriorProb(Hypergraph & hg);
+
+ void set_alpha(double alpha){alpha_ = alpha;}
private:
void RemoveRule(const TRulePtr & rule);
@@ -57,9 +99,15 @@ struct aTextGrammar : public Grammar {
int max_span_;
int max_split_;
boost::shared_ptr<aTGImpl> pimpl_;
+
map <WordID, vector<TRulePtr> > lhs_rules_;// WordID >0
map <WordID, vector<NTRule> > nt_rules_;
+ map <WordID, double> sum_probs_;
+ map <WordID, double> cnt_rules;
+
+ double alpha_;
+
// map<WordID, vector<WordID> > grSplitNonterminals;
WordID goalID;
};
diff --git a/gi/scfg/abc/scfg.cpp b/gi/scfg/abc/scfg.cpp
index 4d094488..b3dbad34 100644
--- a/gi/scfg/abc/scfg.cpp
+++ b/gi/scfg/abc/scfg.cpp
@@ -1,3 +1,8 @@
+#include <iostream>
+#include <fstream>
+
+#include <boost/shared_ptr.hpp>
+#include <boost/pointer_cast.hpp>
#include "lattice.h"
#include "tdict.h"
#include "agrammar.h"
@@ -9,13 +14,53 @@
using namespace std;
+vector<string> src_corpus;
+vector<string> tgt_corpus;
+
+bool openParallelCorpora(string & input_filename){
+ ifstream input_file;
+
+ input_file.open(input_filename.c_str());
+ if (!input_file) {
+ cerr << "Cannot open input file " << input_filename << ". Exiting..." << endl;
+ return false;
+ }
+
+ int line =0;
+ while (!input_file.eof()) {
+ // get a line of source language data
+ // cerr<<"new line "<<ctr<<endl;
+ string str;
+
+ getline(input_file, str);
+ line++;
+ if (str.length()==0){
+ cerr<<" sentence number "<<line<<" is empty, skip the sentence\n";
+ continue;
+ }
+ string delimiters("|||");
+
+ vector<string> v = tokenize(str, delimiters);
+
+ if ( (v.size() != 2) and (v.size() != 3) ) {
+ cerr<<str<<endl;
+ cerr<<" source or target sentence is not found in sentence number "<<line<<" , skip the sentence\n";
+ continue;
+ }
+
+ src_corpus.push_back(v[0]);
+ tgt_corpus.push_back(v[1]);
+ }
+ return true;
+}
+
+
typedef aTextGrammar aGrammar;
aGrammar * load_grammar(string & grammar_filename){
cerr<<"start_load_grammar "<<grammar_filename<<endl;
aGrammar * test = new aGrammar(grammar_filename);
-
return test;
}
@@ -26,7 +71,6 @@ Lattice convertSentenceToLattice(const string & str){
Lattice lsentence;
lsentence.resize(vID.size());
-
for (int i=0; i<vID.size(); i++){
lsentence[i].push_back( LatticeArc(vID[i], 0.0, 1) );
@@ -41,6 +85,8 @@ Lattice convertSentenceToLattice(const string & str){
bool parseSentencePair(const string & goal_sym, const string & src, const string & tgt, GrammarPtr & g, Hypergraph &hg){
+
+ // cout<<" Start parse the sentence pairs\n"<<endl;
Lattice lsource = convertSentenceToLattice(src);
//parse the source sentence by the grammar
@@ -51,7 +97,7 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string
if (!parser.Parse(lsource, &hg)){
- cerr<<"source sentence does not parse by the grammar!"<<endl;
+ cerr<<"source sentence is not parsed by the grammar!"<<endl;
return false;
}
@@ -59,8 +105,15 @@ bool parseSentencePair(const string & goal_sym, const string & src, const string
Lattice ltarget = convertSentenceToLattice(tgt);
//forest.PrintGraphviz();
- return HG::Intersect(ltarget, & hg);
+ if (!HG::Intersect(ltarget, & hg)) return false;
+
+ SparseVector<double> reweight;
+
+ reweight.set_value(FD::Convert("MinusLogP"), -1 );
+ hg.Reweight(reweight);
+ return true;
+
}
@@ -71,74 +124,140 @@ int main(int argc, char** argv){
ParamsArray params(argc, argv);
params.setDescription("scfg models");
- params.addConstraint("grammar_file", "grammar file ", true); // optional
+ params.addConstraint("grammar_file", "grammar file (default ./grammar.pr )", true); // optional
+
+ params.addConstraint("input_file", "parallel input file (default ./parallel_corpora)", true); //optional
+
+ params.addConstraint("output_file", "grammar output file (default ./grammar_output)", true); //optional
+
+ params.addConstraint("goal_symbol", "top nonterminal symbol (default: X)", true); //optional
+
+ params.addConstraint("split", "split one nonterminal into 'split' nonterminals (default: 2)", true); //optional
- params.addConstraint("input_file", "parallel input file", true); //optional
+ params.addConstraint("prob_iters", "number of iterations (default: 10)", true); //optional
+
+ params.addConstraint("split_iters", "number of splitting iterations (default: 3)", true); //optional
+
+ params.addConstraint("alpha", "alpha (default: 0.1)", true); //optional
if (!params.runConstraints("scfg")) {
return 0;
}
cerr<<"get parametters\n\n\n";
- string input_file = params.asString("input_file", "parallel_corpora");
string grammar_file = params.asString("grammar_file", "./grammar.pr");
+ string input_file = params.asString("input_file", "parallel_corpora");
- string src = "el gato .";
-
- string tgt = "the cat .";
-
-
- string goal_sym = "X";
- srand(123);
- /*load grammar*/
+ string output_file = params.asString("output_file", "grammar_output");
+ string goal_sym = params.asString("goal_symbol", "X");
+ int max_split = atoi(params.asString("split", "2").c_str());
+
+ int prob_iters = atoi(params.asString("prob_iters", "2").c_str());
+ int split_iters = atoi(params.asString("split_iters", "1").c_str());
+ double alpha = atof(params.asString("alpha", ".001").c_str());
+
+ /////
+ cerr<<"grammar_file ="<<grammar_file<<endl;
+ cerr<<"input_file ="<< input_file<<endl;
+ cerr<<"output_file ="<< output_file<<endl;
+ cerr<<"goal_sym ="<< goal_sym<<endl;
+ cerr<<"max_split ="<< max_split<<endl;
+ cerr<<"prob_iters ="<< prob_iters<<endl;
+ cerr<<"split_iters ="<< split_iters<<endl;
+ cerr<<"alpha ="<< alpha<<endl;
+ //////////////////////////
+
+ cerr<<"\n\nLoad parallel corpus...\n";
+ if (! openParallelCorpora(input_file))
+ exit(1);
+
+ cerr<<"Load grammar file ...\n";
aGrammar * agrammar = load_grammar(grammar_file);
agrammar->SetGoalNT(goal_sym);
- cout<<"before split nonterminal"<<endl;
- GrammarPtr g( agrammar);
+ agrammar->setMaxSplit(max_split);
+ agrammar->set_alpha(alpha);
+ srand(123);
+
+ GrammarPtr g( agrammar);
Hypergraph hg;
- if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){
- cerr<<"target sentence is not parsed by the grammar!\n";
- return 1;
- }
- hg.PrintGraphviz();
+ int data_size = src_corpus.size();
+ for (int i =0; i <split_iters; i++){
+
+ cerr<<"Split Nonterminals, iteration "<<(i+1)<<endl;
+ agrammar->PrintAllRules(output_file+".s" + itos(i+1));
+ agrammar->splitAllNonterminals();
+
+ //vector<string> src_corpus;
+ //vector<string> tgt_corpus;
+
+ for (int j=0; j<prob_iters; j++){
+ cerr<<"reset grammar score\n";
+ agrammar->ResetScore();
+ // cerr<<"done reset grammar score\n";
+ for (int k=0; k <data_size; k++){
+ string src = src_corpus[k];
+
+ string tgt = tgt_corpus[k];
+ cerr <<"parse sentence pair: "<<src<<" ||| "<<tgt<<endl;
+
+ if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){
+ cerr<<"target sentence is not parsed by the grammar!\n";
+ //return 1;
+ continue;
+
+ }
+ cerr<<"update edge posterior prob"<<endl;
+ boost::static_pointer_cast<aGrammar>(g)->UpdateHgProsteriorProb(hg);
+ hg.clear();
+ }
+ boost::static_pointer_cast<aGrammar>(g)->UpdateScore();
+ }
+ boost::static_pointer_cast<aGrammar>(g)->PrintAllRules(output_file+".e" + itos(i+1));
+ }
- if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){
- cerr<<"target sentence is not parsed by the grammar!\n";
- return 1;
- }
- hg.PrintGraphviz();
- //hg.clear();
- if (1==1) return 1;
+
+
+
- agrammar->PrintAllRules();
- /*split grammar*/
- cout<<"split NTs\n";
- cerr<<"first of all write all nonterminals"<<endl;
- // agrammar->printAllNonterminals();
- agrammar->setMaxSplit(2);
- agrammar->splitNonterminal(4);
- cout<<"after split nonterminal"<<endl;
- agrammar->PrintAllRules();
- Hypergraph hg1;
- if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){
- cerr<<"target sentence is not parsed by the grammar!\n";
- return 1;
- }
- hg1.PrintGraphviz();
+ // // agrammar->ResetScore();
+ // // agrammar->UpdateScore();
+ // if (! parseSentencePair(goal_sym, src, tgt, g, hg) ){
+ // cerr<<"target sentence is not parsed by the grammar!\n";
+ // return 1;
+
+ // }
+ // // hg.PrintGraphviz();
+ // //hg.clear();
+
+ // agrammar->PrintAllRules();
+ // /*split grammar*/
+ // cout<<"split NTs\n";
+ // cerr<<"first of all write all nonterminals"<<endl;
+ // // agrammar->printAllNonterminals();
+ // cout<<"after split nonterminal"<<endl;
+ // agrammar->PrintAllRules();
+ // Hypergraph hg1;
+ // if (! parseSentencePair(goal_sym, src, tgt, g, hg1) ){
+ // cerr<<"target sentence is not parsed by the grammar!\n";
+ // return 1;
+
+ // }
+
+ // hg1.PrintGraphviz();
- agrammar->splitNonterminal(15);
- cout<<"after split nonterminal"<<TD::Convert(15)<<endl;
- agrammar->PrintAllRules();
+ // agrammar->splitNonterminal(15);
+ // cout<<"after split nonterminal"<<TD::Convert(15)<<endl;
+ // agrammar->PrintAllRules();
/*load training corpus*/