From ad418214fe3b3fcd33d81225eb3d3fb08b67f88a Mon Sep 17 00:00:00 2001 From: desaicwtf Date: Mon, 28 Jun 2010 23:14:21 +0000 Subject: add draft version of POS induction with HMM and L1 Linf constraints git-svn-id: https://ws10smt.googlecode.com/svn/trunk@47 ec762483-ff6d-05da-a07a-a48fb63a330f --- .../prjava/src/data/Corpus.java | 230 ++++++++ .../prjava/src/hmm/HMM.java | 576 +++++++++++++++++++++ .../prjava/src/hmm/HMMObjective.java | 348 +++++++++++++ .../prjava/src/hmm/POS.java | 126 +++++ .../prjava/src/io/FileUtil.java | 37 ++ .../prjava/src/io/SerializedObjects.java | 83 +++ .../prjava/src/test/CorpusTest.java | 60 +++ .../prjava/src/test/HMMModelStats.java | 96 ++++ .../prjava/src/test/IntDoublePair.java | 23 + .../prjava/src/test/X2y2WithConstraints.java | 131 +++++ 10 files changed, 1710 insertions(+) create mode 100644 gi/posterior-regularisation/prjava/src/data/Corpus.java create mode 100644 gi/posterior-regularisation/prjava/src/hmm/HMM.java create mode 100644 gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java create mode 100644 gi/posterior-regularisation/prjava/src/hmm/POS.java create mode 100644 gi/posterior-regularisation/prjava/src/io/FileUtil.java create mode 100644 gi/posterior-regularisation/prjava/src/io/SerializedObjects.java create mode 100644 gi/posterior-regularisation/prjava/src/test/CorpusTest.java create mode 100644 gi/posterior-regularisation/prjava/src/test/HMMModelStats.java create mode 100644 gi/posterior-regularisation/prjava/src/test/IntDoublePair.java create mode 100644 gi/posterior-regularisation/prjava/src/test/X2y2WithConstraints.java (limited to 'gi/posterior-regularisation/prjava/src') diff --git a/gi/posterior-regularisation/prjava/src/data/Corpus.java b/gi/posterior-regularisation/prjava/src/data/Corpus.java new file mode 100644 index 00000000..f0da0b33 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/data/Corpus.java @@ -0,0 +1,230 @@ +package data; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Scanner; + +public class Corpus { + + public static final String alphaFilename="../posdata/corpus.alphabet"; + public static final String tagalphaFilename="../posdata/corpus.tag.alphabet"; + +// public static final String START_SYM=""; + public static final String END_SYM=""; + public static final String NUM_TOK=""; + + public static final String UNK_TOK=""; + + private ArrayListsent; + private ArrayListdata; + + public ArrayListtag; + public ArrayListtagData; + + public static boolean convertNumTok=true; + + private HashMapfreq; + public HashMapvocab; + + public HashMaptagVocab; + private int tagV; + + private int V; + + public static void main(String[] args) { + Corpus c=new Corpus("../posdata/en_test.conll"); + System.out.println( + Arrays.toString(c.get(0)) + ); + System.out.println( + Arrays.toString(c.getInt(0)) + ); + + System.out.println( + Arrays.toString(c.get(1)) + ); + System.out.println( + Arrays.toString(c.getInt(1)) + ); + } + + public Corpus(String filename,HashMapdict){ + V=0; + tagV=0; + freq=new HashMap(); + tagVocab=new HashMap(); + vocab=dict; + + sent=new ArrayList(); + tag=new ArrayList(); + + Scanner sc=io.FileUtil.openInFile(filename); + ArrayLists=new ArrayList(); + // s.add(START_SYM); + while(sc.hasNextLine()){ + String line=sc.nextLine(); + String toks[]=line.split("\t"); + if(toks.length<2){ + s.add(END_SYM); + sent.add(s.toArray(new String[0])); + s=new ArrayList(); + // s.add(START_SYM); + continue; + } + String tok=toks[1].toLowerCase(); + s.add(tok); + } + sc.close(); + + buildData(); + } + + public Corpus(String filename){ + V=0; + freq=new HashMap(); + vocab=new HashMap(); + tagVocab=new HashMap(); + + sent=new ArrayList(); + tag=new ArrayList(); + + System.out.println("Reading:"+filename); + + Scanner sc=io.FileUtil.openInFile(filename); + ArrayLists=new ArrayList(); + ArrayListtags=new ArrayList(); + //s.add(START_SYM); + while(sc.hasNextLine()){ + String line=sc.nextLine(); + String toks[]=line.split("\t"); + if(toks.length<2){ + s.add(END_SYM); + tags.add(END_SYM); + if(s.size()>2){ + sent.add(s.toArray(new String[0])); + tag.add(tags.toArray(new String [0])); + } + s=new ArrayList(); + tags=new ArrayList(); + // s.add(START_SYM); + continue; + } + + String tok=toks[1].toLowerCase(); + if(convertNumTok && tok.matches(".*\\d.*")){ + tok=NUM_TOK; + } + s.add(tok); + + if(toks.length>3){ + tok=toks[3].toLowerCase(); + }else{ + tok="_"; + } + tags.add(tok); + + } + sc.close(); + + for(int i=0;i(); + for(int i=0;i(); + for(int i=0;i2){ + vocab.put(key, V); + V++; + } + } + io.SerializedObjects.writeSerializedObject(vocab, alphaFilename); + io.SerializedObjects.writeSerializedObject(tagVocab,tagalphaFilename); + } + + private void addTag(String tag){ + Integer i=tagVocab.get(tag); + if(i==null){ + tagVocab.put(tag, tagV); + tagV++; + } + } + +} diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMM.java b/gi/posterior-regularisation/prjava/src/hmm/HMM.java new file mode 100644 index 00000000..1c4d7659 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/HMM.java @@ -0,0 +1,576 @@ +package hmm; + +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Scanner; + +public class HMM { + + + //trans[i][j]=prob of going FROM i to j + double [][]trans; + double [][]emit; + double []pi; + int [][]data; + int [][]tagdata; + + double logtrans[][]; + + public HMMObjective o; + + public static void main(String[] args) { + + } + + public HMM(int n_state,int n_emit,int [][]data){ + trans=new double [n_state][n_state]; + emit=new double[n_state][n_emit]; + pi=new double [n_state]; + System.out.println(" random initial parameters"); + fillRand(trans); + fillRand(emit); + fillRand(pi); + + this.data=data; + + } + + private void fillRand(double [][] a){ + for(int i=0;i=0;n--){ + for(int i=0;imaxprob){ + maxprob=p[seq.length-1][i]; + maxIdx=i; + } + } + int ans[]=new int [seq.length]; + ans[seq.length-1]=maxIdx; + for(int i=seq.length-2;i>=0;i--){ + ans[i]=backp[i+1][ans[i+1]]; + } + return ans; + } + + public double l1norm(double a[]){ + double norm=0; + for(int i=0;i s=new ArrayList(); + int state=sample(pi); + int sym=sample(emit[state]); + while(sym!=terminalSym){ + s.add(sym); + state=sample(trans[state]); + sym=sample(emit[state]); + } + + int ans[]=new int [s.size()]; + for(int i=0;i=r){ + return i; + } + } + return p.length-1; + } + + public void train(int tagdata[][]){ + double trans_exp_cnt[][]=new double [trans.length][trans.length]; + double emit_exp_cnt[][]=new double[trans.length][emit[0].length]; + double start_exp_cnt[]=new double[trans.length]; + + for(int i=0;imaxwt[i][d[sentNum][n]]){ + maxwt[i][d[sentNum][n]]=py; + } + + } + } + + //the last state + int len=post.length; + for(int i=0;imaxwt[i][d[sentNum][len-1]]){ + maxwt[i][d[sentNum][len-1]]=py; + } + + } + + } + + } + +}//end of class diff --git a/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java new file mode 100644 index 00000000..551210c0 --- /dev/null +++ b/gi/posterior-regularisation/prjava/src/hmm/HMMObjective.java @@ -0,0 +1,348 @@ +package hmm; + +import gnu.trove.TIntArrayList; +import optimization.gradientBasedMethods.ProjectedGradientDescent; +import optimization.gradientBasedMethods.ProjectedObjective; +import optimization.gradientBasedMethods.stats.OptimizerStats; +import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc; +import optimization.linesearch.InterpolationPickFirstStep; +import optimization.linesearch.LineSearchMethod; +import optimization.projections.SimplexProjection; +import optimization.stopCriteria.CompositeStopingCriteria; +import optimization.stopCriteria.ProjectedGradientL2Norm; +import optimization.stopCriteria.StopingCriteria; +import optimization.stopCriteria.ValueDifference; + +public class HMMObjective extends ProjectedObjective{ + + + private static final double GRAD_DIFF = 3; + public static double INIT_STEP_SIZE=10; + public static double VAL_DIFF=2000; + + private HMM hmm; + double[] newPoint ; + + //posterior[sent num][tok num][tag]=index into lambda + private int posteriorMap[][][]; + //projection[word][tag].get(occurence)=index into lambda + private TIntArrayList projectionMap[][]; + + //Size of the simplex + public double scale=10; + private SimplexProjection projection; + + private int wordFreq[]; + private static int MIN_FREQ=3; + private int numWordsToProject=0; + + private int n_param; + + public double loglikelihood; + + public HMMObjective(HMM h){ + hmm=h; + + countWords(); + buildMap(); + + gradient=new double [n_param]; + projection = new SimplexProjection(scale); + newPoint = new double[n_param]; + setInitialParameters(new double[n_param]); + + } + + /**@brief counts word frequency in the corpus + * + */ + private void countWords(){ + wordFreq=new int [hmm.emit[0].length]; + for(int i=0;i