summaryrefslogtreecommitdiff
path: root/decoder
diff options
context:
space:
mode:
authorChris Dyer <redpony@gmail.com>2014-10-09 00:43:57 -0400
committerChris Dyer <redpony@gmail.com>2014-10-09 00:43:57 -0400
commit1d87fbb31502ba27f0469e4a576e410ee43ad77a (patch)
treedbc8a9d743a77fb1a6be229af2c49f9fe4858c8a /decoder
parent8601c7fa4ca6fe8093ec54cd2c150cf130484297 (diff)
make tree terminals available to feature functions
Diffstat (limited to 'decoder')
-rw-r--r--decoder/aligner.cc6
-rw-r--r--decoder/csplit.cc1
-rw-r--r--decoder/fst_translator.cc1
-rw-r--r--decoder/hg_intersect.cc2
-rw-r--r--decoder/lattice.cc1
-rw-r--r--decoder/lattice.h16
-rw-r--r--decoder/lexalign.cc7
-rw-r--r--decoder/lextrans.cc7
-rw-r--r--decoder/phrasebased_translator.cc1
-rw-r--r--decoder/rescore_translator.cc1
-rw-r--r--decoder/scfg_translator.cc1
-rw-r--r--decoder/sentence_metadata.h23
-rw-r--r--decoder/tagger.cc2
-rw-r--r--decoder/tree2string_translator.cc2
-rw-r--r--decoder/tree_fragment.cc7
-rw-r--r--decoder/tree_fragment.h2
16 files changed, 59 insertions, 21 deletions
diff --git a/decoder/aligner.cc b/decoder/aligner.cc
index 232e022a..fd648370 100644
--- a/decoder/aligner.cc
+++ b/decoder/aligner.cc
@@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,
}
const Hypergraph* g = &in_g;
HypergraphP new_hg;
- if (!src_lattice.IsSentence() ||
- !trg_lattice.IsSentence()) {
+ if (!IsSentence(src_lattice) ||
+ !IsSentence(trg_lattice)) {
if (map_instead_of_viterbi) {
cerr << " Lattice alignment: using Viterbi instead of MAP alignment\n";
}
map_instead_of_viterbi = false;
- fix_up_src_spans = !src_lattice.IsSentence();
+ fix_up_src_spans = !IsSentence(src_lattice);
}
KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best);
diff --git a/decoder/csplit.cc b/decoder/csplit.cc
index 4a723822..7ee4092e 100644
--- a/decoder/csplit.cc
+++ b/decoder/csplit.cc
@@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input,
smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign
for (int i = 0; i < in.size(); ++i)
smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1)));
+ smeta->ComputeInputLatticeType();
pimpl_->BuildTrellis(in, forest);
forest->Reweight(weights);
return true;
diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc
index 4253b652..50e6adcc 100644
--- a/decoder/fst_translator.cc
+++ b/decoder/fst_translator.cc
@@ -95,6 +95,7 @@ bool FSTTranslator::TranslateImpl(const string& input,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
smeta->SetSourceLength(0); // don't know how to compute this
+ smeta->input_type_ = cdec::kFOREST;
return pimpl_->Translate(input, weights, minus_lm_forest);
}
diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc
index 02f5a401..b9381d02 100644
--- a/decoder/hg_intersect.cc
+++ b/decoder/hg_intersect.cc
@@ -88,7 +88,7 @@ namespace HG {
bool Intersect(const Lattice& target, Hypergraph* hg) {
// there are a number of faster algorithms available for restricted
// classes of hypergraph and/or target.
- if (hg->IsLinearChain() && target.IsSentence())
+ if (hg->IsLinearChain() && IsSentence(target))
return FastLinearIntersect(target, hg);
vector<bool> rem(hg->edges_.size(), false);
diff --git a/decoder/lattice.cc b/decoder/lattice.cc
index 89da3cd0..1f97048d 100644
--- a/decoder/lattice.cc
+++ b/decoder/lattice.cc
@@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {
l.resize(ids.size());
for (int i = 0; i < l.size(); ++i)
l[i].push_back(LatticeArc(ids[i], 0.0, 1));
- l.is_sentence_ = true;
}
void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) {
diff --git a/decoder/lattice.h b/decoder/lattice.h
index ad4ca50d..39db0a0e 100644
--- a/decoder/lattice.h
+++ b/decoder/lattice.h
@@ -25,22 +25,24 @@ class Lattice : public std::vector<std::vector<LatticeArc> > {
friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);
friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl);
public:
- Lattice() : is_sentence_(false) {}
+ Lattice() {}
explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) :
- std::vector<std::vector<LatticeArc> >(t, v),
- is_sentence_(false) {}
+ std::vector<std::vector<LatticeArc>>(t, v) {}
int Distance(int from, int to) const {
if (dist_.empty())
return (to - from);
return dist_(from, to);
}
- // TODO this should actually be computed based on the contents
- // of the lattice
- bool IsSentence() const { return is_sentence_; }
private:
void ComputeDistances();
Array2D<int> dist_;
- bool is_sentence_;
};
+inline bool IsSentence(const Lattice& in) {
+ bool res = true;
+ for (auto& alt : in)
+ if (alt.size() > 1) { res = false; break; }
+ return res;
+}
+
#endif
diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc
index 11f20de7..dd529311 100644
--- a/decoder/lexalign.cc
+++ b/decoder/lexalign.cc
@@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input,
Hypergraph* forest) {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
- if (!lattice.IsSentence()) {
- // lexical models make independence assumptions
- // that don't work with lattices or conf nets
- cerr << "LexicalTrans: cannot deal with lattice source input!\n";
+ smeta->ComputeInputLatticeType();
+ if (smeta->GetInputType() != cdec::kSEQUENCE) {
+ cerr << "LexicalTrans: cannot deal with non-sequence input!";
abort();
}
smeta->SetSourceLength(lattice.size());
diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc
index 74a18c3f..d13a891a 100644
--- a/decoder/lextrans.cc
+++ b/decoder/lextrans.cc
@@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input,
Hypergraph* forest) {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
- if (!lattice.IsSentence()) {
- // lexical models make independence assumptions
- // that don't work with lattices or conf nets
- cerr << "LexicalTrans: cannot deal with lattice source input!\n";
+ smeta->ComputeInputLatticeType();
+ if (smeta->GetInputType() != cdec::kSEQUENCE) {
+ cerr << "LexicalTrans: cannot deal with non-sequence inputs\n";
abort();
}
smeta->SetSourceLength(lattice.size());
diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc
index 8048248e..8415353a 100644
--- a/decoder/phrasebased_translator.cc
+++ b/decoder/phrasebased_translator.cc
@@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl {
Lattice lattice;
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);
minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);
if (add_pass_through_rules) {
diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc
index 10192f7a..18c83c56 100644
--- a/decoder/rescore_translator.cc
+++ b/decoder/rescore_translator.cc
@@ -53,6 +53,7 @@ bool RescoreTranslator::TranslateImpl(const string& input,
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
smeta->SetSourceLength(0); // don't know how to compute this
+ smeta->input_type_ = cdec::kFOREST;
return pimpl_->Translate(input, weights, minus_lm_forest);
}
diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc
index 83b65c28..538f82ec 100644
--- a/decoder/scfg_translator.cc
+++ b/decoder/scfg_translator.cc
@@ -195,6 +195,7 @@ struct SCFGTranslatorImpl {
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextOrPLF(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
if (add_pass_through_rules){
if (!SILENT) cerr << "Adding pass through grammar" << endl;
PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features);
diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h
index f2a779f4..19f3721d 100644
--- a/decoder/sentence_metadata.h
+++ b/decoder/sentence_metadata.h
@@ -5,10 +5,16 @@
#include <map>
#include <cassert>
#include "lattice.h"
+#include "tree_fragment.h"
struct DocScorer; // deprecated, will be removed
struct Score; // deprecated, will be removed
+namespace cdec {
+enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN };
+class TreeFragment;
+}
+
class SentenceMetadata {
public:
friend class DecoderImpl;
@@ -17,7 +23,17 @@ class SentenceMetadata {
src_len_(-1),
has_reference_(ref.size() > 0),
trg_len_(ref.size()),
- ref_(has_reference_ ? &ref : NULL) {}
+ ref_(has_reference_ ? &ref : NULL),
+ input_type_(cdec::kUNKNOWN) {}
+
+ // helper function for lattice inputs
+ void ComputeInputLatticeType() {
+ input_type_ = cdec::kSEQUENCE;
+ for (auto& alt : src_lattice_) {
+ if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; }
+ }
+ }
+ cdec::InputType GetInputType() { return input_type_; }
int GetSentenceId() const { return sent_id_; }
@@ -25,6 +41,8 @@ class SentenceMetadata {
// it has parsed the source
void SetSourceLength(int sl) { src_len_ = sl; }
+ const cdec::TreeFragment& GetSourceTree() const { return src_tree_; }
+
// this should be called if a separate model needs to
// specify how long the target sentence should be
void SetTargetLength(int tl) {
@@ -64,12 +82,15 @@ class SentenceMetadata {
const Score* app_score;
public:
Lattice src_lattice_; // this will only be set if inputs are finite state!
+ cdec::TreeFragment src_tree_; // this will be set only if inputs are trees
private:
// you need to be very careful when depending on these values
// they will only be set during training / alignment contexts
const bool has_reference_;
int trg_len_;
const Lattice* const ref_;
+ public:
+ cdec::InputType input_type_;
};
#endif
diff --git a/decoder/tagger.cc b/decoder/tagger.cc
index 30fb055f..500d2061 100644
--- a/decoder/tagger.cc
+++ b/decoder/tagger.cc
@@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input,
Lattice& lattice = smeta->src_lattice_;
LatticeTools::ConvertTextToLattice(input, &lattice);
smeta->SetSourceLength(lattice.size());
+ smeta->ComputeInputLatticeType();
+ assert(smeta->GetInputType() == cdec::kSEQUENCE);
vector<WordID> sequence(lattice.size());
for (int i = 0; i < lattice.size(); ++i) {
assert(lattice[i].size() == 1);
diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc
index bd3b01d0..61a3aba5 100644
--- a/decoder/tree2string_translator.cc
+++ b/decoder/tree2string_translator.cc
@@ -287,6 +287,8 @@ struct Tree2StringTranslatorImpl {
const vector<double>& weights,
Hypergraph* minus_lm_forest) {
cdec::TreeFragment input_tree(input, false);
+ smeta->src_tree_ = input_tree;
+ smeta->input_type_ = cdec::kTREE;
if (add_pass_through_rules) CreatePassThroughRules(input_tree);
Hypergraph hg;
hg.ReserveNodes(input_tree.nodes.size());
diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc
index 42f7793a..5f717c5b 100644
--- a/decoder/tree_fragment.cc
+++ b/decoder/tree_fragment.cc
@@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) {
return right;
}
+vector<int> TreeFragment::Terminals() const {
+ vector<int> terms;
+ for (auto& x : *this)
+ if (IsTerminal(x)) terms.push_back(x);
+ return terms;
+}
+
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built
// symp keeps track of the terminal symbols that have been built
diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h
index 6b4842ee..e19b79fb 100644
--- a/decoder/tree_fragment.h
+++ b/decoder/tree_fragment.h
@@ -72,6 +72,8 @@ class TreeFragment {
BreadthFirstIterator bfs_begin(unsigned node_idx) const;
BreadthFirstIterator bfs_end() const;
+ std::vector<int> Terminals() const;
+
private:
// cp is the character index in the tree
// np keeps track of the nodes (nonterminals) that have been built