diff options
author | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-28 23:14:21 +0000 |
---|---|---|
committer | desaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f> | 2010-06-28 23:14:21 +0000 |
commit | ebd00f59aab18446051f9838d3d08427b242b435 (patch) | |
tree | 155a32bba53f50c322c113f51a09a99c0a30475a /gi/posterior-regularisation/prjava/src/hmm/POS.java | |
parent | 95ecf39865c10d46b80e30021e1a838d77eaf09a (diff) |
add draft version of POS induction with HMM and L1 Linf constraints
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@47 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src/hmm/POS.java')
-rw-r--r-- | gi/posterior-regularisation/prjava/src/hmm/POS.java | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/gi/posterior-regularisation/prjava/src/hmm/POS.java b/gi/posterior-regularisation/prjava/src/hmm/POS.java new file mode 100644 index 00000000..722d38e2 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/POS.java @@ -0,0 +1,126 @@ +package hmm;
+
+import java.io.PrintStream;
+import java.util.HashMap;
+
+import data.Corpus;
+
+public class POS {
+
+ //public String trainFilename="../posdata/en_train.conll";
+ //public static String trainFilename="../posdata/small_train.txt";
+ public static String trainFilename="../posdata/en_test.conll";
+// public static String trainFilename="../posdata/trial1.txt";
+
+ public static String testFilename="../posdata/en_test.conll";
+ //public static String testFilename="../posdata/trial1.txt";
+
+ public static String predFilename="../posdata/en_test.predict.conll";
+ public static String modelFilename="../posdata/posModel.out";
+ public static final int ITER=20;
+ public static final int N_STATE=30;
+
+ public static void main(String[] args) {
+ //POS p=new POS();
+ //POS p=new POS(true);
+ PRPOS();
+ }
+
+
+ public POS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter"+i);
+ hmm.EM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+ //POS induction with L1/Linf constraints
+ public static void PRPOS(){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(N_STATE, c.getVocabSize()+1,c.getAllData());
+ hmm.o=new HMMObjective(hmm);
+ for(int i=0;i<ITER;i++){
+ System.out.println("Iter: "+i);
+ hmm.PREM();
+ if((i+1)%10==0){
+ hmm.writeModel(modelFilename+i);
+ }
+ }
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tag[j]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+
+
+ public POS(boolean supervised){
+ Corpus c= new Corpus(trainFilename);
+ //size of vocabulary +1 for unknown tokens
+ HMM hmm =new HMM(c.tagVocab.size() , c.getVocabSize()+1,c.getAllData());
+ hmm.train(c.getTagData());
+
+ hmm.writeModel(modelFilename);
+
+ Corpus test=new Corpus(testFilename,c.vocab);
+
+ HashMap<String, Integer>tagVocab=
+ (HashMap<String, Integer>) io.SerializedObjects.readSerializedObject(Corpus.tagalphaFilename);
+ String [] tagdict=new String [tagVocab.size()+1];
+ for(String key:tagVocab.keySet()){
+ tagdict[tagVocab.get(key)]=key;
+ }
+ tagdict[tagdict.length-1]=Corpus.UNK_TOK;
+
+ System.out.println(c.vocab.get("<e>"));
+
+ PrintStream ps= io.FileUtil.openOutFile(predFilename);
+
+ int [][]data=test.getAllData();
+ for(int i=0;i<data.length;i++){
+ int []tag=hmm.viterbi(data[i]);
+ String sent[]=test.get(i);
+ for(int j=0;j<data[i].length;j++){
+ ps.println(sent[j]+"\t"+tagdict[tag[j]]);
+ }
+ ps.println();
+ }
+ ps.close();
+ }
+}
|