summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormjdenkowski <michael.j.denkowski@gmail.com>2013-12-29 17:16:43 -0500
committermjdenkowski <michael.j.denkowski@gmail.com>2013-12-29 17:16:43 -0500
commit791301036d6bfa406c2d5222032b204f7d4943ee (patch)
tree0a76fdd52700ac43864be62124a43b4f2e940a8b
parentb12f48112213cbbde519b8aabd02b2c57bf0948d (diff)
parent3c22963a360346381588350962499d1a76a89c10 (diff)
Merge branch 'master' of https://github.com/redpony/cdec
-rw-r--r--decoder/bottom_up_parser.cc22
-rw-r--r--decoder/decoder.cc3
-rw-r--r--decoder/grammar.cc4
-rw-r--r--decoder/rule_lexer.h3
-rw-r--r--decoder/rule_lexer.ll20
-rw-r--r--decoder/scfg_translator.cc2
-rw-r--r--decoder/trule.cc2
7 files changed, 37 insertions, 19 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index ed79aaf0..606b8d7e 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -14,6 +14,8 @@
using namespace std;
+static WordID kEPS = 0;
+
class ActiveChart;
class PassiveChart {
public:
@@ -74,9 +76,12 @@ class ActiveChart {
gptr_(g), ant_nodes_(), lattice_cost(0.0) {}
void ExtendTerminal(int symbol, float src_cost, vector<ActiveItem>* out_cell) const {
- const GrammarIter* ni = gptr_->Extend(symbol);
- if (ni) {
- out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
+ if (symbol == kEPS) {
+ out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_cost + src_cost));
+ } else {
+ const GrammarIter* ni = gptr_->Extend(symbol);
+ if (ni)
+ out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_cost + src_cost));
}
}
void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector<ActiveItem>* out_cell) const {
@@ -127,8 +132,10 @@ class ActiveChart {
const WordID& f = ai->label;
const double& c = ai->cost;
const int& len = ai->dist2next;
- //VLOG(1) << "F: " << TD::Convert(f) << endl;
+ //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl;
const vector<ActiveItem>& ec = act_chart_(i, j-1);
+ //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl;
+ //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_cost << endl; }
for (vector<ActiveItem>::const_iterator di = ec.begin(); di != ec.end(); ++di)
di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1));
}
@@ -166,6 +173,7 @@ void PassiveChart::ApplyRule(const int i,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
+ //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
@@ -198,8 +206,11 @@ void PassiveChart::ApplyRules(const int i,
const Hypergraph::TailNodeVector& tail,
const float lattice_cost) {
const int n = rules->GetNumRules();
- for (int k = 0; k < n; ++k)
+ //cerr << i << " " << j << ": NUM RULES: " << n << endl;
+ for (int k = 0; k < n; ++k) {
+ //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl;
ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_cost);
+ }
}
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
@@ -284,6 +295,7 @@ ExhaustiveBottomUpParser::ExhaustiveBottomUpParser(
bool ExhaustiveBottomUpParser::Parse(const Lattice& input,
Hypergraph* forest) const {
+ kEPS = TD::Convert("*EPS*");
PassiveChart chart(goal_sym_, grammars_, input, forest);
const bool result = chart.Parse();
return result;
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index 9b41253b..5bb62710 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -387,6 +387,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("show_partition,z", "Compute and show the partition (inside score)")
("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation")
("show_cfg_search_space", "Show the search space as a CFG")
+ ("show_cfg_alignment_space", "Show the alignment hypergraph as a CFG")
("show_target_graph", po::value<string>(), "Directory to write the target hypergraphs to")
("incremental_search", po::value<string>(), "Run lazy search with this language model file")
("coarse_to_fine_beam_prune", po::value<double>(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)")
@@ -988,6 +989,8 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
cerr << " Contst. partition log(Z): " << log(z) << endl;
}
o->NotifyAlignmentForest(smeta, &forest);
+ if (conf.count("show_cfg_alignment_space"))
+ HypergraphIO::WriteAsCFG(forest);
if (conf.count("forest_output")) {
ForestWriter writer(str("forest_output",conf), sent_id);
if (FileExists(writer.fname_)) {
diff --git a/decoder/grammar.cc b/decoder/grammar.cc
index 160d00e6..439e448d 100644
--- a/decoder/grammar.cc
+++ b/decoder/grammar.cc
@@ -121,11 +121,11 @@ static void AddRuleHelper(const TRulePtr& new_rule, const unsigned int ctf_level
void TextGrammar::ReadFromFile(const string& filename) {
ReadFile in(filename);
- ReadFromStream(in.stream());
+ RuleLexer::ReadRules(in.stream(), &AddRuleHelper, filename, this);
}
void TextGrammar::ReadFromStream(istream* in) {
- RuleLexer::ReadRules(in, &AddRuleHelper, this);
+ RuleLexer::ReadRules(in, &AddRuleHelper, "UNKNOWN", this);
}
bool TextGrammar::HasRuleForSpan(int /* i */, int /* j */, int distance) const {
diff --git a/decoder/rule_lexer.h b/decoder/rule_lexer.h
index 976ea02b..f844e5b2 100644
--- a/decoder/rule_lexer.h
+++ b/decoder/rule_lexer.h
@@ -2,12 +2,13 @@
#define _RULE_LEXER_H_
#include <iostream>
+#include <string>
#include "trule.h"
struct RuleLexer {
typedef void (*RuleCallback)(const TRulePtr& new_rule, const unsigned int ctf_level, const TRulePtr& coarse_rule, void* extra);
- static void ReadRules(std::istream* in, RuleCallback func, void* extra);
+ static void ReadRules(std::istream* in, RuleCallback func, const std::string& fname, void* extra);
};
#endif
diff --git a/decoder/rule_lexer.ll b/decoder/rule_lexer.ll
index 083a5bb1..c6a85919 100644
--- a/decoder/rule_lexer.ll
+++ b/decoder/rule_lexer.ll
@@ -18,6 +18,7 @@ std::istream* scfglex_stream = NULL;
RuleLexer::RuleCallback rule_callback = NULL;
void* rule_callback_extra = NULL;
std::vector<int> scfglex_phrase_fnames;
+std::string scfglex_fname;
#undef YY_INPUT
#define YY_INPUT(buf, result, max_size) (result = scfglex_stream->read(buf, max_size).gcount())
@@ -38,12 +39,12 @@ WordID scfglex_lhs;
int scfglex_src_arity;
int scfglex_trg_arity;
-#define MAX_FEATS 100
+#define MAX_FEATS 10000
int scfglex_feat_ids[MAX_FEATS];
double scfglex_feat_vals[MAX_FEATS];
int scfglex_num_feats;
-#define MAX_ARITY 200
+#define MAX_ARITY 1000
int scfglex_nt_sanity[MAX_ARITY];
int scfglex_src_nts[MAX_ARITY];
// float scfglex_nt_size_means[MAX_ARITY];
@@ -51,7 +52,7 @@ int scfglex_src_nts[MAX_ARITY];
std::stack<TRulePtr> ctf_rule_stack;
unsigned int ctf_level = 0;
-#define MAX_ALS 200
+#define MAX_ALS 2000
AlignmentPoint scfglex_als[MAX_ALS];
int scfglex_num_als;
@@ -190,7 +191,7 @@ NT [^\t \[\],]+
BEGIN(SRC);
}
<INITIAL,LHS_END>. {
- std::cerr << "Line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl;
+ std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in LHS: " << yytext << std::endl;
abort();
}
@@ -217,7 +218,7 @@ NT [^\t \[\],]+
<TRG,FEATS,ALIGNS>\n {
if (scfglex_src_arity != scfglex_trg_arity) {
- std::cerr << "Line " << lex_line << ": LHS and RHS arity mismatch!\n";
+ std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": LHS and RHS arity mismatch!\n";
abort();
}
// const bool ignore_grammar_features = false;
@@ -258,7 +259,7 @@ NT [^\t \[\],]+
BEGIN(FEATS);
}
<FEATVAL>. {
- std::cerr << "Line " << lex_line << ": unexpected input in feature value: " << yytext << std::endl;
+ std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in feature value: " << yytext << std::endl;
abort();
}
<FEATS>{REAL} {
@@ -267,7 +268,7 @@ NT [^\t \[\],]+
++scfglex_num_feats;
}
<FEATS>. {
- std::cerr << "Line " << lex_line << " unexpected input in features: " << yytext << std::endl;
+ std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << " unexpected input in features: " << yytext << std::endl;
abort();
}
<ALIGNS>[0-9]+-[0-9]+ {
@@ -291,14 +292,14 @@ NT [^\t \[\],]+
}
<ALIGNS>[ \t] ;
<ALIGNS>. {
- std::cerr << "Line " << lex_line << ": unexpected input in alignment: " << yytext << std::endl;
+ std::cerr << "Grammar " << scfglex_fname << " line " << lex_line << ": unexpected input in alignment: " << yytext << std::endl;
abort();
}
%%
#include "filelib.h"
-void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, void* extra) {
+void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, const std::string& fname, void* extra) {
if (scfglex_phrase_fnames.empty()) {
scfglex_phrase_fnames.resize(100);
for (int i = 0; i < scfglex_phrase_fnames.size(); ++i) {
@@ -308,6 +309,7 @@ void RuleLexer::ReadRules(std::istream* in, RuleLexer::RuleCallback func, void*
}
}
lex_line = 1;
+ scfglex_fname = fname;
scfglex_stream = in;
rule_callback_extra = extra,
rule_callback = func;
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index a506c591..236d7c90 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -78,7 +78,7 @@ PassThroughGrammar::PassThroughGrammar(const Lattice& input, const string& cat,
}
bool PassThroughGrammar::HasRuleForSpan(int, int, int distance) const {
- return (distance < 2);
+ return (distance < 4); // TODO this isn't great, but helps with EPS lattices
}
struct SCFGTranslatorImpl {
diff --git a/decoder/trule.cc b/decoder/trule.cc
index 896f9f3d..c22baae3 100644
--- a/decoder/trule.cc
+++ b/decoder/trule.cc
@@ -117,7 +117,7 @@ bool TRule::ReadFromString(const string& line, bool strict, bool mono) {
// use lexer
istringstream il(line);
n_assigned=0;
- RuleLexer::ReadRules(&il,assign_trule,this);
+ RuleLexer::ReadRules(&il,assign_trule,"STRING",this);
if (n_assigned>1)
cerr<<"\nWARNING: more than one rule parsed from multi-line string; kept last: "<<line<<".\n";
return n_assigned;