summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 23:14:21 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-06-28 23:14:21 +0000
commitebd00f59aab18446051f9838d3d08427b242b435 (patch)
tree155a32bba53f50c322c113f51a09a99c0a30475a /gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
parent95ecf39865c10d46b80e30021e1a838d77eaf09a (diff)
add draft version of POS induction with HMM and L1 Linf constraints
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@47 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/test/HMMModelStats.java')
-rw-r--r--gi/posterior-regularisation/prjava/src/test/HMMModelStats.java96
1 files changed, 96 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
new file mode 100644
index 00000000..26d7abec
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/test/HMMModelStats.java
@@ -0,0 +1,96 @@
+package test;
+
+import hmm.HMM;
+import hmm.POS;
+
+import java.io.PrintStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class HMMModelStats {
+
+ public static String modelFilename="../posdata/posModel.out";
+ public static String alphaFilename="../posdata/corpus.alphabet";
+ public static String statsFilename="../posdata/model.stats";
+
+ public static final int NUM_WORD=50;
+
+ public static String testFilename="../posdata/en_test.conll";
+
+ public static double [][]maxwt;
+
+ public static void main(String[] args) {
+ HashMap<String, Integer>vocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(alphaFilename);
+
+ Corpus test=new Corpus(testFilename,vocab);
+
+ String [] dict=new String [vocab.size()+1];
+ for(String key:vocab.keySet()){
+ dict[vocab.get(key)]=key;
+ }
+ dict[dict.length-1]=Corpus.UNK_TOK;
+
+ HMM hmm=new HMM();
+ hmm.readModel(modelFilename);
+
+
+
+ PrintStream ps=io.FileUtil.openOutFile(statsFilename);
+
+ double [][] emit=hmm.getEmitProb();
+ for(int i=0;i<emit.length;i++){
+ ArrayList<IntDoublePair>l=new ArrayList<IntDoublePair>();
+ for(int j=0;j<emit[i].length;j++){
+ l.add(new IntDoublePair(j,emit[i][j]));
+ }
+ Collections.sort(l);
+ ps.println(i);
+ for(int j=0;j<NUM_WORD;j++){
+ if(j>=dict.length){
+ break;
+ }
+ ps.print(dict[l.get(j).idx]+"\t");
+ if((1+j)%10==0){
+ ps.println();
+ }
+ }
+ ps.println("\n");
+ }
+
+ checkMaxwt(hmm,ps,test.getAllData());
+
+ int terminalSym=vocab.get(Corpus .END_SYM);
+ //sample 10 sentences
+ for(int i=0;i<10;i++){
+ int []sent=hmm.sample(terminalSym);
+ for(int j=0;j<sent.length;j++){
+ ps.print(dict[sent[j]]+"\t");
+ }
+ ps.println();
+ }
+
+ ps.close();
+
+ }
+
+ public static void checkMaxwt(HMM hmm,PrintStream ps,int [][]data){
+ double [][]emit=hmm.getEmitProb();
+ maxwt=new double[emit.length][emit[0].length];
+
+ hmm.computeMaxwt(maxwt,data);
+ double sum=0;
+ for(int i=0;i<maxwt.length;i++){
+ for(int j=0;j<maxwt.length;j++){
+ sum+=maxwt[i][j];
+ }
+ }
+
+ ps.println("max w t P(w_i|t)"+sum);
+
+ }
+
+}