From 14b4d7dff699259bc5e606fa0d5beb77001e32fb Mon Sep 17 00:00:00 2001 From: redpony Date: Thu, 28 Oct 2010 21:10:12 +0000 Subject: psg for lex trans git-svn-id: https://ws10smt.googlecode.com/svn/trunk@699 ec762483-ff6d-05da-a07a-a48fb63a330f --- decoder/decoder.cc | 1 + decoder/lextrans.cc | 43 ++++++++++++++++++++-- decoder/translator.cc | 2 +- .../support/generate_per_sentence_grammars.pl | 28 ++++++++------ 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/decoder/decoder.cc b/decoder/decoder.cc index eb983419..2a8043db 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -324,6 +324,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("formalism,f",po::value(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)") ("input,i",po::value()->default_value("-"),"Source file") ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") + ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") ("weights,w",po::value(),"Feature weights file") ("prelm_weights",po::value(),"Feature weights file for prelm_beam_prune. Requires --weights.") ("prelm_copy_weights","use --weights as value for --prelm_weights.") diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 1921f280..551e77e3 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -1,6 +1,7 @@ #include "lextrans.h" #include +#include #include "filelib.h" #include "hg.h" @@ -13,10 +14,14 @@ using namespace std; struct LexicalTransImpl { LexicalTransImpl(const boost::program_options::variables_map& conf) : use_null(conf.count("lextrans_use_null") > 0), + psg_file_(), kXCAT(TD::Convert("X")*-1), kNULL(TD::Convert("")), kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + if (conf.count("per_sentence_grammar_file")) { + psg_file_ = new ifstream(conf["per_sentence_grammar_file"].as().c_str()); + } vector gfiles = conf["grammar"].as >(); assert(gfiles.size() == 1); ReadFile rf(gfiles.front()); @@ -25,10 +30,10 @@ struct LexicalTransImpl { istream* in = rf.stream(); int lc = 0; bool flag = false; + string line; while(*in) { - string line; getline(*in, line); - if (line.empty()) continue; + if (!*in) continue; ++lc; TRulePtr r(TRule::CreateRulePhrasetable(line)); tg->AddRule(r); @@ -39,7 +44,31 @@ struct LexicalTransImpl { cerr << "Loaded " << lc << " rules\n"; } + void LoadSentenceGrammar(const string& s_offset) { + const unsigned long long int offset = strtoull(s_offset.c_str(), NULL, 10); + psg_file_->seekg(offset, ios::beg); + TextGrammar *tg = new TextGrammar; + sup_grammar.reset(tg); + const string kEND_MARKER = "###EOS###"; + string line; + while(true) { + assert(*psg_file_); + getline(*psg_file_, line); + if (line == kEND_MARKER) break; + TRulePtr r(TRule::CreateRulePhrasetable(line)); + tg->AddRule(r); + } + } + void BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + if (psg_file_) { + const string offset = smeta.GetSGMLValue("psg"); + if (offset.size() < 2 || offset[0] != '@') { + cerr << "per_sentence_grammar_file given but sentence id=" << smeta.GetSentenceID() << " doesn't have grammar info!\n"; + abort(); + } + LoadSentenceGrammar(offset.substr(1)); + } const int e_len = smeta.GetTargetLength(); assert(e_len > 0); const int f_len = lattice.size(); @@ -53,8 +82,12 @@ struct LexicalTransImpl { const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); if (!gi) { - cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; - abort(); + if (psg_file_) + gi = sup_grammar->GetRoot()->Extend(src_sym); + if (!gi) { + cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; + abort(); + } } const RuleBin* rb = gi->GetRules(); assert(rb); @@ -88,11 +121,13 @@ struct LexicalTransImpl { private: const bool use_null; + ifstream* psg_file_; const WordID kXCAT; const WordID kNULL; const TRulePtr kBINARY; const TRulePtr kGOAL_RULE; GrammarPtr grammar; + GrammarPtr sup_grammar; }; LexicalTrans::LexicalTrans(const boost::program_options::variables_map& conf) : diff --git a/decoder/translator.cc b/decoder/translator.cc index 277c3a2d..d1ca125b 100644 --- a/decoder/translator.cc +++ b/decoder/translator.cc @@ -43,7 +43,7 @@ void Translator::SentenceComplete() { // this may be overridden by translators that want to accept // metadata void Translator::ProcessMarkupHintsImpl(const map& kv) { - int unprocessed = kv.size() - kv.count("id"); + int unprocessed = kv.size() - kv.count("id") - kv.count("psg"); if (!SILENT) cerr << "Inside translator process hints\n"; if (unprocessed > 0) { cerr << "Sentence markup contains unprocessed data:\n"; diff --git a/word-aligner/support/generate_per_sentence_grammars.pl b/word-aligner/support/generate_per_sentence_grammars.pl index 80243419..8779ac9c 100755 --- a/word-aligner/support/generate_per_sentence_grammars.pl +++ b/word-aligner/support/generate_per_sentence_grammars.pl @@ -2,7 +2,7 @@ use strict; use utf8; -die "Usage: $0 f.voc corpus.f-e grammar.f-e.gz\n" unless scalar @ARGV == 3; +die "Usage: $0 f.voc corpus.f-e grammar.f-e.gz [OUT]filtered.f-e.gz [OUT]per_sentence_grammar.f-e [OUT]train.f-e.sgml\n" unless scalar @ARGV == 6; my $MAX_INMEM = 2500; @@ -10,7 +10,14 @@ open FV,"<$ARGV[0]" or die "Can't read $ARGV[0]: $!"; open C,"<$ARGV[1]" or die "Can't read $ARGV[1]: $!"; open G,"gunzip -c $ARGV[2]|" or die "Can't read $ARGV[2]: $!"; +open FILT,"|gzip -c > $ARGV[3]" or die "Can't write $ARGV[3]: $!"; +open PSG,">$ARGV[4]" or die "Can't write $ARGV[4]: $!"; +open OTRAIN,">$ARGV[5]" or die "Can't write $ARGV[5]: $!"; + +binmode FILT, ":utf8"; +binmode PSG, ":utf8"; binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; binmode FV, ":utf8"; binmode C, ":utf8"; binmode G, ":utf8"; @@ -35,7 +42,7 @@ while() { chomp; my ($f, $e, $feats) = split / \|\|\| /; if ($most_freq{$f}) { - print "$_\n"; + print FILT "$_\n"; $memrc++; } else { $loadrc++; @@ -47,20 +54,19 @@ while() { push @$r, "$e ||| $feats"; } } +close FILT; close G; print STDERR " mem rc: $memrc\n"; print STDERR " load rc: $loadrc\n"; my $id = 0; -open O, ">ps.grammar" or die; -binmode(O,":utf8"); while() { chomp; my ($f,$e) = split / \|\|\| /; my @fwords = split /\s+/, $f; my $tot = 0; my %used; - my $fpos = tell(O); + my $fpos = tell(PSG); for my $f (@fwords) { next if $most_freq{$f}; next if $used{$f}; @@ -69,15 +75,15 @@ while() { my $num = scalar @$r; $tot += $num; for my $rule (@$r) { - print O "$f ||| $rule\n"; + print PSG "$f ||| $rule\n"; } $used{$f} = 1; } - print O "###EOS###\n"; - print STDERR " $_ \n"; - #print STDERR "id=$id POS=$fpos\n"; + print PSG "###EOS###\n"; + print OTRAIN " $_ \n"; $id++; - last if $id == 10; } +close PSG; +close OTRAIN; +print STDERR "Done.\n"; -close O; -- cgit v1.2.3