summaryrefslogtreecommitdiff
path: root/gi/posterior-regularisation/prjava/src
diff options
context:
space:
mode:
Diffstat (limited to 'gi/posterior-regularisation/prjava/src')
-rw-r--r--gi/posterior-regularisation/prjava/src/arr/F.java2
-rw-r--r--gi/posterior-regularisation/prjava/src/io/FileUtil.java47
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Corpus.java8
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java71
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java2
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java17
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java2
-rw-r--r--gi/posterior-regularisation/prjava/src/phrase/Trainer.java150
8 files changed, 192 insertions, 107 deletions
diff --git a/gi/posterior-regularisation/prjava/src/arr/F.java b/gi/posterior-regularisation/prjava/src/arr/F.java
index 5821af42..7f2b140a 100644
--- a/gi/posterior-regularisation/prjava/src/arr/F.java
+++ b/gi/posterior-regularisation/prjava/src/arr/F.java
@@ -3,7 +3,7 @@ package arr;
import java.util.Random;
public class F {
- private static Random rng = new Random(); //(9562724l);
+ public static Random rng = new Random();
public static void randomise(double probs[])
{
diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
index 67ce571e..81e7747b 100644
--- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java
+++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java
@@ -3,7 +3,24 @@ import java.util.*;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import java.io.*;
-public class FileUtil {
+public class FileUtil
+{
+ public static BufferedReader reader(File file) throws FileNotFoundException, IOException
+ {
+ if (file.getName().endsWith(".gz"))
+ return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
+ else
+ return new BufferedReader(new FileReader(file));
+ }
+
+ public static PrintStream printstream(File file) throws FileNotFoundException, IOException
+ {
+ if (file.getName().endsWith(".gz"))
+ return new PrintStream(new GZIPOutputStream(new FileOutputStream(file)));
+ else
+ return new PrintStream(new FileOutputStream(file));
+ }
+
public static Scanner openInFile(String filename){
Scanner localsc=null;
try
@@ -16,34 +33,6 @@ public class FileUtil {
return localsc;
}
- public static BufferedReader openBufferedReader(String filename){
- BufferedReader r=null;
- try
- {
- if (filename.endsWith(".gz"))
- r=(new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(new File(filename))))));
- else
- r=(new BufferedReader(new FileReader(new File(filename))));
- }catch(IOException ioe){
- System.out.println(ioe.getMessage());
- }
- return r;
- }
-
- public static PrintStream openOutFile(String filename){
- PrintStream localps=null;
- try
- {
- if (filename.endsWith(".gz"))
- localps=new PrintStream (new GZIPOutputStream(new FileOutputStream(filename)));
- else
- localps=new PrintStream (new FileOutputStream(filename));
-
- }catch(IOException ioe){
- System.out.println(ioe.getMessage());
- }
- return localps;
- }
public static FileInputStream openInputStream(String infilename){
FileInputStream fis=null;
try {
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
index d5e856ca..81264ab9 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/Corpus.java
@@ -217,5 +217,11 @@ public class Corpus
}
return c;
+ }
+
+ public void printStats(PrintStream out)
+ {
+ out.println("Corpus has " + edges.size() + " edges " + phraseLexicon.size() + " phrases "
+ + contextLexicon.size() + " contexts and " + wordLexicon.size() + " word types");
}
-}
+} \ No newline at end of file
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
index 63a60682..7d7c46dd 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCluster.java
@@ -25,73 +25,6 @@ public class PhraseCluster {
// pi[phrase][tag] = p(tag | phrase)
private double pi[][];
- public static void main(String[] args)
- {
- String input_fname = args[0];
- int tags = Integer.parseInt(args[1]);
- String output_fname = args[2];
- int iterations = Integer.parseInt(args[3]);
- double scalePT = Double.parseDouble(args[4]);
- double scaleCT = Double.parseDouble(args[5]);
- int threads = Integer.parseInt(args[6]);
- boolean runEM = Boolean.parseBoolean(args[7]);
-
- assert(tags >= 2);
- assert(scalePT >= 0);
- assert(scaleCT >= 0);
-
- Corpus corpus = null;
- try {
- corpus = Corpus.readFromFile(FileUtil.openBufferedReader(input_fname));
- } catch (IOException e) {
- System.err.println("Failed to open input file: " + input_fname);
- e.printStackTrace();
- System.exit(1);
- }
- PhraseCluster cluster = new PhraseCluster(tags, corpus, scalePT, scaleCT, threads);
-
- //PhraseObjective.ps = FileUtil.openOutFile(outputDir + "/phrase_stat.out");
-
- double last = 0;
- for(int i=0;i<iterations;i++){
-
- double o;
- if (runEM || i < 3)
- o = cluster.EM();
- else if (scaleCT == 0)
- {
- if (threads >= 1)
- o = cluster.PREM_phrase_constraints_parallel();
- else
- o = cluster.PREM_phrase_constraints();
- }
- else
- o = cluster.PREM_phrase_context_constraints();
-
- //PhraseObjective.ps.
- System.out.println("ITER: "+i+" objective: " + o);
- last = o;
- }
-
- double pl1lmax = cluster.phrase_l1lmax();
- double cl1lmax = cluster.context_l1lmax();
- System.out.println("Final posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
- if (runEM) System.out.println("With PR objective " + (last - scalePT*pl1lmax - scaleCT*cl1lmax));
-
- PrintStream ps=io.FileUtil.openOutFile(output_fname);
- cluster.displayPosterior(ps);
- ps.close();
-
- //PhraseObjective.ps.close();
-
- //ps = io.FileUtil.openOutFile(outputDir + "/parameters.out");
- //cluster.displayModelParam(ps);
- //ps.close();
-
- if (cluster.pool != null)
- cluster.pool.shutdown();
- }
-
public PhraseCluster(int numCluster, Corpus corpus, double scalep, double scalec, int threads){
K=numCluster;
c=corpus;
@@ -134,7 +67,7 @@ public class PhraseCluster {
double p[]=posterior(edge);
double z = arr.F.l1norm(p);
assert z > 0;
- loglikelihood+=Math.log(z);
+ loglikelihood += edge.getCount() * Math.log(z);
arr.F.l1normalize(p);
int count = edge.getCount();
@@ -150,7 +83,7 @@ public class PhraseCluster {
}
}
- System.out.println("Log likelihood: "+loglikelihood);
+ //System.out.println("Log likelihood: "+loglikelihood);
//M
for(double [][]i:exp_emit){
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
index fbf43a7f..15bd29c2 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseContextObjective.java
@@ -26,7 +26,7 @@ import phrase.Corpus.Edge;
public class PhraseContextObjective extends ProjectedObjective
{
private static final double GRAD_DIFF = 0.00002;
- private static double INIT_STEP_SIZE = 10;
+ private static double INIT_STEP_SIZE = 300;
private static double VAL_DIFF = 1e-4; // FIXME needs to be tuned
private static int ITERATIONS = 100;
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
index 11e948ff..903e47c8 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseCorpus.java
@@ -1,7 +1,11 @@
package phrase;
+import io.FileUtil;
+
import java.io.BufferedInputStream;
import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
@@ -20,8 +24,9 @@ public class PhraseCorpus
public int data[][][];
public int numContexts;
- public PhraseCorpus(String filename){
- BufferedReader r=io.FileUtil.openBufferedReader(filename);
+ public PhraseCorpus(String filename) throws FileNotFoundException, IOException
+ {
+ BufferedReader r = FileUtil.reader(new File(filename));
phraseLex=new HashMap<String,Integer>();
wordLex=new HashMap<String,Integer>();
@@ -84,8 +89,9 @@ public class PhraseCorpus
}
//for debugging
- public void saveLex(String lexFilename){
- PrintStream ps=io.FileUtil.openOutFile(lexFilename);
+ public void saveLex(String lexFilename) throws FileNotFoundException, IOException
+ {
+ PrintStream ps = FileUtil.printstream(new File(lexFilename));
ps.println("Phrase Lexicon");
ps.println(phraseLex.size());
printDict(phraseLex,ps);
@@ -175,7 +181,8 @@ public class PhraseCorpus
return null;
}
- public static void main(String[] args) {
+ public static void main(String[] args) throws Exception
+ {
String LEX_FILENAME="../pdata/lex.out";
String DATA_FILENAME="../pdata/btec.con";
PhraseCorpus c=new PhraseCorpus(DATA_FILENAME);
diff --git a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
index 0a76e2dc..3314f74a 100644
--- a/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
+++ b/gi/posterior-regularisation/prjava/src/phrase/PhraseObjective.java
@@ -21,7 +21,7 @@ import optimization.util.MathUtils;
public class PhraseObjective extends ProjectedObjective
{
static final double GRAD_DIFF = 0.00002;
- static double INIT_STEP_SIZE = 10;
+ static double INIT_STEP_SIZE = 300;
static double VAL_DIFF = 1e-4; // FIXME needs to be tuned - and this might be too weak
static int ITERATIONS = 100;
//private double c1=0.0001; // wolf stuff
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
new file mode 100644
index 00000000..b19f3fb9
--- /dev/null
+++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java
@@ -0,0 +1,150 @@
+package phrase;
+
+import io.FileUtil;
+import joptsimple.OptionParser;
+import joptsimple.OptionSet;
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintStream;
+import java.util.Random;
+
+import arr.F;
+
+public class Trainer
+{
+ public static void main(String[] args)
+ {
+ OptionParser parser = new OptionParser();
+ parser.accepts("help");
+ parser.accepts("in").withRequiredArg().ofType(File.class);
+ parser.accepts("out").withRequiredArg().ofType(File.class);
+ parser.accepts("parameters").withRequiredArg().ofType(File.class);
+ parser.accepts("topics").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+ parser.accepts("em-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(5);
+ parser.accepts("pr-iterations").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("threads").withRequiredArg().ofType(Integer.class).defaultsTo(0);
+ parser.accepts("scale-phrase").withRequiredArg().ofType(Double.class).defaultsTo(5.0);
+ parser.accepts("scale-context").withRequiredArg().ofType(Double.class).defaultsTo(0.0);
+ parser.accepts("seed").withRequiredArg().ofType(Long.class).defaultsTo(0l);
+ parser.accepts("convergence-threshold").withRequiredArg().ofType(Double.class).defaultsTo(1e-6);
+ OptionSet options = parser.parse(args);
+
+ if (options.has("help") || !options.has("in"))
+ {
+ try {
+ parser.printHelpOn(System.err);
+ } catch (IOException e) {
+ System.err.println("This should never happen. Really.");
+ e.printStackTrace();
+ }
+ System.exit(1);
+ }
+
+ int tags = (Integer) options.valueOf("topics");
+ int em_iterations = (Integer) options.valueOf("em-iterations");
+ int pr_iterations = (Integer) options.valueOf("pr-iterations");
+ double scale_phrase = (Double) options.valueOf("scale-phrase");
+ double scale_context = (Double) options.valueOf("scale-context");
+ int threads = (Integer) options.valueOf("threads");
+ double threshold = (Double) options.valueOf("convergence-threshold");
+
+ if (options.has("seed"))
+ F.rng = new Random((Long) options.valueOf("seed"));
+
+ if (tags <= 1 || scale_phrase < 0 || scale_context < 0 || threshold < 0)
+ {
+ System.err.println("Invalid arguments. Try again!");
+ System.exit(1);
+ }
+
+ Corpus corpus = null;
+ File infile = (File) options.valueOf("in");
+ try {
+ System.out.println("Reading concordance from " + infile);
+ corpus = Corpus.readFromFile(FileUtil.reader(infile));
+ corpus.printStats(System.out);
+ } catch (IOException e) {
+ System.err.println("Failed to open input file: " + infile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+
+ System.out.println("Running with " + tags + " tags " +
+ "for " + em_iterations + " EM and " + pr_iterations + " PR iterations " +
+ "with scale " + scale_phrase + " phrase and " + scale_context + " context " +
+ "and " + threads + " threads");
+ System.out.println();
+
+ PhraseCluster cluster = new PhraseCluster(tags, corpus, scale_phrase, scale_context, threads);
+
+ double last = 0;
+ for (int i=0; i<em_iterations+pr_iterations; i++)
+ {
+ double o;
+ if (i < em_iterations)
+ o = cluster.EM();
+ else if (scale_context == 0)
+ {
+ if (threads >= 1)
+ o = cluster.PREM_phrase_constraints_parallel();
+ else
+ o = cluster.PREM_phrase_constraints();
+ }
+ else
+ o = cluster.PREM_phrase_context_constraints();
+
+ System.out.println("ITER: "+i+" objective: " + o);
+
+ if (i != 0 && Math.abs((o - last) / o) < threshold)
+ {
+ last = o;
+ if (i < em_iterations)
+ {
+ i = em_iterations - 1;
+ continue;
+ }
+ else
+ break;
+ }
+ last = o;
+ }
+
+ double pl1lmax = cluster.phrase_l1lmax();
+ double cl1lmax = cluster.context_l1lmax();
+ System.out.println("\nFinal posterior phrase l1lmax " + pl1lmax + " context l1lmax " + cl1lmax);
+ if (pr_iterations == 0)
+ System.out.println("With PR objective " + (last - scale_phrase*pl1lmax - scale_context*cl1lmax));
+
+ if (options.has("out"))
+ {
+ File outfile = (File) options.valueOf("out");
+ try {
+ PrintStream ps = FileUtil.printstream(outfile);
+ cluster.displayPosterior(ps);
+ ps.close();
+ } catch (IOException e) {
+ System.err.println("Failed to open output file: " + outfile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+
+ if (options.has("parameters"))
+ {
+ File outfile = (File) options.valueOf("parameters");
+ PrintStream ps;
+ try {
+ ps = FileUtil.printstream(outfile);
+ cluster.displayModelParam(ps);
+ ps.close();
+ } catch (IOException e) {
+ System.err.println("Failed to open output parameters file: " + outfile);
+ e.printStackTrace();
+ System.exit(1);
+ }
+ }
+
+ if (cluster.pool != null)
+ cluster.pool.shutdown();
+ }
+}