diff options
-rw-r--r-- | extools/Makefile.am | 7 | ||||
-rw-r--r-- | extools/extractor_monolingual.cc | 196 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/io/FileUtil.java | 18 | ||||
-rw-r--r-- | gi/posterior-regularisation/prjava/src/phrase/Trainer.java | 8 |
4 files changed, 216 insertions, 13 deletions
diff --git a/extools/Makefile.am b/extools/Makefile.am index 1c0da21b..807fe7d6 100644 --- a/extools/Makefile.am +++ b/extools/Makefile.am @@ -4,7 +4,8 @@ bin_PROGRAMS = \ build_lexical_translation \ filter_grammar \ featurize_grammar \ - filter_score_grammar + filter_score_grammar \ + extractor_monolingual noinst_PROGRAMS = @@ -35,5 +36,9 @@ extractor_SOURCES = sentence_pair.cc extract.cc extractor.cc striped_grammar.cc extractor_LDADD = $(top_srcdir)/decoder/libcdec.a -lz extractor_LDFLAGS = -all-static +extractor_monolingual_SOURCES = extractor_monolingual.cc +extractor_monolingual_LDADD = $(top_srcdir)/decoder/libcdec.a -lz +extractor_monolingual_LDFLAGS = -all-static + AM_CPPFLAGS = -W -Wall -Wno-sign-compare $(GTEST_CPPFLAGS) -I$(top_srcdir)/decoder diff --git a/extools/extractor_monolingual.cc b/extools/extractor_monolingual.cc new file mode 100644 index 00000000..5db768e3 --- /dev/null +++ b/extools/extractor_monolingual.cc @@ -0,0 +1,196 @@ +#include <iostream> +#include <vector> +#include <utility> +#include <tr1/unordered_map> + +#include <boost/functional/hash.hpp> +#include <boost/program_options.hpp> +#include <boost/program_options/variables_map.hpp> +#include <boost/lexical_cast.hpp> + +#include "tdict.h" +#include "fdict.h" +#include "wordid.h" +#include "filelib.h" + +using namespace std; +using namespace std::tr1; +namespace po = boost::program_options; + +static const size_t MAX_LINE_LENGTH = 100000; +WordID kBOS, kEOS, kDIVIDER, kGAP; +int kCOUNT; + +void InitCommandLine(int argc, char** argv, po::variables_map* conf) { + po::options_description opts("Configuration options"); + opts.add_options() + ("input,i", po::value<string>()->default_value("-"), "Input file") + ("phrases,p", po::value<string>(), "File contatining phrases of interest") + ("phrase_context_size,S", po::value<int>()->default_value(2), "Use this many words of context on left and write when writing base phrase contexts") + ("silent", "Write nothing to stderr except errors") + ("help,h", "Print this help message and exit"); + po::options_description clo("Command line options"); + po::options_description dcmdline_options; + dcmdline_options.add(opts); + + po::store(parse_command_line(argc, argv, dcmdline_options), *conf); + po::notify(*conf); + + if (conf->count("help") || conf->count("input") != 1 || conf->count("phrases") != 1) { + cerr << "\nUsage: extractor_monolingual [-options]\n"; + cerr << dcmdline_options << endl; + exit(1); + } +} + +struct TrieNode +{ + TrieNode(int l) : finish(false), length(l) {}; + ~TrieNode() + { + for (unordered_map<int, TrieNode*>::iterator + it = next.begin(); it != next.end(); ++it) + delete it->second; + next.clear(); + } + + TrieNode *follow(int token) + { + unordered_map<int, TrieNode*>::iterator + found = next.find(token); + if (found != next.end()) + return found->second; + else + return 0; + } + + void insert(const vector<int> &tokens) + { + insert(tokens.begin(), tokens.end()); + } + + void insert(vector<int>::const_iterator begin, vector<int>::const_iterator end) + { + if (begin == end) + finish = true; + else + { + int token = *begin; + unordered_map<int, TrieNode*>::iterator + nit = next.find(token); + if (nit == next.end()) + nit = next.insert(make_pair(token, new TrieNode(length+1))).first; + ++begin; + nit->second->insert(begin, end); + } + } + + bool finish; + int length; + unordered_map<int, TrieNode*> next; +}; + +void WriteContext(const vector<int>& sentence, int start, int end, int ctx_size) +{ + for (int i = start; i < end; ++i) + { + if (i != start) cout << " "; + cout << sentence[i]; + } + cout << '\t'; + for (int i = ctx_size; i > 0; --i) + cout << TD::Convert(sentence[start-i]) << " "; + cout << " " << TD::Convert(kGAP); + for (int i = 0; i < ctx_size; ++i) + cout << " " << TD::Convert(sentence[end+i]); + cout << "\n"; +} + +inline bool IsWhitespace(char c) { + return c == ' ' || c == '\t'; +} + +inline void SkipWhitespace(const char* buf, int* ptr) { + while (buf[*ptr] && IsWhitespace(buf[*ptr])) { ++(*ptr); } +} + +vector<int> ReadSentence(const char *buf, int padding) +{ + int ptr = 0; + SkipWhitespace(buf, &ptr); + int start = ptr; + vector<int> sentence; + for (int i = 0; i < padding; ++i) + sentence.push_back(kBOS); + + while (char c = buf[ptr]) + { + if (!IsWhitespace(c)) + ++ptr; + else { + sentence.push_back(TD::Convert(string(buf, start, ptr-start))); + SkipWhitespace(buf, &ptr); + start = ptr; + } + } + for (int i = 0; i < padding; ++i) + sentence.push_back(kEOS); + + return sentence; +} + +int main(int argc, char** argv) +{ + po::variables_map conf; + InitCommandLine(argc, argv, &conf); + kBOS = TD::Convert("<s>"); + kEOS = TD::Convert("</s>"); + kDIVIDER = TD::Convert("|||"); + kGAP = TD::Convert("<PHRASE>"); + kCOUNT = FD::Convert("C"); + + bool silent = conf.count("silent") > 0; + const int ctx_size = conf["phrase_context_size"].as<int>(); + + char buf[MAX_LINE_LENGTH]; + TrieNode phrase_trie(0); + ReadFile rpf(conf["phrases"].as<string>()); + istream& pin = *rpf.stream(); + while (pin) { + pin.getline(buf, MAX_LINE_LENGTH); + phrase_trie.insert(ReadSentence(buf, 0)); + } + + ReadFile rif(conf["input"].as<string>()); + istream &iin = *rif.stream(); + int line = 0; + while (iin) { + ++line; + iin.getline(buf, MAX_LINE_LENGTH); + if (buf[0] == 0) continue; + if (!silent) { + if (line % 200 == 0) cerr << '.'; + if (line % 8000 == 0) cerr << " [" << line << "]\n" << flush; + } + + vector<int> sentence = ReadSentence(buf, ctx_size); + vector<TrieNode*> tries(1, &phrase_trie); + for (int i = ctx_size; i < (int)sentence.size() - ctx_size; ++i) + { + vector<TrieNode*> tries_prime(1, &phrase_trie); + for (vector<TrieNode*>::iterator tit = tries.begin(); tit != tries.end(); ++tit) + { + TrieNode* next = (*tit)->follow(sentence[i]); + if (next != 0) + { + if (next->finish) + WriteContext(sentence, i - next->length, i, ctx_size); + tries_prime.push_back(next); + } + } + swap(tries, tries_prime); + } + } + if (!silent) cerr << endl; + return 0; +} diff --git a/gi/posterior-regularisation/prjava/src/io/FileUtil.java b/gi/posterior-regularisation/prjava/src/io/FileUtil.java index 81e7747b..6720d087 100644 --- a/gi/posterior-regularisation/prjava/src/io/FileUtil.java +++ b/gi/posterior-regularisation/prjava/src/io/FileUtil.java @@ -8,24 +8,25 @@ public class FileUtil public static BufferedReader reader(File file) throws FileNotFoundException, IOException
{
if (file.getName().endsWith(".gz"))
- return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file))));
+ return new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(file)), "UTF8"));
else
- return new BufferedReader(new FileReader(file));
+ return new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF8"));
}
public static PrintStream printstream(File file) throws FileNotFoundException, IOException
{
if (file.getName().endsWith(".gz"))
- return new PrintStream(new GZIPOutputStream(new FileOutputStream(file)));
+ return new PrintStream(new GZIPOutputStream(new FileOutputStream(file)), true, "UTF8");
else
- return new PrintStream(new FileOutputStream(file));
+ return new PrintStream(new FileOutputStream(file), true, "UTF8");
}
- public static Scanner openInFile(String filename){
+ public static Scanner openInFile(String filename)
+ {
Scanner localsc=null;
try
{
- localsc=new Scanner (new FileInputStream(filename));
+ localsc=new Scanner(new FileInputStream(filename), "UTF8");
}catch(IOException ioe){
System.out.println(ioe.getMessage());
@@ -33,10 +34,11 @@ public class FileUtil return localsc;
}
- public static FileInputStream openInputStream(String infilename){
+ public static FileInputStream openInputStream(String infilename)
+ {
FileInputStream fis=null;
try {
- fis =(new FileInputStream(infilename));
+ fis = new FileInputStream(infilename);
} catch (IOException ioe) {
System.out.println(ioe.getMessage());
diff --git a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java index d1322c26..7f0b1970 100644 --- a/gi/posterior-regularisation/prjava/src/phrase/Trainer.java +++ b/gi/posterior-regularisation/prjava/src/phrase/Trainer.java @@ -34,10 +34,10 @@ public class Trainer parser.accepts("agree"); parser.accepts("no-parameter-cache"); parser.accepts("skip-large-phrases").withRequiredArg().ofType(Integer.class).defaultsTo(5); - parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(0); - parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(0); + parser.accepts("rare-word").withRequiredArg().ofType(Integer.class).defaultsTo(10); + parser.accepts("rare-edge").withRequiredArg().ofType(Integer.class).defaultsTo(1); + parser.accepts("rare-phrase").withRequiredArg().ofType(Integer.class).defaultsTo(2); + parser.accepts("rare-context").withRequiredArg().ofType(Integer.class).defaultsTo(2); OptionSet options = parser.parse(args); if (options.has("help") || !options.has("in")) |