summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <cdyer@cs.cmu.edu>2011-03-29 23:05:40 -0400
committerChris Dyer <cdyer@cs.cmu.edu>2011-03-29 23:05:40 -0400
commitcf192109de3919e6e53d21c516531aa0d1018b5e (patch)
treea99f8bac42376f7fd926a3ce8c966b8dcbaa0e81 /decoder
parentfc17a75cefc5d7b069a5605cb2176f7ee3ef8649 (diff)
dynasearch neighborhood option instead of default partition
Diffstat (limited to 'decoder')
-rw-r--r--decoder/decoder.cc5
-rw-r--r--decoder/ff_tagger.cc9
-rw-r--r--decoder/ff_tagger.h1
-rw-r--r--decoder/lextrans.cc101
4 files changed, 114 insertions, 2 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc
index fdaf8cb1..81759a12 100644
--- a/decoder/decoder.cc
+++ b/decoder/decoder.cc
@@ -421,6 +421,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream
("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse")
("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails")
("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices")
+ ("lextrans_dynasearch", "'DynaSearch' neighborhood instead of usual partition, as defined by Smith & Eisner (2005)")
("lextrans_use_null", "Support source-side null words in lexical translation")
("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference")
("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set")
@@ -861,6 +862,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) {
}
for (int j = 0; j < in_edges.size(); ++j)
forest.edges_[in_edges[j]].feature_values_.set_value(rp.fid_summary, exp(log_np));
+// Hypergraph::Edge& example_edge = forest.edges_[in_edges[0]];
+// string n = "NONE";
+// if (forest.nodes_[i].cat_) n = TD::Convert(-forest.nodes_[i].cat_);
+// cerr << "[" << n << "," << example_edge.i_ << "," << example_edge.j_ << "] = " << exp(log_np) << endl;
}
}
} else if (summary_feature_type == kEDGE_RISK) {
diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc
index 46c85cf3..019315a2 100644
--- a/decoder/ff_tagger.cc
+++ b/decoder/ff_tagger.cc
@@ -9,11 +9,14 @@
using namespace std;
Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) :
- FeatureFunction(sizeof(WordID)) {}
+ FeatureFunction(sizeof(WordID)) {
+ no_uni_ = (LowercaseString(param) == "no_uni");
+}
void Tagger_BigramIndicator::FireFeature(const WordID& left,
const WordID& right,
SparseVector<double>* features) const {
+ if (no_uni_ && right == 0) return;
int& fid = fmap_[left][right];
if (!fid) {
ostringstream os;
@@ -41,6 +44,8 @@ void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta
if (arity == 0) {
out_context = edge.rule_->e_[0];
FireFeature(out_context, 0, features);
+ } else if (arity == 1) {
+ out_context = *static_cast<const WordID*>(ant_contexts[0]);
} else if (arity == 2) {
WordID left = *static_cast<const WordID*>(ant_contexts[0]);
WordID right = *static_cast<const WordID*>(ant_contexts[1]);
@@ -50,6 +55,8 @@ void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta
if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength())
FireFeature(right, -1, features);
out_context = right;
+ } else {
+ assert(!"shouldn't happen");
}
}
diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h
index 3066866a..bd5b62c0 100644
--- a/decoder/ff_tagger.h
+++ b/decoder/ff_tagger.h
@@ -28,6 +28,7 @@ class Tagger_BigramIndicator : public FeatureFunction {
const WordID& right,
SparseVector<double>* features) const;
mutable Class2Class2FID fmap_;
+ bool no_uni_;
};
// for each pair of symbols cooccuring in a lexicalized rule, fire
diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc
index f237295c..8c3269bf 100644
--- a/decoder/lextrans.cc
+++ b/decoder/lextrans.cc
@@ -15,9 +15,11 @@ struct LexicalTransImpl {
LexicalTransImpl(const boost::program_options::variables_map& conf) :
use_null(conf.count("lextrans_use_null") > 0),
align_only_(conf.count("lextrans_align_only") > 0),
+ dyna_search_(conf.count("lextrans_dynasearch") > 0),
psg_file_(),
kXCAT(TD::Convert("X")*-1),
kNULL(TD::Convert("<eps>")),
+ kUNARY(new TRule("[X] ||| [X,1] ||| [1]")),
kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")),
kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) {
if (conf.count("per_sentence_grammar_file")) {
@@ -61,7 +63,103 @@ struct LexicalTransImpl {
}
}
+ void CreateEdgeHelper(int label_node, int src, int dest, Hypergraph* forest, map<int,int>* nl2node) {
+ assert(src != dest);
+ assert(label_node < forest->nodes_.size());
+ int& next_node_id = (*nl2node)[dest];
+ if (!next_node_id)
+ next_node_id = forest->AddNode(kXCAT)->id_;
+ if (src < 0) { // edge from the start node
+ Hypergraph::TailNodeVector tail(1, label_node);
+ Hypergraph::Edge* edge = forest->AddEdge(kUNARY, tail);
+ forest->ConnectEdgeToHeadNode(edge->id_, next_node_id);
+ } else { // edge connecting two nodes
+ map<int,int>::iterator it = nl2node->find(src);
+ assert(it != nl2node->end());
+ int prev_node_id = it->second;
+ Hypergraph::TailNodeVector tail(2, prev_node_id);
+ tail[1] = label_node;
+ Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail);
+ forest->ConnectEdgeToHeadNode(edge->id_, next_node_id);
+ }
+ }
+
+ bool BuildDynaSearchTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) {
+ const int e_len = smeta.GetTargetLength();
+ assert(e_len > 0);
+ const int f_len = lattice.size();
+ // hack to tell the feature function system how big the sentence pair is
+ map<WordID, int> words;
+ int wc = 0;
+ vector<WordID> ref_sent;
+ for (int i = 0; i < e_len; ++i) {
+ WordID word = smeta.GetReference()[i][0].label;
+ ref_sent.push_back(word);
+ if (words.find(word) == words.end()) {
+ words[word] = forest->AddNode(kXCAT)->id_;
+ }
+ }
+
+ // create zero-arity rules representing edge contents
+ for (int j = 0; j < f_len; ++j) { // for each word in the source
+ const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label);
+ const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym);
+ if (!gi) {
+ cerr << "No translations found for: " << TD::Convert(src_sym) << "\n";
+ return false;
+ }
+ const RuleBin* rb = gi->GetRules();
+ assert(rb);
+ for (int k = 0; k < rb->GetNumRules(); ++k) {
+ TRulePtr rule = rb->GetIthRule(k);
+ const WordID trg_word = rule->e_[0];
+ const map<WordID, int>::iterator wordit = words.find(trg_word);
+ if (wordit == words.end()) continue;
+ Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector());
+ edge->i_ = j;
+ edge->j_ = j+1;
+ edge->feature_values_ += edge->rule_->GetFeatureValues();
+ forest->ConnectEdgeToHeadNode(edge->id_, wordit->second);
+ }
+ }
+
+ map<int,int> nl2node;
+
+ int num_nodes = e_len * 2 - 1;
+ for (int i = 0; i < num_nodes; ++i) {
+ const bool is_leaf_node = (i <= 1);
+ if (i % 2 == 0) { // has two previous words
+ int prev_index1 = i - 2;
+ WordID trg1 = ref_sent[i / 2];
+ //cerr << prev_index1 << "-->" << i << "\t" << TD::Convert(trg1) << endl;
+ CreateEdgeHelper(words[trg1], prev_index1, i, forest, &nl2node);
+ if (!is_leaf_node) {
+ int prev_index2 = i - 1;
+ WordID trg2 = ref_sent[(i - 1) / 2];
+ //cerr << prev_index2 << "-->" << i << "\t" << TD::Convert(trg2) << endl;
+ CreateEdgeHelper(words[trg2], prev_index2, i, forest, &nl2node);
+ }
+ } else {
+ WordID trg_word = ref_sent[(i + 1) / 2];
+ int prev_index = i - 3;
+ //cerr << prev_index << "-->" << i << "\t" << TD::Convert(trg_word) << endl;
+ CreateEdgeHelper(words[trg_word], prev_index, i, forest, &nl2node);
+ }
+ //cerr << endl;
+ }
+ Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1);
+ Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1);
+ Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail);
+ forest->ConnectEdgeToHeadNode(hg_edge, goal);
+ forest->is_linear_chain_ = false;
+ return true;
+ }
+
bool BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) {
+ if (dyna_search_) {
+ return BuildDynaSearchTrellis(lattice, smeta, forest);
+ }
+ forest->is_linear_chain_ = true;
if (psg_file_) {
const string offset = smeta.GetSGMLValue("psg");
if (offset.size() < 2 || offset[0] != '@') {
@@ -153,9 +251,11 @@ struct LexicalTransImpl {
private:
const bool use_null;
const bool align_only_;
+ const bool dyna_search_;
ifstream* psg_file_;
const WordID kXCAT;
const WordID kNULL;
+ const TRulePtr kUNARY;
const TRulePtr kBINARY;
const TRulePtr kGOAL_RULE;
GrammarPtr grammar;
@@ -179,7 +279,6 @@ bool LexicalTrans::TranslateImpl(const string& input,
}
smeta->SetSourceLength(lattice.size());
if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false;
- forest->is_linear_chain_ = true;
forest->Reweight(weights);
return true;
}