From 4ebd159797a5db525fce7433e03858f8de96dce6 Mon Sep 17 00:00:00 2001 From: Chris Dyer Date: Thu, 9 Oct 2014 00:43:57 -0400 Subject: make tree terminals available to feature functions --- decoder/aligner.cc | 6 +++--- decoder/csplit.cc | 1 + decoder/fst_translator.cc | 1 + decoder/hg_intersect.cc | 2 +- decoder/lattice.cc | 1 - decoder/lattice.h | 16 +++++++++------- decoder/lexalign.cc | 7 +++---- decoder/lextrans.cc | 7 +++---- decoder/phrasebased_translator.cc | 1 + decoder/rescore_translator.cc | 1 + decoder/scfg_translator.cc | 1 + decoder/sentence_metadata.h | 23 ++++++++++++++++++++++- decoder/tagger.cc | 2 ++ decoder/tree2string_translator.cc | 2 ++ decoder/tree_fragment.cc | 7 +++++++ decoder/tree_fragment.h | 2 ++ 16 files changed, 59 insertions(+), 21 deletions(-) (limited to 'decoder') diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 232e022a..fd648370 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice, } const Hypergraph* g = &in_g; HypergraphP new_hg; - if (!src_lattice.IsSentence() || - !trg_lattice.IsSentence()) { + if (!IsSentence(src_lattice) || + !IsSentence(trg_lattice)) { if (map_instead_of_viterbi) { cerr << " Lattice alignment: using Viterbi instead of MAP alignment\n"; } map_instead_of_viterbi = false; - fix_up_src_spans = !src_lattice.IsSentence(); + fix_up_src_spans = !IsSentence(src_lattice); } KBest::KBestDerivations, ViterbiPathTraversal> kbest(in_g, k_best); diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 4a723822..7ee4092e 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input, smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign for (int i = 0; i < in.size(); ++i) smeta->src_lattice_.push_back(vector(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); + smeta->ComputeInputLatticeType(); pimpl_->BuildTrellis(in, forest); forest->Reweight(weights); return true; diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 4253b652..50e6adcc 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -95,6 +95,7 @@ bool FSTTranslator::TranslateImpl(const string& input, const vector& weights, Hypergraph* minus_lm_forest) { smeta->SetSourceLength(0); // don't know how to compute this + smeta->input_type_ = cdec::kFOREST; return pimpl_->Translate(input, weights, minus_lm_forest); } diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 02f5a401..b9381d02 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -88,7 +88,7 @@ namespace HG { bool Intersect(const Lattice& target, Hypergraph* hg) { // there are a number of faster algorithms available for restricted // classes of hypergraph and/or target. - if (hg->IsLinearChain() && target.IsSentence()) + if (hg->IsLinearChain() && IsSentence(target)) return FastLinearIntersect(target, hg); vector rem(hg->edges_.size(), false); diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 89da3cd0..1f97048d 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) { l.resize(ids.size()); for (int i = 0; i < l.size(); ++i) l[i].push_back(LatticeArc(ids[i], 0.0, 1)); - l.is_sentence_ = true; } void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index ad4ca50d..39db0a0e 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -25,22 +25,24 @@ class Lattice : public std::vector > { friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl); friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl); public: - Lattice() : is_sentence_(false) {} + Lattice() {} explicit Lattice(size_t t, const std::vector& v = std::vector()) : - std::vector >(t, v), - is_sentence_(false) {} + std::vector>(t, v) {} int Distance(int from, int to) const { if (dist_.empty()) return (to - from); return dist_(from, to); } - // TODO this should actually be computed based on the contents - // of the lattice - bool IsSentence() const { return is_sentence_; } private: void ComputeDistances(); Array2D dist_; - bool is_sentence_; }; +inline bool IsSentence(const Lattice& in) { + bool res = true; + for (auto& alt : in) + if (alt.size() > 1) { res = false; break; } + return res; +} + #endif diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 11f20de7..dd529311 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input, Hypergraph* forest) { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); - if (!lattice.IsSentence()) { - // lexical models make independence assumptions - // that don't work with lattices or conf nets - cerr << "LexicalTrans: cannot deal with lattice source input!\n"; + smeta->ComputeInputLatticeType(); + if (smeta->GetInputType() != cdec::kSEQUENCE) { + cerr << "LexicalTrans: cannot deal with non-sequence input!"; abort(); } smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 74a18c3f..d13a891a 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input, Hypergraph* forest) { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); - if (!lattice.IsSentence()) { - // lexical models make independence assumptions - // that don't work with lattices or conf nets - cerr << "LexicalTrans: cannot deal with lattice source input!\n"; + smeta->ComputeInputLatticeType(); + if (smeta->GetInputType() != cdec::kSEQUENCE) { + cerr << "LexicalTrans: cannot deal with non-sequence inputs\n"; abort(); } smeta->SetSourceLength(lattice.size()); diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc index 8048248e..8415353a 100644 --- a/decoder/phrasebased_translator.cc +++ b/decoder/phrasebased_translator.cc @@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl { Lattice lattice; LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion); minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100); if (add_pass_through_rules) { diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 10192f7a..18c83c56 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -53,6 +53,7 @@ bool RescoreTranslator::TranslateImpl(const string& input, const vector& weights, Hypergraph* minus_lm_forest) { smeta->SetSourceLength(0); // don't know how to compute this + smeta->input_type_ = cdec::kFOREST; return pimpl_->Translate(input, weights, minus_lm_forest); } diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 83b65c28..538f82ec 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -195,6 +195,7 @@ struct SCFGTranslatorImpl { Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextOrPLF(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); if (add_pass_through_rules){ if (!SILENT) cerr << "Adding pass through grammar" << endl; PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index f2a779f4..19f3721d 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -5,10 +5,16 @@ #include #include #include "lattice.h" +#include "tree_fragment.h" struct DocScorer; // deprecated, will be removed struct Score; // deprecated, will be removed +namespace cdec { +enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN }; +class TreeFragment; +} + class SentenceMetadata { public: friend class DecoderImpl; @@ -17,7 +23,17 @@ class SentenceMetadata { src_len_(-1), has_reference_(ref.size() > 0), trg_len_(ref.size()), - ref_(has_reference_ ? &ref : NULL) {} + ref_(has_reference_ ? &ref : NULL), + input_type_(cdec::kUNKNOWN) {} + + // helper function for lattice inputs + void ComputeInputLatticeType() { + input_type_ = cdec::kSEQUENCE; + for (auto& alt : src_lattice_) { + if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; } + } + } + cdec::InputType GetInputType() { return input_type_; } int GetSentenceId() const { return sent_id_; } @@ -25,6 +41,8 @@ class SentenceMetadata { // it has parsed the source void SetSourceLength(int sl) { src_len_ = sl; } + const cdec::TreeFragment& GetSourceTree() const { return src_tree_; } + // this should be called if a separate model needs to // specify how long the target sentence should be void SetTargetLength(int tl) { @@ -64,12 +82,15 @@ class SentenceMetadata { const Score* app_score; public: Lattice src_lattice_; // this will only be set if inputs are finite state! + cdec::TreeFragment src_tree_; // this will be set only if inputs are trees private: // you need to be very careful when depending on these values // they will only be set during training / alignment contexts const bool has_reference_; int trg_len_; const Lattice* const ref_; + public: + cdec::InputType input_type_; }; #endif diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 30fb055f..500d2061 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input, Lattice& lattice = smeta->src_lattice_; LatticeTools::ConvertTextToLattice(input, &lattice); smeta->SetSourceLength(lattice.size()); + smeta->ComputeInputLatticeType(); + assert(smeta->GetInputType() == cdec::kSEQUENCE); vector sequence(lattice.size()); for (int i = 0; i < lattice.size(); ++i) { assert(lattice[i].size() == 1); diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index bd3b01d0..61a3aba5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -287,6 +287,8 @@ struct Tree2StringTranslatorImpl { const vector& weights, Hypergraph* minus_lm_forest) { cdec::TreeFragment input_tree(input, false); + smeta->src_tree_ = input_tree; + smeta->input_type_ = cdec::kTREE; if (add_pass_through_rules) CreatePassThroughRules(input_tree); Hypergraph hg; hg.ReserveNodes(input_tree.nodes.size()); diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 42f7793a..5f717c5b 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) { return right; } +vector TreeFragment::Terminals() const { + vector terms; + for (auto& x : *this) + if (IsTerminal(x)) terms.push_back(x); + return terms; +} + // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built // symp keeps track of the terminal symbols that have been built diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 6b4842ee..e19b79fb 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -72,6 +72,8 @@ class TreeFragment { BreadthFirstIterator bfs_begin(unsigned node_idx) const; BreadthFirstIterator bfs_end() const; + std::vector Terminals() const; + private: // cp is the character index in the tree // np keeps track of the nodes (nonterminals) that have been built -- cgit v1.2.3