summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src
diff options
context:
space:
mode:
authordesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 20:55:36 +0000
committerdesaicwtf <desaicwtf@ec762483-ff6d-05da-a07a-a48fb63a330f>2010-07-09 20:55:36 +0000
commit406f8775c704df9e288ca2896b2b596c49538178 (patch)
treefd2efa45252796327063dbd47dfcf617a75794c9 /gi/posterior-regularisation/prjava/src
parent78763d1966bc6bb7702906b73aeb6b154577418e (diff)
context->phrase
git-svn-id: https://ws10smt.googlecode.com/svn/trunk@211 ec762483-ff6d-05da-a07a-a48fb63a330f
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/C2F.java205
1 files changed, 200 insertions, 5 deletions
diff --git a/gi/posterior-regularisation/prjava/src/phrase/C2F.java b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
index 2646d961..63dad2ab 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/C2F.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/C2F.java
@@ -1,17 +1,212 @@
package phrase;
+
+import gnu.trove.TIntArrayList;
+
+import io.FileUtil;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.Arrays;
+import java.util.List;
+
+import phrase.Corpus.Edge;
+
/**
* @brief context generates phrase
* @author desaic
*
*/
public class C2F {
-
- /**
- * @param args
+ public int K;
+ private int n_words, n_contexts, n_positions;
+ public Corpus c;
+
+ /**@brief
+ * emit[tag][position][word] = p(word | tag, position in phrase)
+ */
+ private double emit[][][];
+ /**@brief
+ * pi[context][tag] = p(tag | context)
+ */
+ private double pi[][];
+
+ public C2F(int numCluster, Corpus corpus){
+ K=numCluster;
+ c=corpus;
+ n_words=c.getNumWords();
+ n_contexts=c.getNumContexts();
+
+ //number of words in a phrase to be considered
+ //currently the first and last word
+ //if the phrase has length 1
+ //use the same word for two positions
+ n_positions=2;
+
+ emit=new double [K][n_positions][n_words];
+ pi=new double[n_contexts][K];
+
+ for(double [][]i:emit){
+ for(double []j:i){
+ arr.F.randomise(j);
+ }
+ }
+
+ for(double []j:pi){
+ arr.F.randomise(j);
+ }
+ }
+
+ /**@brief test
+ *
*/
- public static void main(String[] args) {
- // TODO Auto-generated method stub
+ public static void main(String args[]){
+ String in="../pdata/canned.con";
+ String out="../pdata/posterior.out";
+ int numCluster=5;
+ Corpus corpus = null;
+ File infile = new File(in);
+ try {
+ System.out.println("Reading concordance from " + infile);
+ corpus = Corpus.readFromFile(FileUtil.reader(infile));
+ corpus.printStats(System.out);
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + infile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ C2F c2f=new C2F(numCluster,corpus);
+ int iter=20;
+ double llh=0;
+ for(int i=0;i<iter;i++){
+ llh=c2f.EM();
+ System.out.println("Iter"+i+", llh: "+llh);
+ }
+
+ File outfile = new File (out);
+ try {
+ PrintStream ps = FileUtil.printstream(outfile);
+ c2f.displayPosterior(ps);
+ ps.println();
+ c2f.displayModelParam(ps);
+ ps.close();
+ } catch (IOException e) {
+ System.err.println("Failed to open output file: " + outfile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ }
+
+ public double EM(){
+ double [][][]exp_emit=new double [K][n_positions][n_words];
+ double [][]exp_pi=new double[n_contexts][K];
+
+ double loglikelihood=0;
+
+ //E
+ for(int context=0; context< n_contexts; context++){
+
+ List<Edge> contexts = c.getEdgesForContext(context);
+
+ for (int ctx=0; ctx<contexts.size(); ctx++){
+ Edge edge = contexts.get(ctx);
+ double p[]=posterior(edge);
+ double z = arr.F.l1norm(p);
+ assert z > 0;
+ loglikelihood += edge.getCount() * Math.log(z);
+ arr.F.l1normalize(p);
+
+ int count = edge.getCount();
+ //increment expected count
+ TIntArrayList phrase= edge.getPhrase();
+ for(int tag=0;tag<K;tag++){
+
+ exp_emit[tag][0][phrase.get(0)]+=p[tag]*count;
+ exp_emit[tag][1][phrase.get(phrase.size()-1)]+=p[tag]*count;
+
+ exp_pi[context][tag]+=p[tag]*count;
+ }
+ }
+ }
+
+ //System.out.println("Log likelihood: "+loglikelihood);
+
+ //M
+ for(double [][]i:exp_emit){
+ for(double []j:i){
+ arr.F.l1normalize(j);
+ }
+ }
+
+ emit=exp_emit;
+
+ for(double []j:exp_pi){
+ arr.F.l1normalize(j);
+ }
+
+ pi=exp_pi;
+
+ return loglikelihood;
+ }
+ public double[] posterior(Corpus.Edge edge)
+ {
+ double[] prob=Arrays.copyOf(pi[edge.getPhraseId()], K);
+
+ TIntArrayList phrase = edge.getPhrase();
+ for(int tag=0;tag<K;tag++)
+ prob[tag]*=emit[tag][0][phrase.get(0)]
+ *emit[tag][1][phrase.get(phrase.size()-1)];
+ return prob;
}
+ public void displayPosterior(PrintStream ps)
+ {
+ for (Edge edge : c.getEdges())
+ {
+ double probs[] = posterior(edge);
+ arr.F.l1normalize(probs);
+
+ // emit phrase
+ ps.print(edge.getPhraseString());
+ ps.print("\t");
+ ps.print(edge.getContextString(true));
+ int t=arr.F.argmax(probs);
+ ps.println(" ||| C=" + t);
+ }
+ }
+
+ public void displayModelParam(PrintStream ps)
+ {
+ final double EPS = 1e-6;
+
+ ps.println("P(tag|context)");
+ for (int i = 0; i < n_contexts; ++i)
+ {
+ ps.print(c.getContext(i));
+ for(int j=0;j<pi[i].length;j++){
+ if (pi[i][j] > EPS)
+ ps.print("\t" + j + ": " + pi[i][j]);
+ }
+ ps.println();
+ }
+
+ ps.println("P(word|tag,position)");
+ for (int i = 0; i < K; ++i)
+ {
+ for(int position=0;position<n_positions;position++){
+ ps.println("tag " + i + " position " + position);
+ for(int word=0;word<emit[i][position].length;word++){
+ if (emit[i][position][word] > EPS)
+ ps.print(c.getWord(word)+"="+emit[i][position][word]+"\t");
+ }
+ ps.println();
+ }
+ ps.println();
+ }
+
+ }
+
}