summaryrefslogtreecommitdiff
path: root/decoder/tagger.cc
diff options
context:
space:
mode:
Diffstat (limited to 'decoder/tagger.cc')
-rw-r--r--decoder/tagger.cc109
1 files changed, 109 insertions, 0 deletions
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
new file mode 100644
index 00000000..5a0155cc
--- /dev/null
+++ b/decoder/tagger.cc
@@ -0,0 +1,109 @@
+#include "tagger.h"
+
+#include "tdict.h"
+#include "hg_io.h"
+#include "filelib.h"
+#include "hg.h"
+#include "wordid.h"
+#include "sentence_metadata.h"
+
+using namespace std;
+
+// This is a really simple linear chain tagger.
+// You specify a tagset, and it hypothesizes that each word in the
+// input can be tagged with any member of the tagset.
+// The are a couple sample features implemented in ff_tagger.h/cc
+// One thing to note, that while CRFs typically define the label
+// sequence as corresponding to the hidden states in a trellis,
+// in our model the labels are on edges, but mathematically
+// they are identical.
+//
+// Things to do if you want to make this a "real" tagger:
+// - support dictionaries (for each word, limit the tags considered)
+// - add latent variables - this is really easy to do
+
+static void ReadTagset(const string& file, vector<WordID>* tags) {
+ ReadFile rf(file);
+ istream& in(*rf.stream());
+ while(in) {
+ string tag;
+ in >> tag;
+ if (tag.empty()) continue;
+ tags->push_back(TD::Convert(tag));
+ }
+ cerr << "Read " << tags->size() << " labels (tags) from " << file << endl;
+}
+
+struct TaggerImpl {
+ TaggerImpl(const boost::program_options::variables_map& conf) :
+ kXCAT(TD::Convert("X")*-1),
+ kNULL(TD::Convert("<eps>")),
+ kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")),
+ kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) {
+ if (conf.count("tagger_tagset") == 0) {
+ cerr << "Tagger requires --tagger_tagset FILE!\n";
+ exit(1);
+ }
+ ReadTagset(conf["tagger_tagset"].as<string>(), &tagset_);
+ }
+
+ void BuildTrellis(const vector<WordID>& seq, Hypergraph* forest) {
+ int prev_node_id = -1;
+ for (int i = 0; i < seq.size(); ++i) {
+ const WordID& src = seq[i];
+ const int new_node_id = forest->AddNode(kXCAT)->id_;
+ for (int k = 0; k < tagset_.size(); ++k) {
+ TRulePtr rule(TRule::CreateLexicalRule(src, tagset_[k]));
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = i;
+ edge->j_ = i+1;
+ forest->ConnectEdgeToHeadNode(edge->id_, new_node_id);
+ }
+ if (prev_node_id >= 0) {
+ const int comb_node_id = forest->AddNode(kXCAT)->id_;
+ Hypergraph::TailNodeVector tail(2, prev_node_id);
+ tail[1] = new_node_id;
+ Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail);
+ edge->i_ = 0;
+ edge->j_ = i+1;
+ forest->ConnectEdgeToHeadNode(edge->id_, comb_node_id);
+ prev_node_id = comb_node_id;
+ } else {
+ prev_node_id = new_node_id;
+ }
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ }
+
+ private:
+ vector<WordID> tagset_;
+ const WordID kXCAT;
+ const WordID kNULL;
+ const TRulePtr kBINARY;
+ const TRulePtr kGOAL_RULE;
+};
+
+Tagger::Tagger(const boost::program_options::variables_map& conf) :
+ pimpl_(new TaggerImpl(conf)) {}
+
+
+bool Tagger::Translate(const string& input,
+ SentenceMetadata* smeta,
+ const vector<double>& weights,
+ Hypergraph* forest) {
+ Lattice lattice;
+ LatticeTools::ConvertTextToLattice(input, &lattice);
+ smeta->SetSourceLength(lattice.size());
+ vector<WordID> sequence(lattice.size());
+ for (int i = 0; i < lattice.size(); ++i) {
+ assert(lattice[i].size() == 1);
+ sequence[i] = lattice[i][0].label;
+ }
+ pimpl_->BuildTrellis(sequence, forest);
+ forest->Reweight(weights);
+ return true;
+}
+