diff options
author | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-29 23:05:40 -0400 |
---|---|---|
committer | Chris Dyer <cdyer@cs.cmu.edu> | 2011-03-29 23:05:40 -0400 |
commit | cf192109de3919e6e53d21c516531aa0d1018b5e (patch) | |
tree | a99f8bac42376f7fd926a3ce8c966b8dcbaa0e81 | |
parent | fc17a75cefc5d7b069a5605cb2176f7ee3ef8649 (diff) |
dynasearch neighborhood option instead of default partition
-rw-r--r-- | decoder/decoder.cc | 5 | ||||
-rw-r--r-- | decoder/ff_tagger.cc | 9 | ||||
-rw-r--r-- | decoder/ff_tagger.h | 1 | ||||
-rw-r--r-- | decoder/lextrans.cc | 101 |
4 files changed, 114 insertions, 2 deletions
diff --git a/decoder/decoder.cc b/decoder/decoder.cc index fdaf8cb1..81759a12 100644 --- a/decoder/decoder.cc +++ b/decoder/decoder.cc @@ -421,6 +421,7 @@ DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream ("ctf_num_widenings", po::value<int>()->default_value(2), "Widen coarse beam this many times before backing off to full parse") ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") + ("lextrans_dynasearch", "'DynaSearch' neighborhood instead of usual partition, as defined by Smith & Eisner (2005)") ("lextrans_use_null", "Support source-side null words in lexical translation") ("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference") ("tagger_tagset,t", po::value<string>(), "(Tagger) file containing tag set") @@ -861,6 +862,10 @@ bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { } for (int j = 0; j < in_edges.size(); ++j) forest.edges_[in_edges[j]].feature_values_.set_value(rp.fid_summary, exp(log_np)); +// Hypergraph::Edge& example_edge = forest.edges_[in_edges[0]]; +// string n = "NONE"; +// if (forest.nodes_[i].cat_) n = TD::Convert(-forest.nodes_[i].cat_); +// cerr << "[" << n << "," << example_edge.i_ << "," << example_edge.j_ << "] = " << exp(log_np) << endl; } } } else if (summary_feature_type == kEDGE_RISK) { diff --git a/decoder/ff_tagger.cc b/decoder/ff_tagger.cc index 46c85cf3..019315a2 100644 --- a/decoder/ff_tagger.cc +++ b/decoder/ff_tagger.cc @@ -9,11 +9,14 @@ using namespace std; Tagger_BigramIndicator::Tagger_BigramIndicator(const std::string& param) : - FeatureFunction(sizeof(WordID)) {} + FeatureFunction(sizeof(WordID)) { + no_uni_ = (LowercaseString(param) == "no_uni"); +} void Tagger_BigramIndicator::FireFeature(const WordID& left, const WordID& right, SparseVector<double>* features) const { + if (no_uni_ && right == 0) return; int& fid = fmap_[left][right]; if (!fid) { ostringstream os; @@ -41,6 +44,8 @@ void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta if (arity == 0) { out_context = edge.rule_->e_[0]; FireFeature(out_context, 0, features); + } else if (arity == 1) { + out_context = *static_cast<const WordID*>(ant_contexts[0]); } else if (arity == 2) { WordID left = *static_cast<const WordID*>(ant_contexts[0]); WordID right = *static_cast<const WordID*>(ant_contexts[1]); @@ -50,6 +55,8 @@ void Tagger_BigramIndicator::TraversalFeaturesImpl(const SentenceMetadata& smeta if (edge.i_ == 0 && edge.j_ == smeta.GetSourceLength()) FireFeature(right, -1, features); out_context = right; + } else { + assert(!"shouldn't happen"); } } diff --git a/decoder/ff_tagger.h b/decoder/ff_tagger.h index 3066866a..bd5b62c0 100644 --- a/decoder/ff_tagger.h +++ b/decoder/ff_tagger.h @@ -28,6 +28,7 @@ class Tagger_BigramIndicator : public FeatureFunction { const WordID& right, SparseVector<double>* features) const; mutable Class2Class2FID fmap_; + bool no_uni_; }; // for each pair of symbols cooccuring in a lexicalized rule, fire diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index f237295c..8c3269bf 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -15,9 +15,11 @@ struct LexicalTransImpl { LexicalTransImpl(const boost::program_options::variables_map& conf) : use_null(conf.count("lextrans_use_null") > 0), align_only_(conf.count("lextrans_align_only") > 0), + dyna_search_(conf.count("lextrans_dynasearch") > 0), psg_file_(), kXCAT(TD::Convert("X")*-1), kNULL(TD::Convert("<eps>")), + kUNARY(new TRule("[X] ||| [X,1] ||| [1]")), kBINARY(new TRule("[X] ||| [X,1] [X,2] ||| [1] [2]")), kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")) { if (conf.count("per_sentence_grammar_file")) { @@ -61,7 +63,103 @@ struct LexicalTransImpl { } } + void CreateEdgeHelper(int label_node, int src, int dest, Hypergraph* forest, map<int,int>* nl2node) { + assert(src != dest); + assert(label_node < forest->nodes_.size()); + int& next_node_id = (*nl2node)[dest]; + if (!next_node_id) + next_node_id = forest->AddNode(kXCAT)->id_; + if (src < 0) { // edge from the start node + Hypergraph::TailNodeVector tail(1, label_node); + Hypergraph::Edge* edge = forest->AddEdge(kUNARY, tail); + forest->ConnectEdgeToHeadNode(edge->id_, next_node_id); + } else { // edge connecting two nodes + map<int,int>::iterator it = nl2node->find(src); + assert(it != nl2node->end()); + int prev_node_id = it->second; + Hypergraph::TailNodeVector tail(2, prev_node_id); + tail[1] = label_node; + Hypergraph::Edge* edge = forest->AddEdge(kBINARY, tail); + forest->ConnectEdgeToHeadNode(edge->id_, next_node_id); + } + } + + bool BuildDynaSearchTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + const int e_len = smeta.GetTargetLength(); + assert(e_len > 0); + const int f_len = lattice.size(); + // hack to tell the feature function system how big the sentence pair is + map<WordID, int> words; + int wc = 0; + vector<WordID> ref_sent; + for (int i = 0; i < e_len; ++i) { + WordID word = smeta.GetReference()[i][0].label; + ref_sent.push_back(word); + if (words.find(word) == words.end()) { + words[word] = forest->AddNode(kXCAT)->id_; + } + } + + // create zero-arity rules representing edge contents + for (int j = 0; j < f_len; ++j) { // for each word in the source + const WordID src_sym = (j < 0 ? kNULL : lattice[j][0].label); + const GrammarIter* gi = grammar->GetRoot()->Extend(src_sym); + if (!gi) { + cerr << "No translations found for: " << TD::Convert(src_sym) << "\n"; + return false; + } + const RuleBin* rb = gi->GetRules(); + assert(rb); + for (int k = 0; k < rb->GetNumRules(); ++k) { + TRulePtr rule = rb->GetIthRule(k); + const WordID trg_word = rule->e_[0]; + const map<WordID, int>::iterator wordit = words.find(trg_word); + if (wordit == words.end()) continue; + Hypergraph::Edge* edge = forest->AddEdge(rule, Hypergraph::TailNodeVector()); + edge->i_ = j; + edge->j_ = j+1; + edge->feature_values_ += edge->rule_->GetFeatureValues(); + forest->ConnectEdgeToHeadNode(edge->id_, wordit->second); + } + } + + map<int,int> nl2node; + + int num_nodes = e_len * 2 - 1; + for (int i = 0; i < num_nodes; ++i) { + const bool is_leaf_node = (i <= 1); + if (i % 2 == 0) { // has two previous words + int prev_index1 = i - 2; + WordID trg1 = ref_sent[i / 2]; + //cerr << prev_index1 << "-->" << i << "\t" << TD::Convert(trg1) << endl; + CreateEdgeHelper(words[trg1], prev_index1, i, forest, &nl2node); + if (!is_leaf_node) { + int prev_index2 = i - 1; + WordID trg2 = ref_sent[(i - 1) / 2]; + //cerr << prev_index2 << "-->" << i << "\t" << TD::Convert(trg2) << endl; + CreateEdgeHelper(words[trg2], prev_index2, i, forest, &nl2node); + } + } else { + WordID trg_word = ref_sent[(i + 1) / 2]; + int prev_index = i - 3; + //cerr << prev_index << "-->" << i << "\t" << TD::Convert(trg_word) << endl; + CreateEdgeHelper(words[trg_word], prev_index, i, forest, &nl2node); + } + //cerr << endl; + } + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + forest->is_linear_chain_ = false; + return true; + } + bool BuildTrellis(const Lattice& lattice, const SentenceMetadata& smeta, Hypergraph* forest) { + if (dyna_search_) { + return BuildDynaSearchTrellis(lattice, smeta, forest); + } + forest->is_linear_chain_ = true; if (psg_file_) { const string offset = smeta.GetSGMLValue("psg"); if (offset.size() < 2 || offset[0] != '@') { @@ -153,9 +251,11 @@ struct LexicalTransImpl { private: const bool use_null; const bool align_only_; + const bool dyna_search_; ifstream* psg_file_; const WordID kXCAT; const WordID kNULL; + const TRulePtr kUNARY; const TRulePtr kBINARY; const TRulePtr kGOAL_RULE; GrammarPtr grammar; @@ -179,7 +279,6 @@ bool LexicalTrans::TranslateImpl(const string& input, } smeta->SetSourceLength(lattice.size()); if (!pimpl_->BuildTrellis(lattice, *smeta, forest)) return false; - forest->is_linear_chain_ = true; forest->Reweight(weights); return true; } |