summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorarmatthews <armatthe@cmu.edu>2014-02-20 22:22:15 -0500
committerarmatthews <armatthe@cmu.edu>2014-02-20 22:22:15 -0500
commitfbdc905f6f201e2cc0dbee89ef81e36a53bb3c42 (patch)
tree0b1ccd0a64347bec4c63cb2e81c11e003d0f9e1a /decoder
parente84a9a146495ea3d42f555dac875ff9f74ad4c08 (diff)
parent639d57841510a230af6dead1c27777d45e07be6a (diff)
Merge branch 'master' of https://github.com/redpony/cdec
Diffstat (limited to 'decoder')
-rw-r--r--decoder/bottom_up_parser.cc80
-rw-r--r--decoder/cdec_ff.cc1
-rw-r--r--decoder/ff_ruleshape.cc138
-rw-r--r--decoder/ff_ruleshape.h46
4 files changed, 254 insertions, 11 deletions
diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc
index 606b8d7e..8738c8f1 100644
--- a/decoder/bottom_up_parser.cc
+++ b/decoder/bottom_up_parser.cc
@@ -45,6 +45,7 @@ class PassiveChart {
const float lattice_cost);
void ApplyUnaryRules(const int i, const int j);
+ void TopoSortUnaries();
const vector<GrammarPtr>& grammars_;
const Lattice& input_;
@@ -57,6 +58,7 @@ class PassiveChart {
TRulePtr goal_rule_;
int goal_idx_; // index of goal node, if found
const int lc_fid_;
+ vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars
static WordID kGOAL; // [Goal]
};
@@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal,
goal_cat_(TD::Convert(goal) * -1),
goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")),
goal_idx_(-1),
- lc_fid_(FD::Convert("LatticeCost")) {
+ lc_fid_(FD::Convert("LatticeCost")),
+ unaries_() {
act_chart_.resize(grammars_.size());
- for (unsigned i = 0; i < grammars_.size(); ++i)
+ for (unsigned i = 0; i < grammars_.size(); ++i) {
act_chart_[i] = new ActiveChart(forest, *this);
+ const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules();
+ for (unsigned j = 0; j < u.size(); ++j)
+ unaries_.push_back(u[j]);
+ }
+ TopoSortUnaries();
if (!kGOAL) kGOAL = TD::Convert("Goal") * -1;
if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl;
}
+static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) {
+ if (mark[node] == 1) {
+ cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n";
+ return false; // cycle detected
+ } else if (mark[node] == 2) {
+ return true; // already been
+ }
+ mark[node] = 1;
+ const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node);
+ if (nit != g.end()) {
+ const vector<TRulePtr>& edges = nit->second;
+ vector<bool> okay(edges.size(), true);
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark);
+ if (!okay[i]) {
+ cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl;
+ }
+ }
+ for (unsigned i = 0; i < edges.size(); ++i) {
+ if (okay[i]) u.push_back(edges[i]);
+ //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl;
+ }
+ }
+ mark[node] = 2;
+ return true;
+}
+
+void PassiveChart::TopoSortUnaries() {
+ vector<TRulePtr> u(unaries_.size()); u.clear();
+ map<int, vector<TRulePtr> > g;
+ map<int, int> mark;
+ //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl;
+ mark[goal_cat_] = 2;
+ for (unsigned i = 0; i < unaries_.size(); ++i) {
+ //cerr << "Adding: " << unaries_[i]->AsString() << endl;
+ g[unaries_[i]->f()[0]].push_back(unaries_[i]);
+ }
+ //m[unaries_[i]->lhs_].push_back(unaries_[i]);
+ for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) {
+ //cerr << "PROC: " << TD::Convert(-it->first) << endl;
+ if (mark[it->first] > 0) {
+ //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n";
+ } else {
+ TopoSortVisit(it->first, u, g, mark);
+ }
+ }
+ unaries_.clear();
+ for (int i = u.size() - 1; i >= 0; --i)
+ unaries_.push_back(u[i]);
+}
+
void PassiveChart::ApplyRule(const int i,
const int j,
const TRulePtr& r,
const Hypergraph::TailNodeVector& ant_nodes,
const float lattice_cost) {
Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes);
- //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
+ // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl;
new_edge->prev_i_ = r->prev_i;
new_edge->prev_j_ = r->prev_j;
new_edge->i_ = i;
@@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i,
void PassiveChart::ApplyUnaryRules(const int i, const int j) {
const vector<int>& nodes = chart_(i,j); // reference is important!
- for (unsigned gi = 0; gi < grammars_.size(); ++gi) {
- if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue;
- for (unsigned di = 0; di < nodes.size(); ++di) {
- const WordID& cat = forest_->nodes_[nodes[di]].cat_;
- const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat);
- for (unsigned ri = 0; ri < unaries.size(); ++ri) {
- // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl;
+ for (unsigned di = 0; di < nodes.size(); ++di) {
+ const WordID& cat = forest_->nodes_[nodes[di]].cat_;
+ for (unsigned ri = 0; ri < unaries_.size(); ++ri) {
+ //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl;
+ if (unaries_[ri]->f()[0] == cat) {
+ //cerr << " --MATCH\n";
const Hypergraph::TailNodeVector ant(1, nodes[di]);
- ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes
+ ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes
}
}
}
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index b2541722..0411908f 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -58,6 +58,7 @@ void register_feature_functions() {
ff_registry.Register("KLanguageModel", new KLanguageModelFactory());
ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>);
ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>);
+ ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>);
ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>);
ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>);
ff_registry.Register("NewJump", new FFFactory<NewJump>);
diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc
index 7bb548c4..35b41c46 100644
--- a/decoder/ff_ruleshape.cc
+++ b/decoder/ff_ruleshape.cc
@@ -1,5 +1,8 @@
#include "ff_ruleshape.h"
+#include "filelib.h"
+#include "stringlib.h"
+#include "verbose.h"
#include "trule.h"
#include "hg.h"
#include "fdict.h"
@@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta *
features->set_value(cur->fid_, 1.0);
}
+namespace {
+void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) {
+ vector<string> const& argv=SplitOnWhitespace(in);
+ *emapfile = "";
+ *fmapfile = "";
+ *pfxsize = 0;
+#define RSSPEC_NEXTARG if (i==argv.end()) { \
+ cerr << "Missing argument for "<<*last<<". "; goto usage; \
+ } else { ++i; }
+
+ for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) {
+ string const& s=*i;
+ if (s[0]=='-') {
+ if (s.size()>2) goto fail;
+ switch (s[1]) {
+ case 'e':
+ if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *emapfile=*i;
+ break;
+ case 'f':
+ if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); }
+ RSSPEC_NEXTARG; *fmapfile=*i;
+ break;
+ case 'p':
+ RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str());
+ break;
+#undef RSSPEC_NEXTARG
+ default:
+ fail:
+ cerr<<"Unknown RuleShape2 option "<<s<<" ; ";
+ goto usage;
+ }
+ } else {
+ cerr << "RuleShape2 bad argument!\n";
+ abort();
+ }
+ }
+ return;
+usage:
+ cerr << "Bad parameters for RuleShape2\n";
+ abort();
+}
+
+inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) {
+ if (pfx_size) {
+ const string& ts = TD::Convert(t);
+ if (pfx_size < ts.size())
+ t = TD::Convert(ts.substr(0, pfx_size));
+ }
+ if (f >= pv->size())
+ pv->resize((f + 1) * 1.2);
+ (*pv)[f] = t;
+}
+}
+
+RuleShapeFeatures2::~RuleShapeFeatures2() {}
+
+RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) {
+ string emap;
+ string fmap;
+ unsigned pfxsize = 0;
+ ParseRSArgs(param, &emap, &fmap, &pfxsize);
+ has_src_ = fmap.size();
+ has_trg_ = emap.size();
+ if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_);
+ if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_);
+ if (!has_trg_ && !has_src_) {
+ cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n";
+ abort();
+ }
+}
+
+void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) {
+ ReadFile rf(file);
+ istream& in = *rf.stream();
+ string line;
+ vector<WordID> dummy;
+ int lc = 0;
+ if (!SILENT)
+ cerr << " Loading word classes from " << file << " ...\n";
+ AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0);
+ AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0);
+ while(getline(in, line)) {
+ dummy.clear();
+ TD::ConvertSentence(line, &dummy);
+ ++lc;
+ if (dummy.size() != 2 && dummy.size() != 3) {
+ cerr << " Class map file expects: CLASS WORD [freq]\n";
+ cerr << " Format error in " << file << ", line " << lc << ": " << line << endl;
+ abort();
+ }
+ AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size);
+ }
+ if (!SILENT)
+ cerr << " Loaded word " << lc << " mapping rules.\n";
+}
+
+void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */,
+ const Hypergraph::Edge& edge,
+ const vector<const void*>& /* ant_contexts */,
+ SparseVector<double>* features,
+ SparseVector<double>* /* estimated_features */,
+ void* /* context */) const {
+ const vector<int>& f = edge.rule_->f();
+ const vector<int>& e = edge.rule_->e();
+ Node* fid = &fidtree_;
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i)
+ fid = &fid->next_[MapF(f[i])];
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i)
+ fid = &fid->next_[MapE(e[i])];
+ }
+ if (!fid->fid_) {
+ ostringstream os;
+ os << "RS:";
+ if (has_src_) {
+ for (unsigned i = 0; i < f.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapF(f[i]));
+ }
+ if (has_trg_) os << "__";
+ }
+ if (has_trg_) {
+ for (unsigned i = 0; i < e.size(); ++i) {
+ if (i) os << '_';
+ os << TD::Convert(MapE(e[i]));
+ }
+ }
+ fid->fid_ = FD::Convert(os.str());
+ }
+ features->set_value(fid->fid_, 1);
+}
+
diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h
index 9f20faf3..488cfd84 100644
--- a/decoder/ff_ruleshape.h
+++ b/decoder/ff_ruleshape.h
@@ -2,6 +2,7 @@
#define _FF_RULESHAPE_H_
#include <vector>
+#include <map>
#include "ff.h"
class RuleShapeFeatures : public FeatureFunction {
@@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction {
}
};
+class RuleShapeFeatures2 : public FeatureFunction {
+ public:
+ ~RuleShapeFeatures2();
+ RuleShapeFeatures2(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const HG::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ struct Node {
+ int fid_;
+ Node() : fid_() {}
+ std::map<WordID, Node> next_;
+ };
+ mutable Node fidtree_;
+
+ inline WordID MapE(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < e2class_.size()) res = e2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ inline WordID MapF(WordID w) const {
+ if (w <= 0) return kNT;
+ unsigned res = 0;
+ if (w < f2class_.size()) res = f2class_[w];
+ if (!res) res = kUNK;
+ return res;
+ }
+
+ // prfx_size=0 => use full word classes otherwise truncate to specified length
+ void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv);
+ const WordID kNT;
+ const WordID kUNK;
+ std::vector<WordID> e2class_;
+ std::vector<WordID> f2class_;
+ bool has_src_;
+ bool has_trg_;
+};
+
#endif