diff options
author | Chris Dyer <redpony@gmail.com> | 2009-12-17 13:57:54 -0500 |
---|---|---|
committer | Chris Dyer <redpony@gmail.com> | 2009-12-17 13:57:54 -0500 |
commit | bba4ff830c8722cdcaf29e36c1ff5821a912ae5d (patch) | |
tree | 268f2f8118aca09b3cc40dca8b2be7de8295acd5 | |
parent | 04ae1beeaeceb0161a64d33112f21956f9741bde (diff) |
added non-pruning intersection and a CRF tagger
- the linear-chain tagger is more of a proof of concept than a real tagger-- the context-free assumptions made in a number of places mean that the algorithms used may not be as efficient as they could be, but the model is as powerful as any CRF
- it would be easy to add latent variables or semi-CRF support (or both!)
- i've added a couple basic features that are often used for POS tagging
- non-pruning intersection is useful for lexical word alignment models and the tagger
- a sample POS tagger model will be committed later
-rw-r--r-- | decoder/Makefile.am | 2 | ||||
-rw-r--r-- | decoder/apply_models.cc | 89 | ||||
-rw-r--r-- | decoder/apply_models.h | 6 | ||||
-rw-r--r-- | decoder/cdec.cc | 22 | ||||
-rw-r--r-- | decoder/cdec_ff.cc | 3 | ||||
-rw-r--r-- | decoder/ff_tagger.cc | 96 | ||||
-rw-r--r-- | decoder/ff_tagger.h | 51 | ||||
-rw-r--r-- | decoder/lexcrf.cc | 2 | ||||
-rw-r--r-- | decoder/tagger.cc | 109 | ||||
-rw-r--r-- | decoder/tagger.h | 17 | ||||
-rw-r--r-- | decoder/trule.h | 5 | ||||
-rwxr-xr-x | training/cluster-ptrain.pl | 1 |
12 files changed, 382 insertions, 21 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am index f3843102..4c86ae6f 100644 --- a/decoder/Makefile.am +++ b/decoder/Makefile.am @@ -60,8 +60,10 @@ libcdec_a_SOURCES = \ ff_lm.cc \ ff_wordalign.cc \ ff_csplit.cc \ + ff_tagger.cc \ freqdict.cc \ lexcrf.cc \ + tagger.cc \ bottom_up_parser.cc \ phrasebased_translator.cc \ JSON_parser.c \ diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc index b1d002f4..a340aa1a 100644 --- a/decoder/apply_models.cc +++ b/decoder/apply_models.cc @@ -296,14 +296,69 @@ public: }; struct NoPruningRescorer { - NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) : + NoPruningRescorer(const ModelSet& m, const SentenceMetadata &sm, const Hypergraph& i, Hypergraph* o) : models(m), + smeta(sm), in(i), - out(*o) { + out(*o), + nodemap(i.nodes_.size()) { cerr << " Rescoring forest (full intersection)\n"; } - void RescoreNode(const int node_num, const bool is_goal) { + typedef unordered_map<string, int, boost::hash<string> > State2NodeIndex; + + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) { + const int arity = in_edge.Arity(); + Hypergraph::TailNodeVector ends(arity); + for (int i = 0; i < arity; ++i) + ends[i] = nodemap[in_edge.tail_nodes_[i]].size(); + + Hypergraph::TailNodeVector tail_iter(arity, 0); + bool done = false; + while (!done) { + Hypergraph::TailNodeVector tail(arity); + for (int i = 0; i < arity; ++i) + tail[i] = nodemap[in_edge.tail_nodes_[i]][tail_iter[i]]; + Hypergraph::Edge* new_edge = out.AddEdge(in_edge.rule_, tail); + new_edge->feature_values_ = in_edge.feature_values_; + new_edge->i_ = in_edge.i_; + new_edge->j_ = in_edge.j_; + new_edge->prev_i_ = in_edge.prev_i_; + new_edge->prev_j_ = in_edge.prev_j_; + string head_state; + if (is_goal) { + assert(tail.size() == 1); + const string& ant_state = out.nodes_[tail.front()].state_; + models.AddFinalFeatures(ant_state, new_edge); + } else { + prob_t edge_estimate; // this is a full intersection, so we disregard this + models.AddFeaturesToEdge(smeta, out, new_edge, &head_state, &edge_estimate); + } + int& head_plus1 = (*state2node)[head_state]; + if (!head_plus1) { + head_plus1 = out.AddNode(in_edge.rule_->GetLHS(), head_state)->id_ + 1; + nodemap[in_edge.head_node_].push_back(head_plus1 - 1); + } + const int head_index = head_plus1 - 1; + out.ConnectEdgeToHeadNode(new_edge->id_, head_index); + + int ii = 0; + for (; ii < arity; ++ii) { + ++tail_iter[ii]; + if (tail_iter[ii] < ends[ii]) break; + tail_iter[ii] = 0; + } + done = (ii == arity); + } + } + + void ProcessOneNode(const int node_num, const bool is_goal) { + State2NodeIndex state2node; + const Hypergraph::Node& node = in.nodes_[node_num]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; + ExpandEdge(edge, is_goal, &state2node); + } } void Apply() { @@ -316,29 +371,41 @@ struct NoPruningRescorer { cerr << " "; for (int i = 0; i < in.nodes_.size(); ++i) { if (i % every == 0) cerr << '.'; - RescoreNode(i, i == goal_id); + ProcessOneNode(i, i == goal_id); } cerr << endl; } private: const ModelSet& models; + const SentenceMetadata& smeta; const Hypergraph& in; Hypergraph& out; + + vector<vector<int> > nodemap; }; // each node in the graph has one of these, it keeps track of void ApplyModelSet(const Hypergraph& in, const SentenceMetadata& smeta, const ModelSet& models, - const PruningConfiguration& config, + const IntersectionConfiguration& config, Hypergraph* out) { - int pl = config.pop_limit; - if (pl > 100 && in.nodes_.size() > 80000) { - cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; - pl = 30; + // TODO special handling when all models are stateless + if (config.algorithm == 1) { + int pl = config.pop_limit; + if (pl > 100 && in.nodes_.size() > 80000) { + cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; + pl = 30; + } + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } else if (config.algorithm == 0) { + NoPruningRescorer ma(models, smeta, in, out); + ma.Apply(); + } else { + cerr << "Don't understand intersection algorithm " << config.algorithm << endl; + exit(1); } - CubePruningRescorer ma(models, smeta, in, pl, out); - ma.Apply(); } diff --git a/decoder/apply_models.h b/decoder/apply_models.h index 08fce037..d6d8b34a 100644 --- a/decoder/apply_models.h +++ b/decoder/apply_models.h @@ -5,16 +5,16 @@ struct ModelSet; struct Hypergraph; struct SentenceMetadata; -struct PruningConfiguration { +struct IntersectionConfiguration { const int algorithm; // 0 = full intersection, 1 = cube pruning const int pop_limit; // max number of pops off the heap at each node - explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {} + IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {} }; void ApplyModelSet(const Hypergraph& in, const SentenceMetadata& smeta, const ModelSet& models, - const PruningConfiguration& config, + const IntersectionConfiguration& config, Hypergraph* out); #endif diff --git a/decoder/cdec.cc b/decoder/cdec.cc index 6185c79b..c6773cce 100644 --- a/decoder/cdec.cc +++ b/decoder/cdec.cc @@ -17,6 +17,7 @@ #include "filelib.h" #include "sampler.h" #include "sparse_vector.h" +#include "tagger.h" #include "lexcrf.h" #include "csplit.h" #include "weights.h" @@ -48,7 +49,7 @@ void ShowBanner() { void InitCommandLine(int argc, char** argv, po::variables_map* conf) { po::options_description opts("Configuration options"); opts.add_options() - ("formalism,f",po::value<string>(),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting)") + ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting), Tagger (sequence labeling)") ("input,i",po::value<string>()->default_value("-"),"Source file") ("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") ("weights,w",po::value<string>(),"Feature weights file") @@ -58,16 +59,18 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { ("k_best,k",po::value<int>(),"Extract the k best derivations") ("unique_k_best,r", "Unique k-best translation list") ("aligner,a", "Run as a word/phrase aligner (src & ref required)") + ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full") ("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node") ("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)") ("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") ("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG") ("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") - ("show_tree_structure,T", "Show the Viterbi derivation structure") + ("show_tree_structure", "Show the Viterbi derivation structure") ("show_expected_length", "Show the expected translation length under the model") ("show_partition,z", "Compute and show the partition (inside score)") ("beam_prune", po::value<double>(), "Prune paths from +LM forest") + ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") ("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file") @@ -111,8 +114,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) { } const string formalism = LowercaseString((*conf)["formalism"].as<string>()); - if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit") { - cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit' or 'lexcrf'\n"; + if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit" && formalism != "tagger") { + cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lexcrf', or 'tagger'\n"; cerr << dcmdline_options << endl; exit(1); } @@ -255,6 +258,8 @@ int main(int argc, char** argv) { translator.reset(new CompoundSplit(conf)); else if (formalism == "lexcrf") translator.reset(new LexicalCRF(conf)); + else if (formalism == "tagger") + translator.reset(new Tagger(conf)); else assert(!"error"); @@ -285,6 +290,12 @@ int main(int argc, char** argv) { } } ModelSet late_models(feature_weights, late_ffs); + int palg = 1; + if (LowercaseString(conf["intersection_strategy"].as<string>()) == "full") { + palg = 0; + cerr << "Using full intersection (no pruning).\n"; + } + const IntersectionConfiguration inter_conf(palg, conf["cubepruning_pop_limit"].as<int>()); const int sample_max_trans = conf.count("max_translation_sample") ? conf["max_translation_sample"].as<int>() : 0; @@ -374,11 +385,10 @@ int main(int argc, char** argv) { forest.Reweight(feature_weights); forest.SortInEdgesByEdgeWeights(); Hypergraph lm_forest; - int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>(); ApplyModelSet(forest, smeta, late_models, - PruningConfiguration(cubepruning_pop_limit), + inter_conf, &lm_forest); forest.swap(lm_forest); forest.Reweight(feature_weights); diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc index 0a4f3d5e..bb2c9d34 100644 --- a/decoder/cdec_ff.cc +++ b/decoder/cdec_ff.cc @@ -4,6 +4,7 @@ #include "ff_lm.h" #include "ff_csplit.h" #include "ff_wordalign.h" +#include "ff_tagger.h" #include "ff_factory.h" boost::shared_ptr<FFRegistry> global_ff_registry; @@ -18,5 +19,7 @@ void register_feature_functions() { global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>); global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>); global_ff_registry->Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>); + global_ff_registry->Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>); + global_ff_registry->Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>); }; diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc new file mode 100644 index 00000000..7a9d1def --- /dev/null +++ b/decoder/ff_tagger.cc @@ -0,0 +1,96 @@ +#include "ff_tagger.h" + +#include "tdict.h" +#include "sentence_metadata.h" + +#include <sstream> + +using namespace std; + +Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) : + FeatureFunction(sizeof(WordID)) {} + +void Tagger_BigramIdentity::FireFeature(const WordID& left, + const WordID& right, + SparseVector<double>* features) const { + int& fid = fmap_[left][right]; + if (!fid) { + ostringstream os; + if (right == 0) { + os << "Uni:" << TD::Convert(left); + } else { + os << "Bi:"; + if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); } + os << '_'; + if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); } + } + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); +} + +void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + WordID& out_context = *static_cast<WordID*>(context); + const int arity = edge.Arity(); + if (arity == 0) { + out_context = edge.rule_->e_[0]; + FireFeature(out_context, 0, features); + } else if (arity == 2) { + WordID left = *static_cast<const WordID*>(ant_contexts[0]); + WordID right = *static_cast<const WordID*>(ant_contexts[1]); + if (edge.i_ == 0 && edge.j_ == 2) + FireFeature(-1, left, features); + FireFeature(left, right, features); + if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength()) + FireFeature(right, -1, features); + out_context = right; + } +} + +LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {} + +void LexicalPairIdentity::FireFeature(WordID src, + WordID trg, + SparseVector<double>* features) const { + int& fid = fmap_[src][trg]; + if (!fid) { + static map<WordID, WordID> escape; + if (escape.empty()) { + escape[TD::Convert("=")] = TD::Convert("__EQ"); + escape[TD::Convert(";")] = TD::Convert("__SC"); + escape[TD::Convert(",")] = TD::Convert("__CO"); + } + if (escape.count(src)) src = escape[src]; + if (escape.count(trg)) trg = escape[trg]; + ostringstream os; + os << "Id:" << TD::Convert(src) << ':' << TD::Convert(trg); + fid = FD::Convert(os.str()); + } + features->set_value(fid, 1.0); +} + +void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const { + const vector<WordID>& ew = edge.rule_->e_; + const vector<WordID>& fw = edge.rule_->f_; + for (int i = 0; i < ew.size(); ++i) { + const WordID& e = ew[i]; + if (e <= 0) continue; + for (int j = 0; j < fw.size(); ++j) { + const WordID& f = fw[j]; + if (f <= 0) continue; + FireFeature(f, e, features); + } + } +} + + diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h new file mode 100644 index 00000000..41c3ee5b --- /dev/null +++ b/decoder/ff_tagger.h @@ -0,0 +1,51 @@ +#ifndef _FF_TAGGER_H_ +#define _FF_TAGGER_H_ + +#include <map> +#include "ff.h" + +typedef std::map<WordID, int> Class2FID; +typedef std::map<WordID, Class2FID> Class2Class2FID; + +// the reason this is a "tagger" feature is that it assumes that +// the sequence unfolds from left to right, which means it doesn't +// have to split states based on left context. +// fires unigram features as well +class Tagger_BigramIdentity : public FeatureFunction { + public: + Tagger_BigramIdentity(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + void FireFeature(const WordID& left, + const WordID& right, + SparseVector<double>* features) const; + mutable Class2Class2FID fmap_; +}; + +// for each pair of symbols cooccuring in a lexicalized rule, fire +// a feature (mostly used for tagging, but could be used for any model) +class LexicalPairIdentity : public FeatureFunction { + public: + LexicalPairIdentity(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector<const void*>& ant_contexts, + SparseVector<double>* features, + SparseVector<double>* estimated_features, + void* context) const; + private: + void FireFeature(WordID src, + WordID trg, + SparseVector<double>* features) const; + mutable Class2Class2FID fmap_; +}; + + +#endif diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc index 33455a3d..816506e4 100644 --- a/decoder/lexcrf.cc +++ b/decoder/lexcrf.cc @@ -81,7 +81,7 @@ struct LexicalCRFImpl { } } Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); - Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); forest->ConnectEdgeToHeadNode(hg_edge, goal); } diff --git a/decoder/tagger.cc b/decoder/tagger.cc new file mode 100644 index 00000000..5a0155cc --- /dev/null +++ b/decoder/tagger.cc @@ -0,0 +1,109 @@ +#include "tagger.h" + +#include "tdict.h" +#include "hg_io.h" +#include "filelib.h" +#include "hg.h" +#include "wordid.h" +#include "sentence_metadata.h" + +using namespace std; + +// This is a really simple linear chain tagger. +// You specify a tagset, and it hypothesizes that each word in the +// input can be tagged with any member of the tagset. +// The are a couple sample features implemented in ff_tagger.h/cc +// One thing to note, that while CRFs typically define the label +// sequence as corresponding to the hidden states in a trellis, +// in our model the labels are on edges, but mathematically +// they are identical. +// +// Things to do if you want to make this a "real" tagger: +// - support dictionaries (for each word, limit the tags considered) +// - add latent variables - this is really easy to do + +static void ReadTagset(const string& file, vector<WordID>* tags) { + ReadFile rf(file); + istream& in(*rf.stream()); + while(in) { + string tag; + in >> tag; + if (tag.empty()) continue; + tags->push_back(TD::Convert(tag)); + } + cerr << "Read " << tags->size() << " labels (tags) from " << file << endl; +} + +struct TaggerImpl { + TaggerImpl(const boost::program_options::variables_map& conf) : + kXCAT(TD::Convert("X")*-1), + kNULL(TD::Convert("<eps>")), + kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { + if (conf.count("tagger_tagset") == 0) { + cerr << "Tagger requires --tagger_tagset FILE!\n"; + exit(1); + } + ReadTagset(conf["tagger_tagset"].as<string>(), &tagset_); + } + + void BuildTrellis(const vector<WordID>& seq, Hypergraph* forest) { + int prev_node_id = -1; + for (int i = 0; i < seq.size(); ++i) { + const WordID& src = seq[i]; + const int new_node_id = forest->AddNode(kXCAT)->id_; + for (int k = 0; k < tagset_.size(); ++k) { + TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k])); + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = i; + edge->j_ = i+1; + forest->ConnectEdgeToHeadNode(edge->id_, new_node_id); + } + if (prev_node_id >= 0) { + const int comb_node_id = forest->AddNode(kXCAT)->id_; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = new_node_id; + Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); + edge->i_ = 0; + edge->j_ = i+1; + forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id); + prev_node_id = comb_node_id; + } else { + prev_node_id = new_node_id; + } + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + + private: + vector<WordID> tagset_; + const WordID kXCAT; + const WordID kNULL; + const TRulePtr kBINARY; + const TRulePtr kGOAL_RULE; +}; + +Tagger::Tagger(const boost::program_options::variables_map& conf) : + pimpl_(new TaggerImpl(conf)) {} + + +bool Tagger::Translate(const string& input, + SentenceMetadata* smeta, + const vector<double>& weights, + Hypergraph* forest) { + Lattice lattice; + LatticeTools::ConvertTextToLattice(input, &lattice); + smeta->SetSourceLength(lattice.size()); + vector<WordID> sequence(lattice.size()); + for (int i = 0; i < lattice.size(); ++i) { + assert(lattice[i].size() == 1); + sequence[i] = lattice[i][0].label; + } + pimpl_->BuildTrellis(sequence, forest); + forest->Reweight(weights); + return true; +} + diff --git a/decoder/tagger.h b/decoder/tagger.h new file mode 100644 index 00000000..900019f2 --- /dev/null +++ b/decoder/tagger.h @@ -0,0 +1,17 @@ +#ifndef _TAGGER_H_ +#define _TAGGER_H_ + +#include "translator.h" + +struct TaggerImpl; +struct Tagger : public Translator { + Tagger(const boost::program_options::variables_map& conf); + bool Translate(const std::string& input, + SentenceMetadata* smeta, + const std::vector<double>& weights, + Hypergraph* forest); + private: + boost::shared_ptr<TaggerImpl> pimpl_; +}; + +#endif diff --git a/decoder/trule.h b/decoder/trule.h index d2b1babe..42edfa2c 100644 --- a/decoder/trule.h +++ b/decoder/trule.h @@ -39,6 +39,10 @@ class TRule { // [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT] static TRule* CreateRuleMonolingual(const std::string& rule); + static TRule* CreateLexicalRule(const WordID& src, const WordID& trg) { + return new TRule(src, trg); + } + void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values, std::vector<WordID>* result) const { int vc = 0; @@ -116,6 +120,7 @@ class TRule { short int prev_j; private: + TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {} bool SanityCheck() const; }; diff --git a/training/cluster-ptrain.pl b/training/cluster-ptrain.pl index 8b06f162..33aab25d 100755 --- a/training/cluster-ptrain.pl +++ b/training/cluster-ptrain.pl @@ -36,6 +36,7 @@ GetOptions("cdec=s" => \$DECODER, "sigma_squared=f" => \$sigsq, "means=s" => \$means_file, "optimizer=s" => \$OALG, + "gaussian_prior" => \$PRIOR, "jobs=i" => \$nodes, "pmem=s" => \$pmem ) or usage(); |