From 3368ed9579857982a51d78e834cd6f44e1915deb Mon Sep 17 00:00:00 2001 From: desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> Date: Mon, 12 Jul 2010 14:17:09 +0000 Subject: agreement model git-svn-id: https://ws10smt.googlecode.com/svn/trunk@221 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/phrase/Agree.java | 174 +++++++++++++++++++++ .../prjava/src/phrase/C2F.java | 4 +- 2 files changed, 176 insertions(+), 2 deletions(-) create mode 100644 gi/posterior-regularisation/prjava/src/phrase/Agree.java (limited to 'gi') diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java new file mode 100644 index 00000000..091875ce --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java @@ -0,0 +1,174 @@ +package phrase; + +import gnu.trove.TIntArrayList; + +import io.FileUtil; + +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.util.List; + +import phrase.Corpus.Edge; + +public class Agree { + private PhraseCluster model1; + private C2F model2; + Corpus c; + private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2; + + /** + * + * @param numCluster + * @param corpus + */ + public Agree(int numCluster, Corpus corpus){ + + model1=new PhraseCluster(numCluster, corpus, 0, 0, 0); + model2=new C2F(numCluster,corpus); + c=corpus; + n_words=c.getNumWords(); + n_phrases=c.getNumPhrases(); + n_contexts=c.getNumContexts(); + n_positions1=c.getNumContextPositions(); + n_positions2=2; + K=numCluster; + + } + + /**@brief test + * + */ + public static void main(String args[]){ + String in="../pdata/canned.con"; + String out="../pdata/posterior.out"; + int numCluster=25; + Corpus corpus = null; + File infile = new File(in); + try { + System.out.println("Reading concordance from " + infile); + corpus = Corpus.readFromFile(FileUtil.reader(infile)); + corpus.printStats(System.out); + } catch (IOException e) { + System.err.println("Failed to open input file: " + infile); + e.printStackTrace(); + System.exit(1); + } + + Agree agree=new Agree(numCluster, corpus); + int iter=20; + double llh=0; + for(int i=0;i<iter;i++){ + llh=agree.EM(); + System.out.println("Iter"+i+", llh: "+llh); + } + + File outfile = new File (out); + try { + PrintStream ps = FileUtil.printstream(outfile); + agree.displayPosterior(ps); + // ps.println(); + // c2f.displayModelParam(ps); + ps.close(); + } catch (IOException e) { + System.err.println("Failed to open output file: " + outfile); + e.printStackTrace(); + System.exit(1); + } + + } + + public double EM(){ + + double [][][]exp_emit1=new double [K][n_positions1][n_words]; + double [][]exp_pi1=new double[n_phrases][K]; + + double [][][]exp_emit2=new double [K][n_positions2][n_words]; + double [][]exp_pi2=new double[n_contexts][K]; + + double loglikelihood=0; + + //E + for(int context=0; context< n_contexts; context++){ + + List<Edge> contexts = c.getEdgesForContext(context); + + for (int ctx=0; ctx<contexts.size(); ctx++){ + Edge edge = contexts.get(ctx); + int phrase=edge.getPhraseId(); + double p[]=posterior(edge); + double z = arr.F.l1norm(p); + assert z > 0; + loglikelihood += edge.getCount() * Math.log(z); + arr.F.l1normalize(p); + + int count = edge.getCount(); + //increment expected count + TIntArrayList phraseToks = edge.getPhrase(); + TIntArrayList contextToks = edge.getContext(); + for(int tag=0;tag<K;tag++){ + + for(int position=0;position<n_positions1;position++){ + exp_emit1[tag][position][contextToks.get(position)]+=p[tag]*count; + } + + exp_emit2[tag][0][phraseToks.get(0)]+=p[tag]*count; + exp_emit2[tag][1][phraseToks.get(phraseToks.size()-1)]+=p[tag]*count; + + exp_pi1[phrase][tag]+=p[tag]*count; + exp_pi2[context][tag]+=p[tag]*count; + } + } + } + + //System.out.println("Log likelihood: "+loglikelihood); + + //M + for(double [][]i:exp_emit1){ + for(double []j:i){ + arr.F.l1normalize(j); + } + } + + for(double []j:exp_pi1){ + arr.F.l1normalize(j); + } + + model1.emit=exp_emit1; + model1.pi=exp_pi1; + model2.emit=exp_emit2; + model2.pi=exp_pi2; + + return loglikelihood; + } + + public double[] posterior(Corpus.Edge edge) + { + double[] prob1=model1.posterior(edge); + double[] prob2=model2.posterior(edge); + + for(int i=0;i<prob1.length;i++){ + prob1[i]*=prob2[i]; + prob1[i]=Math.sqrt(prob1[i]); + } + + return prob1; + } + + public void displayPosterior(PrintStream ps) + { + for (Edge edge : c.getEdges()) + { + double probs[] = posterior(edge); + arr.F.l1normalize(probs); + + // emit phrase + ps.print(edge.getPhraseString()); + ps.print("\t"); + ps.print(edge.getContextString(true)); + int t=arr.F.argmax(probs); + ps.println(" ||| C=" + t); + } + } + +} diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java index 3456c953..a8e557f2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -25,11 +25,11 @@ public class C2F { /**@brief * emit[tag][position][word] = p(word | tag, position in phrase) */ - private double emit[][][]; + public double emit[][][]; /**@brief * pi[context][tag] = p(tag | context) */ - private double pi[][]; + public double pi[][]; public C2F(int numCluster, Corpus corpus){ K=numCluster; -- cgit v1.2.3