diff options
-rw-r--r-- | decoder/bottom_up_parser.cc | 22 | ||||
-rw-r--r-- | decoder/decoder.cc | 3 | ||||
-rw-r--r-- | decoder/grammar.cc | 4 | ||||
-rw-r--r-- | decoder/rule_lexer.h | 3 | ||||
-rw-r--r-- | decoder/rule_lexer.ll | 20 | ||||
-rw-r--r-- | decoder/scfg_translator.cc | 2 | ||||
-rw-r--r-- | decoder/trule.cc | 2 |
7 files changed, 37 insertions, 19 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index ed79aaf0..606b8d7e 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -14,6 +14,8 @@ using namespace std; +static WordID kEPS = 0; + class ActiveChart; class PassiveChart { public: @@ -74,9 +76,12 @@ class ActiveChart { gptr_(g), ant_nodes_(), lattice_cost(0.0) {} void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const { - const GrammarIter* ni = gptr_->Extend(symbol); - if (ni) { - out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); + if (symbol == kEPS) { + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost)); + } else { + const GrammarIter* ni = gptr_->Extend(symbol); + if (ni) + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost)); } } void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const { @@ -127,8 +132,10 @@ class ActiveChart { const WordID& f = ai->label; const double& c = ai->cost; const int& len = ai->dist2next; - //VLOG(1) << "F: " << TD::Convert(f) << endl; + //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; const vector<ActiveItem>& ec = act_chart_(i, j-1); + //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; } for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di) di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); } @@ -166,6 +173,7 @@ void PassiveChart::ApplyRule(const int i, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; @@ -198,8 +206,11 @@ void PassiveChart::ApplyRules(const int i, const Hypergraph::TailNodeVector& tail, const float lattice_cost) { const int n = rules->GetNumRules(); - for (int k = 0; k < n; ++k) + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost); + } } void PassiveChart::ApplyUnaryRules(const int i, const int j) { @@ -284,6 +295,7 @@ ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( bool ExhaustiveBottomUpParser::Parse(const Lattice& input, Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); PassiveChart chart(goal_sym_, grammars_, input, forest); const bool result = chart.Parse(); return result; diff --git a/decoder/decoder.cc b/decoder/decoder.cc index 9b41253b..5bb62710 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -387,6 +387,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("show_partition,z", "Compute and show the partition (inside score)") ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") ("show_cfg_search_space", "Show the search space as a CFG") + ("show_cfg_alignment_space", "Show the alignment hypergraph as a CFG") ("show_target_graph", po::value<string>(), "Directory to write the target hypergraphs to") ("incremental_search", po::value<string>(), "Run lazy search with this language model file") ("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") @@ -988,6 +989,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { cerr << " Contst. partition log(Z): " << log(z) << endl; } o->NotifyAlignmentForest(smeta, &forest); + if (conf.count("show_cfg_alignment_space")) + HypergraphIO::WriteAsCFG(forest); if (conf.count("forest_output")) { ForestWriter writer(str("forest_output",conf), sent_id); if (FileExists(writer.fname_)) { diff --git a/decoder/grammar.cc b/decoder/grammar.cc index 160d00e6..439e448d 100644 --- a/decoder/grammar.cc +++ b/decoder/grammar.cc @@ -121,11 +121,11 @@ static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level void TextGrammar::ReadFromFile(const string& filename) { ReadFile in(filename); - ReadFromStream(in.stream()); + RuleLexer::ReadRules(in.stream(), &AddRuleHelper, filename, this); } void TextGrammar::ReadFromStream(istream* in) { - RuleLexer::ReadRules(in, &AddRuleHelper, this); + RuleLexer::ReadRules(in, &AddRuleHelper, "UNKNOWN", this); } bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const { diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h index 976ea02b..f844e5b2 100644 --- a/decoder/rule_lexer.h +++ b/decoder/rule_lexer.h @@ -2,12 +2,13 @@ #define _RULE_LEXER_H_ #include <iostream> +#include <string> #include "trule.h" struct RuleLexer { typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra); - static void ReadRules(std::istream* in, RuleCallback func, void* extra); + static void ReadRules(std::istream* in, RuleCallback func, const std::string& fname, void* extra); }; #endif diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll index 083a5bb1..c6a85919 100644 --- a/decoder/rule_lexer.ll +++ b/decoder/rule_lexer.ll @@ -18,6 +18,7 @@ std::istream* scfglex_stream = NULL; RuleLexer::RuleCallback rule_callback = NULL; void* rule_callback_extra = NULL; std::vector<int> scfglex_phrase_fnames; +std::string scfglex_fname; #undef YY_INPUT #define YY_INPUT(buf, result, max_size) (result = scfglex_stream->read(buf, max_size).gcount()) @@ -38,12 +39,12 @@ WordID scfglex_lhs; int scfglex_src_arity; int scfglex_trg_arity; -#define MAX_FEATS 100 +#define MAX_FEATS 10000 int scfglex_feat_ids[MAX_FEATS]; double scfglex_feat_vals[MAX_FEATS]; int scfglex_num_feats; -#define MAX_ARITY 200 +#define MAX_ARITY 1000 int scfglex_nt_sanity[MAX_ARITY]; int scfglex_src_nts[MAX_ARITY]; // float scfglex_nt_size_means[MAX_ARITY]; @@ -51,7 +52,7 @@ int scfglex_src_nts[MAX_ARITY]; std::stack<TRulePtr> ctf_rule_stack; unsigned int ctf_level = 0; -#define MAX_ALS 200 +#define MAX_ALS 2000 AlignmentPoint scfglex_als[MAX_ALS]; int scfglex_num_als; @@ -190,7 +191,7 @@ NT [^\t \[\],]+ BEGIN(SRC); } <INITIAL,LHS_END>. { - std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl; + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl; abort(); } @@ -217,7 +218,7 @@ NT [^\t \[\],]+ <TRG,FEATS,ALIGNS>\n { if (scfglex_src_arity != scfglex_trg_arity) { - std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n"; + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n"; abort(); } // const bool ignore_grammar_features = false; @@ -258,7 +259,7 @@ NT [^\t \[\],]+ BEGIN(FEATS); } <FEATVAL>. { - std::cerr << "Line " << lex_line << ": unexpected input in feature value: " << yytext << std::endl; + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in feature value: " << yytext << std::endl; abort(); } <FEATS>{REAL} { @@ -267,7 +268,7 @@ NT [^\t \[\],]+ ++scfglex_num_feats; } <FEATS>. { - std::cerr << "Line " << lex_line << " unexpected input in features: " << yytext << std::endl; + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << " unexpected input in features: " << yytext << std::endl; abort(); } <ALIGNS>[0-9]+-[0-9]+ { @@ -291,14 +292,14 @@ NT [^\t \[\],]+ } <ALIGNS>[ \t] ; <ALIGNS>. { - std::cerr << "Line " << lex_line << ": unexpected input in alignment: " << yytext << std::endl; + std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in alignment: " << yytext << std::endl; abort(); } %% #include "filelib.h" -void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, void* extra) { +void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) { if (scfglex_phrase_fnames.empty()) { scfglex_phrase_fnames.resize(100); for (int i = 0; i < scfglex_phrase_fnames.size(); ++i) { @@ -308,6 +309,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, void* } } lex_line = 1; + scfglex_fname = fname; scfglex_stream = in; rule_callback_extra = extra, rule_callback = func; diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index a506c591..236d7c90 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -78,7 +78,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat, } bool PassThroughGrammar::HasRuleForSpan(int, int, int distance) const { - return (distance < 2); + return (distance < 4); // TODO this isn't great, but helps with EPS lattices } struct SCFGTranslatorImpl { diff --git a/decoder/trule.cc b/decoder/trule.cc index 896f9f3d..c22baae3 100644 --- a/decoder/trule.cc +++ b/decoder/trule.cc @@ -117,7 +117,7 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) { // use lexer istringstream il(line); n_assigned=0; - RuleLexer::ReadRules(&il,assign_trule,this); + RuleLexer::ReadRules(&il,assign_trule,"STRING",this); if (n_assigned>1) cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<<line<<".\n"; return n_assigned; |