diff options
author | armatthews <armatthe@cmu.edu> | 2014-02-20 22:22:15 -0500 |
---|---|---|
committer | armatthews <armatthe@cmu.edu> | 2014-02-20 22:22:15 -0500 |
commit | f7d1893a35ae158e05503ec15b7125e89309aa16 (patch) | |
tree | e46c09e53f168966f6b56f72eb27e2e97f0a70d6 | |
parent | fb3a61574c90bb96f90fb390496eb1277b4cc83c (diff) | |
parent | 7ecdca1448cc709b21692a1a3b6247c45a3126c3 (diff) |
Merge branch 'master' of https://github.com/redpony/cdec
-rwxr-xr-x | corpus/support/quote-norm.pl | 2 | ||||
-rw-r--r-- | decoder/bottom_up_parser.cc | 80 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 1 | ||||
-rw-r--r-- | decoder/ff_ruleshape.cc | 138 | ||||
-rw-r--r-- | decoder/ff_ruleshape.h | 46 | ||||
-rw-r--r-- | tests/system_tests/multigram/cdec.ini | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/g1.scfg | 4 | ||||
-rw-r--r-- | tests/system_tests/multigram/g2.scfg | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/gold.statistics | 3 | ||||
-rw-r--r-- | tests/system_tests/multigram/gold.stdout | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/input.txt | 1 | ||||
-rw-r--r-- | tests/system_tests/multigram/weights | 1 | ||||
-rw-r--r-- | training/mira/kbest_cut_mira.cc | 4 |
13 files changed, 270 insertions, 13 deletions
diff --git a/corpus/support/quote-norm.pl b/corpus/support/quote-norm.pl index bed0844e..1d9bb96f 100755 --- a/corpus/support/quote-norm.pl +++ b/corpus/support/quote-norm.pl @@ -109,6 +109,8 @@ while(<STDIN>) { s/\x{ab}/\"/g; # opening guillemet s/\x{bb}/\"/g; # closing guillemet s/\x{0301}/'/g; # combining acute accent + s/\x{203a}/\"/g; # angle quotation mark + s/\x{2039}/\"/g; # angle quotation mark # Space inverted punctuation: s/¡/ ¡ /g; diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc index 606b8d7e..8738c8f1 100644 --- a/decoder/bottom_up_parser.cc +++ b/decoder/bottom_up_parser.cc @@ -45,6 +45,7 @@ class PassiveChart { const float lattice_cost); void ApplyUnaryRules(const int i, const int j); + void TopoSortUnaries(); const vector<GrammarPtr>& grammars_; const Lattice& input_; @@ -57,6 +58,7 @@ class PassiveChart { TRulePtr goal_rule_; int goal_idx_; // index of goal node, if found const int lc_fid_; + vector<TRulePtr> unaries_; // topologically sorted list of unary rules from all grammars static WordID kGOAL; // [Goal] }; @@ -159,21 +161,78 @@ PassiveChart::PassiveChart(const string& goal, goal_cat_(TD::Convert(goal) * -1), goal_rule_(new TRule("[Goal] ||| [" + goal + ",1] ||| [" + goal + ",1]")), goal_idx_(-1), - lc_fid_(FD::Convert("LatticeCost")) { + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { act_chart_.resize(grammars_.size()); - for (unsigned i = 0; i < grammars_.size(); ++i) + for (unsigned i = 0; i < grammars_.size(); ++i) { act_chart_[i] = new ActiveChart(forest, *this); + const vector<TRulePtr>& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; } +static bool TopoSortVisit(int node, vector<TRulePtr>& u, const map<int, vector<TRulePtr> >& g, map<int, int>& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map<int, vector<TRulePtr> >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector<TRulePtr>& edges = nit->second; + vector<bool> okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void PassiveChart::TopoSortUnaries() { + vector<TRulePtr> u(unaries_.size()); u.clear(); + map<int, vector<TRulePtr> > g; + map<int, int> mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map<int, vector<TRulePtr> >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + void PassiveChart::ApplyRule(const int i, const int j, const TRulePtr& r, const Hypergraph::TailNodeVector& ant_nodes, const float lattice_cost) { Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); - //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; new_edge->prev_i_ = r->prev_i; new_edge->prev_j_ = r->prev_j; new_edge->i_ = i; @@ -215,15 +274,14 @@ void PassiveChart::ApplyRules(const int i, void PassiveChart::ApplyUnaryRules(const int i, const int j) { const vector<int>& nodes = chart_(i,j); // reference is important! - for (unsigned gi = 0; gi < grammars_.size(); ++gi) { - if (!grammars_[gi]->HasRuleForSpan(i,j,input_.Distance(i,j))) continue; - for (unsigned di = 0; di < nodes.size(); ++di) { - const WordID& cat = forest_->nodes_[nodes[di]].cat_; - const vector<TRulePtr>& unaries = grammars_[gi]->GetUnaryRulesForRHS(cat); - for (unsigned ri = 0; ri < unaries.size(); ++ri) { - // cerr << "At (" << i << "," << j << "): applying " << unaries[ri]->AsString() << endl; + for (unsigned di = 0; di < nodes.size(); ++di) { + const WordID& cat = forest_->nodes_[nodes[di]].cat_; + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; const Hypergraph::TailNodeVector ant(1, nodes[di]); - ApplyRule(i, j, unaries[ri], ant, 0); // may update nodes + ApplyRule(i, j, unaries_[ri], ant, 0); // may update nodes } } } diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index b2541722..0411908f 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -58,6 +58,7 @@ void register_feature_functions() { ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); ff_registry.Register("NonLatinCount", new FFFactory<NonLatinCount>); ff_registry.Register("RuleShape", new FFFactory<RuleShapeFeatures>); + ff_registry.Register("RuleShape2", new FFFactory<RuleShapeFeatures2>); ff_registry.Register("RelativeSentencePosition", new FFFactory<RelativeSentencePosition>); ff_registry.Register("LexNullJump", new FFFactory<LexNullJump>); ff_registry.Register("NewJump", new FFFactory<NewJump>); diff --git a/decoder/ff_ruleshape.cc b/decoder/ff_ruleshape.cc index 7bb548c4..35b41c46 100644 --- a/decoder/ff_ruleshape.cc +++ b/decoder/ff_ruleshape.cc @@ -1,5 +1,8 @@ #include "ff_ruleshape.h" +#include "filelib.h" +#include "stringlib.h" +#include "verbose.h" #include "trule.h" #include "hg.h" #include "fdict.h" @@ -104,3 +107,138 @@ void RuleShapeFeatures::TraversalFeaturesImpl(const SentenceMetadata& /* smeta * features->set_value(cur->fid_, 1.0); } +namespace { +void ParseRSArgs(string const& in, string* emapfile, string* fmapfile, unsigned *pfxsize) { + vector<string> const& argv=SplitOnWhitespace(in); + *emapfile = ""; + *fmapfile = ""; + *pfxsize = 0; +#define RSSPEC_NEXTARG if (i==argv.end()) { \ + cerr << "Missing argument for "<<*last<<". "; goto usage; \ + } else { ++i; } + + for (vector<string>::const_iterator last,i=argv.begin(),e=argv.end();i!=e;++i) { + string const& s=*i; + if (s[0]=='-') { + if (s.size()>2) goto fail; + switch (s[1]) { + case 'e': + if (emapfile->size() > 0) { cerr << "Multiple -e specifications!\n"; abort(); } + RSSPEC_NEXTARG; *emapfile=*i; + break; + case 'f': + if (fmapfile->size() > 0) { cerr << "Multiple -f specifications!\n"; abort(); } + RSSPEC_NEXTARG; *fmapfile=*i; + break; + case 'p': + RSSPEC_NEXTARG; *pfxsize=atoi(i->c_str()); + break; +#undef RSSPEC_NEXTARG + default: + fail: + cerr<<"Unknown RuleShape2 option "<<s<<" ; "; + goto usage; + } + } else { + cerr << "RuleShape2 bad argument!\n"; + abort(); + } + } + return; +usage: + cerr << "Bad parameters for RuleShape2\n"; + abort(); +} + +inline void AddWordToClassMapping_(vector<WordID>* pv, unsigned f, unsigned t, unsigned pfx_size) { + if (pfx_size) { + const string& ts = TD::Convert(t); + if (pfx_size < ts.size()) + t = TD::Convert(ts.substr(0, pfx_size)); + } + if (f >= pv->size()) + pv->resize((f + 1) * 1.2); + (*pv)[f] = t; +} +} + +RuleShapeFeatures2::~RuleShapeFeatures2() {} + +RuleShapeFeatures2::RuleShapeFeatures2(const string& param) : kNT(TD::Convert("NT")), kUNK(TD::Convert("<unk>")) { + string emap; + string fmap; + unsigned pfxsize = 0; + ParseRSArgs(param, &emap, &fmap, &pfxsize); + has_src_ = fmap.size(); + has_trg_ = emap.size(); + if (has_trg_) LoadWordClasses(emap, pfxsize, &e2class_); + if (has_src_) LoadWordClasses(fmap, pfxsize, &f2class_); + if (!has_trg_ && !has_src_) { + cerr << "RuleShapeFeatures2 requires [-e trg_map.gz] or [-f src_map.gz] or both, and optional [-p pfxsize]\n"; + abort(); + } +} + +void RuleShapeFeatures2::LoadWordClasses(const string& file, const unsigned pfx_size, vector<WordID>* pv) { + ReadFile rf(file); + istream& in = *rf.stream(); + string line; + vector<WordID> dummy; + int lc = 0; + if (!SILENT) + cerr << " Loading word classes from " << file << " ...\n"; + AddWordToClassMapping_(pv, TD::Convert("<s>"), TD::Convert("<s>"), 0); + AddWordToClassMapping_(pv, TD::Convert("</s>"), TD::Convert("</s>"), 0); + while(getline(in, line)) { + dummy.clear(); + TD::ConvertSentence(line, &dummy); + ++lc; + if (dummy.size() != 2 && dummy.size() != 3) { + cerr << " Class map file expects: CLASS WORD [freq]\n"; + cerr << " Format error in " << file << ", line " << lc << ": " << line << endl; + abort(); + } + AddWordToClassMapping_(pv, dummy[1], dummy[0], pfx_size); + } + if (!SILENT) + cerr << " Loaded word " << lc << " mapping rules.\n"; +} + +void RuleShapeFeatures2::TraversalFeaturesImpl(const SentenceMetadata& /* smeta */, + const Hypergraph::Edge& edge, + const vector<const void*>& /* ant_contexts */, + SparseVector<double>* features, + SparseVector<double>* /* estimated_features */, + void* /* context */) const { + const vector<int>& f = edge.rule_->f(); + const vector<int>& e = edge.rule_->e(); + Node* fid = &fidtree_; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) + fid = &fid->next_[MapF(f[i])]; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) + fid = &fid->next_[MapE(e[i])]; + } + if (!fid->fid_) { + ostringstream os; + os << "RS:"; + if (has_src_) { + for (unsigned i = 0; i < f.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapF(f[i])); + } + if (has_trg_) os << "__"; + } + if (has_trg_) { + for (unsigned i = 0; i < e.size(); ++i) { + if (i) os << '_'; + os << TD::Convert(MapE(e[i])); + } + } + fid->fid_ = FD::Convert(os.str()); + } + features->set_value(fid->fid_, 1); +} + diff --git a/decoder/ff_ruleshape.h b/decoder/ff_ruleshape.h index 9f20faf3..488cfd84 100644 --- a/decoder/ff_ruleshape.h +++ b/decoder/ff_ruleshape.h @@ -2,6 +2,7 @@ #define _FF_RULESHAPE_H_ #include <vector> +#include <map> #include "ff.h" class RuleShapeFeatures : public FeatureFunction { @@ -28,4 +29,49 @@ class RuleShapeFeatures : public FeatureFunction { } }; +class RuleShapeFeatures2 : public FeatureFunction { + public: + ~RuleShapeFeatures2(); + RuleShapeFeatures2(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + struct Node { + int fid_; + Node() : fid_() {} + std::map<WordID, Node> next_; + }; + mutable Node fidtree_; + + inline WordID MapE(WordID w) const { + if (w <= 0) return kNT; + unsigned res = 0; + if (w < e2class_.size()) res = e2class_[w]; + if (!res) res = kUNK; + return res; + } + + inline WordID MapF(WordID w) const { + if (w <= 0) return kNT; + unsigned res = 0; + if (w < f2class_.size()) res = f2class_[w]; + if (!res) res = kUNK; + return res; + } + + // prfx_size=0 => use full word classes otherwise truncate to specified length + void LoadWordClasses(const std::string& fname, unsigned pfxsize, std::vector<WordID>* pv); + const WordID kNT; + const WordID kUNK; + std::vector<WordID> e2class_; + std::vector<WordID> f2class_; + bool has_src_; + bool has_trg_; +}; + #endif diff --git a/tests/system_tests/multigram/cdec.ini b/tests/system_tests/multigram/cdec.ini new file mode 100644 index 00000000..f31becb8 --- /dev/null +++ b/tests/system_tests/multigram/cdec.ini @@ -0,0 +1 @@ +formalism=scfg diff --git a/tests/system_tests/multigram/g1.scfg b/tests/system_tests/multigram/g1.scfg new file mode 100644 index 00000000..a3a59699 --- /dev/null +++ b/tests/system_tests/multigram/g1.scfg @@ -0,0 +1,4 @@ +[X] ||| [Z] ||| [1] ||| Top=1 +[Y] ||| foo ||| foo ||| F1=1 +[Z] ||| [Z] [Y] ||| [1] [2] ||| W1=1 +[Z] ||| [Y] ||| [1] ||| W2=1 diff --git a/tests/system_tests/multigram/g2.scfg b/tests/system_tests/multigram/g2.scfg new file mode 100644 index 00000000..40962517 --- /dev/null +++ b/tests/system_tests/multigram/g2.scfg @@ -0,0 +1 @@ +[Y] ||| bar ||| bar ||| F2=1 diff --git a/tests/system_tests/multigram/gold.statistics b/tests/system_tests/multigram/gold.statistics new file mode 100644 index 00000000..ef23a685 --- /dev/null +++ b/tests/system_tests/multigram/gold.statistics @@ -0,0 +1,3 @@ +-lm_nodes 11 +-lm_edges 12 +-lm_paths 2 diff --git a/tests/system_tests/multigram/gold.stdout b/tests/system_tests/multigram/gold.stdout new file mode 100644 index 00000000..d675fa44 --- /dev/null +++ b/tests/system_tests/multigram/gold.stdout @@ -0,0 +1 @@ +foo bar diff --git a/tests/system_tests/multigram/input.txt b/tests/system_tests/multigram/input.txt new file mode 100644 index 00000000..2aef01a0 --- /dev/null +++ b/tests/system_tests/multigram/input.txt @@ -0,0 +1 @@ +<seg id="0" grammar="g1.scfg" grammar1="g2.scfg">foo bar</seg> diff --git a/tests/system_tests/multigram/weights b/tests/system_tests/multigram/weights new file mode 100644 index 00000000..a6b6698a --- /dev/null +++ b/tests/system_tests/multigram/weights @@ -0,0 +1 @@ +Glue -1 diff --git a/training/mira/kbest_cut_mira.cc b/training/mira/kbest_cut_mira.cc index e075bed3..1a6415be 100644 --- a/training/mira/kbest_cut_mira.cc +++ b/training/mira/kbest_cut_mira.cc @@ -82,14 +82,14 @@ bool InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("optimizer,o",po::value<int>()->default_value(1), "Optimizer (SGD=1, PA MIRA w/Delta=2, Cutting Plane MIRA=3, PA MIRA=4, Triple nbest list MIRA=5)") ("fear,f",po::value<int>()->default_value(1), "Fear selection (model-cost=1, maxcost=2, maxscore=3)") ("hope,h",po::value<int>()->default_value(1), "Hope selection (model+cost=1, mincost=2)") - ("max_step_size,C", po::value<double>()->default_value(0.01), "regularization strength (C)") + ("max_step_size,C", po::value<double>()->default_value(0.001), "regularization strength (C)") ("random_seed,S", po::value<uint32_t>(), "Random seed (if not specified, /dev/random will be used)") ("mt_metric_scale,s", po::value<double>()->default_value(1.0), "Amount to scale MT loss function by") ("sent_approx,a", "Use smoothed sentence-level BLEU score for approximate scoring") ("pseudo_doc,e", "Use pseudo-document BLEU score for approximate scoring") ("no_reweight,d","Do not reweight forest for cutting plane") ("no_select,n", "Do not use selection heuristic") - ("k_best_size,k", po::value<int>()->default_value(250), "Size of hypothesis list to search for oracles") + ("k_best_size,k", po::value<int>()->default_value(500), "Size of hypothesis list to search for oracles") ("update_k_best,b", po::value<int>()->default_value(1), "Size of good, bad lists to perform update with") ("unique_k_best,u", "Unique k-best translation list") ("stream,t", "Stream mode (used for realtime)") |