diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 14:17:09 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-07-12 14:17:09 +0000 |
commit | 1e4724dd169fbb20fc7448cc2cb1ae1bc539560c (patch) | |
tree | 5a949c59d453c13666e037c17cc2faaf21cfd37b /gi/posterior-regularisation/prjava | |
parent | 7b9c1f91e594c4b7783c72e4516d59d60a04dc91 (diff) |
agreement model
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@221 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Agree.java | 174 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/C2F.java | 4 |
2 files changed, 176 insertions, 2 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Agree.java b/gi/posterior-regularisation/prjava/src/phrase/Agree.java new file mode 100644 index 00000000..091875ce --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/phrase/Agree.java @@ -0,0 +1,174 @@ +package phrase;
+
+import gnu.trove.TIntArrayList;
+
+import io.FileUtil;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.List;
+
+import phrase.Corpus.Edge;
+
+public class Agree {
+ private PhraseCluster model1;
+ private C2F model2;
+ Corpus c;
+ private int K,n_phrases, n_words, n_contexts, n_positions1,n_positions2;
+
+ /**
+ *
+ * @param numCluster
+ * @param corpus
+ */
+ public Agree(int numCluster, Corpus corpus){
+
+ model1=new PhraseCluster(numCluster, corpus, 0, 0, 0);
+ model2=new C2F(numCluster,corpus);
+ c=corpus;
+ n_words=c.getNumWords();
+ n_phrases=c.getNumPhrases();
+ n_contexts=c.getNumContexts();
+ n_positions1=c.getNumContextPositions();
+ n_positions2=2;
+ K=numCluster;
+
+ }
+
+ /**@brief test
+ *
+ */
+ public static void main(String args[]){
+ String in="../pdata/canned.con";
+ String out="../pdata/posterior.out";
+ int numCluster=25;
+ Corpus corpus = null;
+ File infile = new File(in);
+ try {
+ System.out.println("Reading concordance from " + infile);
+ corpus = Corpus.readFromFile(FileUtil.reader(infile));
+ corpus.printStats(System.out);
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + infile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ Agree agree=new Agree(numCluster, corpus);
+ int iter=20;
+ double llh=0;
+ for(int i=0;i<iter;i++){
+ llh=agree.EM();
+ System.out.println("Iter"+i+", llh: "+llh);
+ }
+
+ File outfile = new File (out);
+ try {
+ PrintStream ps = FileUtil.printstream(outfile);
+ agree.displayPosterior(ps);
+ // ps.println();
+ // c2f.displayModelParam(ps);
+ ps.close();
+ } catch (IOException e) {
+ System.err.println("Failed to open output file: " + outfile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ }
+
+ public double EM(){
+
+ double [][][]exp_emit1=new double [K][n_positions1][n_words];
+ double [][]exp_pi1=new double[n_phrases][K];
+
+ double [][][]exp_emit2=new double [K][n_positions2][n_words];
+ double [][]exp_pi2=new double[n_contexts][K];
+
+ double loglikelihood=0;
+
+ //E
+ for(int context=0; context< n_contexts; context++){
+
+ List<Edge> contexts = c.getEdgesForContext(context);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++){
+ Edge edge = contexts.get(ctx);
+ int phrase=edge.getPhraseId();
+ double p[]=posterior(edge);
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood += edge.getCount() * Math.log(z);
+ arr.F.l1normalize(p);
+
+ int count = edge.getCount();
+ //increment expected count
+ TIntArrayList phraseToks = edge.getPhrase();
+ TIntArrayList contextToks = edge.getContext();
+ for(int tag=0;tag<K;tag++){
+
+ for(int position=0;position<n_positions1;position++){
+ exp_emit1[tag][position][contextToks.get(position)]+=p[tag]*count;
+ }
+
+ exp_emit2[tag][0][phraseToks.get(0)]+=p[tag]*count;
+ exp_emit2[tag][1][phraseToks.get(phraseToks.size()-1)]+=p[tag]*count;
+
+ exp_pi1[phrase][tag]+=p[tag]*count;
+ exp_pi2[context][tag]+=p[tag]*count;
+ }
+ }
+ }
+
+ //System.out.println("Log likelihood: "+loglikelihood);
+
+ //M
+ for(double [][]i:exp_emit1){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ for(double []j:exp_pi1){
+ arr.F.l1normalize(j);
+ }
+
+ model1.emit=exp_emit1;
+ model1.pi=exp_pi1;
+ model2.emit=exp_emit2;
+ model2.pi=exp_pi2;
+
+ return loglikelihood;
+ }
+
+ public double[] posterior(Corpus.Edge edge)
+ {
+ double[] prob1=model1.posterior(edge);
+ double[] prob2=model2.posterior(edge);
+
+ for(int i=0;i<prob1.length;i++){
+ prob1[i]*=prob2[i];
+ prob1[i]=Math.sqrt(prob1[i]);
+ }
+
+ return prob1;
+ }
+
+ public void displayPosterior(PrintStream ps)
+ {
+ for (Edge edge : c.getEdges())
+ {
+ double probs[] = posterior(edge);
+ arr.F.l1normalize(probs);
+
+ // emit phrase
+ ps.print(edge.getPhraseString());
+ ps.print("\t");
+ ps.print(edge.getContextString(true));
+ int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
+ }
+ }
+
+}
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java index 3456c953..a8e557f2 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java +++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java @@ -25,11 +25,11 @@ public class C2F { /**@brief
* emit[tag][position][word] = p(word | tag, position in phrase)
*/
- private double emit[][][];
+ public double emit[][][];
/**@brief
* pi[context][tag] = p(tag | context)
*/
- private double pi[][];
+ public double pi[][];
public C2F(int numCluster, Corpus corpus){
K=numCluster;
|