summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2009-12-17 13:57:54 -0500
committerChris Dyer <redpony@gmail.com>2009-12-17 13:57:54 -0500
commitbba4ff830c8722cdcaf29e36c1ff5821a912ae5d (patch)
tree268f2f8118aca09b3cc40dca8b2be7de8295acd5
parent04ae1beeaeceb0161a64d33112f21956f9741bde (diff)
added non-pruning intersection and a CRF tagger
- the linear-chain tagger is more of a proof of concept than a real tagger-- the context-free assumptions made in a number of places mean that the algorithms used may not be as efficient as they could be, but the model is as powerful as any CRF - it would be easy to add latent variables or semi-CRF support (or both!) - i've added a couple basic features that are often used for POS tagging - non-pruning intersection is useful for lexical word alignment models and the tagger - a sample POS tagger model will be committed later
-rw-r--r--decoder/Makefile.am2
-rw-r--r--decoder/apply_models.cc89
-rw-r--r--decoder/apply_models.h6
-rw-r--r--decoder/cdec.cc22
-rw-r--r--decoder/cdec_ff.cc3
-rw-r--r--decoder/ff_tagger.cc96
-rw-r--r--decoder/ff_tagger.h51
-rw-r--r--decoder/lexcrf.cc2
-rw-r--r--decoder/tagger.cc109
-rw-r--r--decoder/tagger.h17
-rw-r--r--decoder/trule.h5
-rwxr-xr-xtraining/cluster-ptrain.pl1
12 files changed, 382 insertions, 21 deletions
diff --git a/decoder/Makefile.am b/decoder/Makefile.am
index f3843102..4c86ae6f 100644
--- a/decoder/Makefile.am
+++ b/decoder/Makefile.am
@@ -60,8 +60,10 @@ libcdec_a_SOURCES = \
ff_lm.cc \
ff_wordalign.cc \
ff_csplit.cc \
+ ff_tagger.cc \
freqdict.cc \
lexcrf.cc \
+ tagger.cc \
bottom_up_parser.cc \
phrasebased_translator.cc \
JSON_parser.c \
diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc
index b1d002f4..a340aa1a 100644
--- a/decoder/apply_models.cc
+++ b/decoder/apply_models.cc
@@ -296,14 +296,69 @@ public:
};
struct NoPruningRescorer {
- NoPruningRescorer(const ModelSet& m, const Hypergraph& i, Hypergraph* o) :
+ NoPruningRescorer(const ModelSet& m, const SentenceMetadata &sm, const Hypergraph& i, Hypergraph* o) :
models(m),
+ smeta(sm),
in(i),
- out(*o) {
+ out(*o),
+ nodemap(i.nodes_.size()) {
cerr << " Rescoring forest (full intersection)\n";
}
- void RescoreNode(const int node_num, const bool is_goal) {
+ typedef unordered_map<string, int, boost::hash<string> > State2NodeIndex;
+
+ void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, State2NodeIndex* state2node) {
+ const int arity = in_edge.Arity();
+ Hypergraph::TailNodeVector ends(arity);
+ for (int i = 0; i < arity; ++i)
+ ends[i] = nodemap[in_edge.tail_nodes_[i]].size();
+
+ Hypergraph::TailNodeVector tail_iter(arity, 0);
+ bool done = false;
+ while (!done) {
+ Hypergraph::TailNodeVector tail(arity);
+ for (int i = 0; i < arity; ++i)
+ tail[i] = nodemap[in_edge.tail_nodes_[i]][tail_iter[i]];
+ Hypergraph::Edge* new_edge = out.AddEdge(in_edge.rule_, tail);
+ new_edge->feature_values_ = in_edge.feature_values_;
+ new_edge->i_ = in_edge.i_;
+ new_edge->j_ = in_edge.j_;
+ new_edge->prev_i_ = in_edge.prev_i_;
+ new_edge->prev_j_ = in_edge.prev_j_;
+ string head_state;
+ if (is_goal) {
+ assert(tail.size() == 1);
+ const string& ant_state = out.nodes_[tail.front()].state_;
+ models.AddFinalFeatures(ant_state, new_edge);
+ } else {
+ prob_t edge_estimate; // this is a full intersection, so we disregard this
+ models.AddFeaturesToEdge(smeta, out, new_edge, &head_state, &edge_estimate);
+ }
+ int& head_plus1 = (*state2node)[head_state];
+ if (!head_plus1) {
+ head_plus1 = out.AddNode(in_edge.rule_->GetLHS(), head_state)->id_ + 1;
+ nodemap[in_edge.head_node_].push_back(head_plus1 - 1);
+ }
+ const int head_index = head_plus1 - 1;
+ out.ConnectEdgeToHeadNode(new_edge->id_, head_index);
+
+ int ii = 0;
+ for (; ii < arity; ++ii) {
+ ++tail_iter[ii];
+ if (tail_iter[ii] < ends[ii]) break;
+ tail_iter[ii] = 0;
+ }
+ done = (ii == arity);
+ }
+ }
+
+ void ProcessOneNode(const int node_num, const bool is_goal) {
+ State2NodeIndex state2node;
+ const Hypergraph::Node& node = in.nodes_[node_num];
+ for (int i = 0; i < node.in_edges_.size(); ++i) {
+ const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]];
+ ExpandEdge(edge, is_goal, &state2node);
+ }
}
void Apply() {
@@ -316,29 +371,41 @@ struct NoPruningRescorer {
cerr << " ";
for (int i = 0; i < in.nodes_.size(); ++i) {
if (i % every == 0) cerr << '.';
- RescoreNode(i, i == goal_id);
+ ProcessOneNode(i, i == goal_id);
}
cerr << endl;
}
private:
const ModelSet& models;
+ const SentenceMetadata& smeta;
const Hypergraph& in;
Hypergraph& out;
+
+ vector<vector<int> > nodemap;
};
// each node in the graph has one of these, it keeps track of
void ApplyModelSet(const Hypergraph& in,
const SentenceMetadata& smeta,
const ModelSet& models,
- const PruningConfiguration& config,
+ const IntersectionConfiguration& config,
Hypergraph* out) {
- int pl = config.pop_limit;
- if (pl > 100 && in.nodes_.size() > 80000) {
- cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
- pl = 30;
+ // TODO special handling when all models are stateless
+ if (config.algorithm == 1) {
+ int pl = config.pop_limit;
+ if (pl > 100 && in.nodes_.size() > 80000) {
+ cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n";
+ pl = 30;
+ }
+ CubePruningRescorer ma(models, smeta, in, pl, out);
+ ma.Apply();
+ } else if (config.algorithm == 0) {
+ NoPruningRescorer ma(models, smeta, in, out);
+ ma.Apply();
+ } else {
+ cerr << "Don't understand intersection algorithm " << config.algorithm << endl;
+ exit(1);
}
- CubePruningRescorer ma(models, smeta, in, pl, out);
- ma.Apply();
}
diff --git a/decoder/apply_models.h b/decoder/apply_models.h
index 08fce037..d6d8b34a 100644
--- a/decoder/apply_models.h
+++ b/decoder/apply_models.h
@@ -5,16 +5,16 @@ struct ModelSet;
struct Hypergraph;
struct SentenceMetadata;
-struct PruningConfiguration {
+struct IntersectionConfiguration {
const int algorithm; // 0 = full intersection, 1 = cube pruning
const int pop_limit; // max number of pops off the heap at each node
- explicit PruningConfiguration(int k) : algorithm(1), pop_limit(k) {}
+ IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {}
};
void ApplyModelSet(const Hypergraph& in,
const SentenceMetadata& smeta,
const ModelSet& models,
- const PruningConfiguration& config,
+ const IntersectionConfiguration& config,
Hypergraph* out);
#endif
diff --git a/decoder/cdec.cc b/decoder/cdec.cc
index 6185c79b..c6773cce 100644
--- a/decoder/cdec.cc
+++ b/decoder/cdec.cc
@@ -17,6 +17,7 @@
#include "filelib.h"
#include "sampler.h"
#include "sparse_vector.h"
+#include "tagger.h"
#include "lexcrf.h"
#include "csplit.h"
#include "weights.h"
@@ -48,7 +49,7 @@ void ShowBanner() {
void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
po::options_description opts("Configuration options");
opts.add_options()
- ("formalism,f",po::value<string>(),"Translation formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting)")
+ ("formalism,f",po::value<string>(),"Decoding formalism; values include SCFG, FST, PB, LexCRF (lexical translation model), CSplit (compound splitting), Tagger (sequence labeling)")
("input,i",po::value<string>()->default_value("-"),"Source file")
("grammar,g",po::value<vector<string> >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)")
("weights,w",po::value<string>(),"Feature weights file")
@@ -58,16 +59,18 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
("k_best,k",po::value<int>(),"Extract the k best derivations")
("unique_k_best,r", "Unique k-best translation list")
("aligner,a", "Run as a word/phrase aligner (src & ref required)")
+ ("intersection_strategy,I",po::value<string>()->default_value("cube_pruning"), "Intersection strategy for incorporating finite-state features; values include Cube_pruning, Full")
("cubepruning_pop_limit,K",po::value<int>()->default_value(200), "Max number of pops from the candidate heap at each node")
("goal",po::value<string>()->default_value("S"),"Goal symbol (SCFG & FST)")
("scfg_extra_glue_grammar", po::value<string>(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)")
("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)")
("scfg_default_nt,d",po::value<string>()->default_value("X"),"Default non-terminal symbol in SCFG")
("scfg_max_span_limit,S",po::value<int>()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)")
- ("show_tree_structure,T", "Show the Viterbi derivation structure")
+ ("show_tree_structure", "Show the Viterbi derivation structure")
("show_expected_length", "Show the expected translation length under the model")
("show_partition,z", "Compute and show the partition (inside score)")
("beam_prune", po::value<double>(), "Prune paths from +LM forest")
+ ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
("csplit_output_plf", "(Compound splitter) Output lattice in PLF format")
("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice")
("extract_rules", po::value<string>(), "Extract the rules used in translation (de-duped) to this file")
@@ -111,8 +114,8 @@ void InitCommandLine(int argc, char** argv, po::variables_map* conf) {
}
const string formalism = LowercaseString((*conf)["formalism"].as<string>());
- if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit") {
- cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit' or 'lexcrf'\n";
+ if (formalism != "scfg" && formalism != "fst" && formalism != "lexcrf" && formalism != "pb" && formalism != "csplit" && formalism != "tagger") {
+ cerr << "Error: --formalism takes only 'scfg', 'fst', 'pb', 'csplit', 'lexcrf', or 'tagger'\n";
cerr << dcmdline_options << endl;
exit(1);
}
@@ -255,6 +258,8 @@ int main(int argc, char** argv) {
translator.reset(new CompoundSplit(conf));
else if (formalism == "lexcrf")
translator.reset(new LexicalCRF(conf));
+ else if (formalism == "tagger")
+ translator.reset(new Tagger(conf));
else
assert(!"error");
@@ -285,6 +290,12 @@ int main(int argc, char** argv) {
}
}
ModelSet late_models(feature_weights, late_ffs);
+ int palg = 1;
+ if (LowercaseString(conf["intersection_strategy"].as<string>()) == "full") {
+ palg = 0;
+ cerr << "Using full intersection (no pruning).\n";
+ }
+ const IntersectionConfiguration inter_conf(palg, conf["cubepruning_pop_limit"].as<int>());
const int sample_max_trans = conf.count("max_translation_sample") ?
conf["max_translation_sample"].as<int>() : 0;
@@ -374,11 +385,10 @@ int main(int argc, char** argv) {
forest.Reweight(feature_weights);
forest.SortInEdgesByEdgeWeights();
Hypergraph lm_forest;
- int cubepruning_pop_limit = conf["cubepruning_pop_limit"].as<int>();
ApplyModelSet(forest,
smeta,
late_models,
- PruningConfiguration(cubepruning_pop_limit),
+ inter_conf,
&lm_forest);
forest.swap(lm_forest);
forest.Reweight(feature_weights);
diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc
index 0a4f3d5e..bb2c9d34 100644
--- a/decoder/cdec_ff.cc
+++ b/decoder/cdec_ff.cc
@@ -4,6 +4,7 @@
#include "ff_lm.h"
#include "ff_csplit.h"
#include "ff_wordalign.h"
+#include "ff_tagger.h"
#include "ff_factory.h"
boost::shared_ptr<FFRegistry> global_ff_registry;
@@ -18,5 +19,7 @@ void register_feature_functions() {
global_ff_registry->Register("AlignerResults", new FFFactory<AlignerResults>);
global_ff_registry->Register("CSplit_BasicFeatures", new FFFactory<BasicCSplitFeatures>);
global_ff_registry->Register("CSplit_ReverseCharLM", new FFFactory<ReverseCharLMCSplitFeature>);
+ global_ff_registry->Register("Tagger_BigramIdentity", new FFFactory<Tagger_BigramIdentity>);
+ global_ff_registry->Register("LexicalPairIdentity", new FFFactory<LexicalPairIdentity>);
};
diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc
new file mode 100644
index 00000000..7a9d1def
--- /dev/null
+++ b/decoder/ff_tagger.cc
@@ -0,0 +1,96 @@
+#include "ff_tagger.h"
+
+#include "tdict.h"
+#include "sentence_metadata.h"
+
+#include <sstream>
+
+using namespace std;
+
+Tagger_BigramIdentity::Tagger_BigramIdentity(const std::string& param) :
+ FeatureFunction(sizeof(WordID)) {}
+
+void Tagger_BigramIdentity::FireFeature(const WordID& left,
+ const WordID& right,
+ SparseVector<double>* features) const {
+ int& fid = fmap_[left][right];
+ if (!fid) {
+ ostringstream os;
+ if (right == 0) {
+ os << "Uni:" << TD::Convert(left);
+ } else {
+ os << "Bi:";
+ if (left < 0) { os << "BOS"; } else { os << TD::Convert(left); }
+ os << '_';
+ if (right < 0) { os << "EOS"; } else { os << TD::Convert(right); }
+ }
+ fid = FD::Convert(os.str());
+ }
+ features->set_value(fid, 1.0);
+}
+
+void Tagger_BigramIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ WordID& out_context = *static_cast<WordID*>(context);
+ const int arity = edge.Arity();
+ if (arity == 0) {
+ out_context = edge.rule_->e_[0];
+ FireFeature(out_context, 0, features);
+ } else if (arity == 2) {
+ WordID left = *static_cast<const WordID*>(ant_contexts[0]);
+ WordID right = *static_cast<const WordID*>(ant_contexts[1]);
+ if (edge.i_ == 0 && edge.j_ == 2)
+ FireFeature(-1, left, features);
+ FireFeature(left, right, features);
+ if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength())
+ FireFeature(right, -1, features);
+ out_context = right;
+ }
+}
+
+LexicalPairIdentity::LexicalPairIdentity(const std::string& param) {}
+
+void LexicalPairIdentity::FireFeature(WordID src,
+ WordID trg,
+ SparseVector<double>* features) const {
+ int& fid = fmap_[src][trg];
+ if (!fid) {
+ static map<WordID, WordID> escape;
+ if (escape.empty()) {
+ escape[TD::Convert("=")] = TD::Convert("__EQ");
+ escape[TD::Convert(";")] = TD::Convert("__SC");
+ escape[TD::Convert(",")] = TD::Convert("__CO");
+ }
+ if (escape.count(src)) src = escape[src];
+ if (escape.count(trg)) trg = escape[trg];
+ ostringstream os;
+ os << "Id:" << TD::Convert(src) << ':' << TD::Convert(trg);
+ fid = FD::Convert(os.str());
+ }
+ features->set_value(fid, 1.0);
+}
+
+void LexicalPairIdentity::TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const {
+ const vector<WordID>& ew = edge.rule_->e_;
+ const vector<WordID>& fw = edge.rule_->f_;
+ for (int i = 0; i < ew.size(); ++i) {
+ const WordID& e = ew[i];
+ if (e <= 0) continue;
+ for (int j = 0; j < fw.size(); ++j) {
+ const WordID& f = fw[j];
+ if (f <= 0) continue;
+ FireFeature(f, e, features);
+ }
+ }
+}
+
+
diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h
new file mode 100644
index 00000000..41c3ee5b
--- /dev/null
+++ b/decoder/ff_tagger.h
@@ -0,0 +1,51 @@
+#ifndef _FF_TAGGER_H_
+#define _FF_TAGGER_H_
+
+#include <map>
+#include "ff.h"
+
+typedef std::map<WordID, int> Class2FID;
+typedef std::map<WordID, Class2FID> Class2Class2FID;
+
+// the reason this is a "tagger" feature is that it assumes that
+// the sequence unfolds from left to right, which means it doesn't
+// have to split states based on left context.
+// fires unigram features as well
+class Tagger_BigramIdentity : public FeatureFunction {
+ public:
+ Tagger_BigramIdentity(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ void FireFeature(const WordID& left,
+ const WordID& right,
+ SparseVector<double>* features) const;
+ mutable Class2Class2FID fmap_;
+};
+
+// for each pair of symbols cooccuring in a lexicalized rule, fire
+// a feature (mostly used for tagging, but could be used for any model)
+class LexicalPairIdentity : public FeatureFunction {
+ public:
+ LexicalPairIdentity(const std::string& param);
+ protected:
+ virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta,
+ const Hypergraph::Edge& edge,
+ const std::vector<const void*>& ant_contexts,
+ SparseVector<double>* features,
+ SparseVector<double>* estimated_features,
+ void* context) const;
+ private:
+ void FireFeature(WordID src,
+ WordID trg,
+ SparseVector<double>* features) const;
+ mutable Class2Class2FID fmap_;
+};
+
+
+#endif
diff --git a/decoder/lexcrf.cc b/decoder/lexcrf.cc
index 33455a3d..816506e4 100644
--- a/decoder/lexcrf.cc
+++ b/decoder/lexcrf.cc
@@ -81,7 +81,7 @@ struct LexicalCRFImpl {
}
}
Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
- Hypergraph::Node* goal = forest->AddNode(TD::Convert("[Goal]")*-1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
forest->ConnectEdgeToHeadNode(hg_edge, goal);
}
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
new file mode 100644
index 00000000..5a0155cc
--- /dev/null
+++ b/decoder/tagger.cc
@@ -0,0 +1,109 @@
+#include "tagger.h"
+
+#include "tdict.h"
+#include "hg_io.h"
+#include "filelib.h"
+#include "hg.h"
+#include "wordid.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+// This is a really simple linear chain tagger.
+// You specify a tagset, and it hypothesizes that each word in the
+// input can be tagged with any member of the tagset.
+// The are a couple sample features implemented in ff_tagger.h/cc
+// One thing to note, that while CRFs typically define the label
+// sequence as corresponding to the hidden states in a trellis,
+// in our model the labels are on edges, but mathematically
+// they are identical.
+//
+// Things to do if you want to make this a "real" tagger:
+// - support dictionaries (for each word, limit the tags considered)
+// - add latent variables - this is really easy to do
+
+static void ReadTagset(const string& file, vector<WordID>* tags) {
+ ReadFile rf(file);
+ istream& in(*rf.stream());
+ while(in) {
+ string tag;
+ in >> tag;
+ if (tag.empty()) continue;
+ tags->push_back(TD::Convert(tag));
+ }
+ cerr << "Read " << tags->size() << " labels (tags) from " << file << endl;
+}
+
+struct TaggerImpl {
+ TaggerImpl(const boost::program_options::variables_map& conf) :
+ kXCAT(TD::Convert("X")*-1),
+ kNULL(TD::Convert("<eps>")),
+ kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")),
+ kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) {
+ if (conf.count("tagger_tagset") == 0) {
+ cerr << "Tagger requires --tagger_tagset FILE!\n";
+ exit(1);
+ }
+ ReadTagset(conf["tagger_tagset"].as<string>(), &tagset_);
+ }
+
+ void BuildTrellis(const vector<WordID>& seq, Hypergraph* forest) {
+ int prev_node_id = -1;
+ for (int i = 0; i < seq.size(); ++i) {
+ const WordID& src = seq[i];
+ const int new_node_id = forest->AddNode(kXCAT)->id_;
+ for (int k = 0; k < tagset_.size(); ++k) {
+ TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k]));
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = i;
+ edge->j_ = i+1;
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ }
+ if (prev_node_id >= 0) {
+ const int comb_node_id = forest->AddNode(kXCAT)->id_;
+ Hypergraph::TailNodeVector tail(2, prev_node_id);
+ tail[1] = new_node_id;
+ Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail);
+ edge->i_ = 0;
+ edge->j_ = i+1;
+ forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id);
+ prev_node_id = comb_node_id;
+ } else {
+ prev_node_id = new_node_id;
+ }
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ }
+
+ private:
+ vector<WordID> tagset_;
+ const WordID kXCAT;
+ const WordID kNULL;
+ const TRulePtr kBINARY;
+ const TRulePtr kGOAL_RULE;
+};
+
+Tagger::Tagger(const boost::program_options::variables_map& conf) :
+ pimpl_(new TaggerImpl(conf)) {}
+
+
+bool Tagger::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextToLattice(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ vector<WordID> sequence(lattice.size());
+ for (int i = 0; i < lattice.size(); ++i) {
+ assert(lattice[i].size() == 1);
+ sequence[i] = lattice[i][0].label;
+ }
+ pimpl_->BuildTrellis(sequence, forest);
+ forest->Reweight(weights);
+ return true;
+}
+
diff --git a/decoder/tagger.h b/decoder/tagger.h
new file mode 100644
index 00000000..900019f2
--- /dev/null
+++ b/decoder/tagger.h
@@ -0,0 +1,17 @@
+#ifndef _TAGGER_H_
+#define _TAGGER_H_
+
+#include "translator.h"
+
+struct TaggerImpl;
+struct Tagger : public Translator {
+ Tagger(const boost::program_options::variables_map& conf);
+ bool Translate(const std::string& input,
+ SentenceMetadata* smeta,
+ const std::vector<double>& weights,
+ Hypergraph* forest);
+ private:
+ boost::shared_ptr<TaggerImpl> pimpl_;
+};
+
+#endif
diff --git a/decoder/trule.h b/decoder/trule.h
index d2b1babe..42edfa2c 100644
--- a/decoder/trule.h
+++ b/decoder/trule.h
@@ -39,6 +39,10 @@ class TRule {
// [LHS] ||| term1 [NT] term2 [OTHER_NT] [YET_ANOTHER_NT]
static TRule* CreateRuleMonolingual(const std::string& rule);
+ static TRule* CreateLexicalRule(const WordID& src, const WordID& trg) {
+ return new TRule(src, trg);
+ }
+
void ESubstitute(const std::vector<const std::vector<WordID>* >& var_values,
std::vector<WordID>* result) const {
int vc = 0;
@@ -116,6 +120,7 @@ class TRule {
short int prev_j;
private:
+ TRule(const WordID& src, const WordID& trg) : e_(1, trg), f_(1, src), lhs_(), arity_(), prev_i(), prev_j() {}
bool SanityCheck() const;
};
diff --git a/training/cluster-ptrain.pl b/training/cluster-ptrain.pl
index 8b06f162..33aab25d 100755
--- a/training/cluster-ptrain.pl
+++ b/training/cluster-ptrain.pl
@@ -36,6 +36,7 @@ GetOptions("cdec=s" => \$DECODER,
"sigma_squared=f" => \$sigsq,
"means=s" => \$means_file,
"optimizer=s" => \$OALG,
+ "gaussian_prior" => \$PRIOR,
"jobs=i" => \$nodes,
"pmem=s" => \$pmem
) or usage();