diff options
| -rw-r--r-- | decoder/aligner.cc | 6 | ||||
| -rw-r--r-- | decoder/csplit.cc | 1 | ||||
| -rw-r--r-- | decoder/fst_translator.cc | 1 | ||||
| -rw-r--r-- | decoder/hg_intersect.cc | 2 | ||||
| -rw-r--r-- | decoder/lattice.cc | 1 | ||||
| -rw-r--r-- | decoder/lattice.h | 16 | ||||
| -rw-r--r-- | decoder/lexalign.cc | 7 | ||||
| -rw-r--r-- | decoder/lextrans.cc | 7 | ||||
| -rw-r--r-- | decoder/phrasebased_translator.cc | 1 | ||||
| -rw-r--r-- | decoder/rescore_translator.cc | 1 | ||||
| -rw-r--r-- | decoder/scfg_translator.cc | 1 | ||||
| -rw-r--r-- | decoder/sentence_metadata.h | 23 | ||||
| -rw-r--r-- | decoder/tagger.cc | 2 | ||||
| -rw-r--r-- | decoder/tree2string_translator.cc | 2 | ||||
| -rw-r--r-- | decoder/tree_fragment.cc | 7 | ||||
| -rw-r--r-- | decoder/tree_fragment.h | 2 | 
16 files changed, 59 insertions, 21 deletions
| diff --git a/decoder/aligner.cc b/decoder/aligner.cc index 232e022a..fd648370 100644 --- a/decoder/aligner.cc +++ b/decoder/aligner.cc @@ -198,13 +198,13 @@ void AlignerTools::WriteAlignment(const Lattice& src_lattice,    }    const Hypergraph* g = &in_g;    HypergraphP new_hg; -  if (!src_lattice.IsSentence() || -      !trg_lattice.IsSentence()) { +  if (!IsSentence(src_lattice) || +      !IsSentence(trg_lattice)) {      if (map_instead_of_viterbi) {        cerr << "  Lattice alignment: using Viterbi instead of MAP alignment\n";      }      map_instead_of_viterbi = false; -    fix_up_src_spans = !src_lattice.IsSentence(); +    fix_up_src_spans = !IsSentence(src_lattice);    }    KBest::KBestDerivations<vector<Hypergraph::Edge const*>, ViterbiPathTraversal> kbest(in_g, k_best); diff --git a/decoder/csplit.cc b/decoder/csplit.cc index 4a723822..7ee4092e 100644 --- a/decoder/csplit.cc +++ b/decoder/csplit.cc @@ -151,6 +151,7 @@ bool CompoundSplit::TranslateImpl(const string& input,    smeta->SetSourceLength(in.size());  // TODO do utf8 or somethign    for (int i = 0; i < in.size(); ++i)      smeta->src_lattice_.push_back(vector<LatticeArc>(1, LatticeArc(TD::Convert(in[i]), 0.0, 1))); +  smeta->ComputeInputLatticeType();    pimpl_->BuildTrellis(in, forest);    forest->Reweight(weights);    return true; diff --git a/decoder/fst_translator.cc b/decoder/fst_translator.cc index 4253b652..50e6adcc 100644 --- a/decoder/fst_translator.cc +++ b/decoder/fst_translator.cc @@ -95,6 +95,7 @@ bool FSTTranslator::TranslateImpl(const string& input,                                const vector<double>& weights,                                Hypergraph* minus_lm_forest) {    smeta->SetSourceLength(0);  // don't know how to compute this +  smeta->input_type_ = cdec::kFOREST;    return pimpl_->Translate(input, weights, minus_lm_forest);  } diff --git a/decoder/hg_intersect.cc b/decoder/hg_intersect.cc index 02f5a401..b9381d02 100644 --- a/decoder/hg_intersect.cc +++ b/decoder/hg_intersect.cc @@ -88,7 +88,7 @@ namespace HG {  bool Intersect(const Lattice& target, Hypergraph* hg) {    // there are a number of faster algorithms available for restricted    // classes of hypergraph and/or target. -  if (hg->IsLinearChain() && target.IsSentence()) +  if (hg->IsLinearChain() && IsSentence(target))      return FastLinearIntersect(target, hg);    vector<bool> rem(hg->edges_.size(), false); diff --git a/decoder/lattice.cc b/decoder/lattice.cc index 89da3cd0..1f97048d 100644 --- a/decoder/lattice.cc +++ b/decoder/lattice.cc @@ -50,7 +50,6 @@ void LatticeTools::ConvertTextToLattice(const string& text, Lattice* pl) {    l.resize(ids.size());    for (int i = 0; i < l.size(); ++i)      l[i].push_back(LatticeArc(ids[i], 0.0, 1)); -  l.is_sentence_ = true;  }  void LatticeTools::ConvertTextOrPLF(const string& text_or_plf, Lattice* pl) { diff --git a/decoder/lattice.h b/decoder/lattice.h index ad4ca50d..39db0a0e 100644 --- a/decoder/lattice.h +++ b/decoder/lattice.h @@ -25,22 +25,24 @@ class Lattice : public std::vector<std::vector<LatticeArc> > {    friend void LatticeTools::ConvertTextOrPLF(const std::string& text_or_plf, Lattice* pl);    friend void LatticeTools::ConvertTextToLattice(const std::string& text, Lattice* pl);   public: -  Lattice() : is_sentence_(false) {} +  Lattice() {}    explicit Lattice(size_t t, const std::vector<LatticeArc>& v = std::vector<LatticeArc>()) : -   std::vector<std::vector<LatticeArc> >(t, v), -   is_sentence_(false) {} +   std::vector<std::vector<LatticeArc>>(t, v) {}    int Distance(int from, int to) const {      if (dist_.empty())        return (to - from);      return dist_(from, to);    } -  // TODO this should actually be computed based on the contents -  // of the lattice -  bool IsSentence() const { return is_sentence_; }   private:    void ComputeDistances();    Array2D<int> dist_; -  bool is_sentence_;  }; +inline bool IsSentence(const Lattice& in) { +  bool res = true; +  for (auto& alt : in) +    if (alt.size() > 1) { res = false; break; } +  return res; +} +  #endif diff --git a/decoder/lexalign.cc b/decoder/lexalign.cc index 11f20de7..dd529311 100644 --- a/decoder/lexalign.cc +++ b/decoder/lexalign.cc @@ -114,10 +114,9 @@ bool LexicalAlign::TranslateImpl(const string& input,                        Hypergraph* forest) {    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextOrPLF(input, &lattice); -  if (!lattice.IsSentence()) { -    // lexical models make independence assumptions -    // that don't work with lattices or conf nets -    cerr << "LexicalTrans: cannot deal with lattice source input!\n"; +  smeta->ComputeInputLatticeType(); +  if (smeta->GetInputType() != cdec::kSEQUENCE) { +    cerr << "LexicalTrans: cannot deal with non-sequence input!";      abort();    }    smeta->SetSourceLength(lattice.size()); diff --git a/decoder/lextrans.cc b/decoder/lextrans.cc index 74a18c3f..d13a891a 100644 --- a/decoder/lextrans.cc +++ b/decoder/lextrans.cc @@ -271,10 +271,9 @@ bool LexicalTrans::TranslateImpl(const string& input,                        Hypergraph* forest) {    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextOrPLF(input, &lattice); -  if (!lattice.IsSentence()) { -    // lexical models make independence assumptions -    // that don't work with lattices or conf nets -    cerr << "LexicalTrans: cannot deal with lattice source input!\n"; +  smeta->ComputeInputLatticeType(); +  if (smeta->GetInputType() != cdec::kSEQUENCE) { +    cerr << "LexicalTrans: cannot deal with non-sequence inputs\n";      abort();    }    smeta->SetSourceLength(lattice.size()); diff --git a/decoder/phrasebased_translator.cc b/decoder/phrasebased_translator.cc index 8048248e..8415353a 100644 --- a/decoder/phrasebased_translator.cc +++ b/decoder/phrasebased_translator.cc @@ -114,6 +114,7 @@ struct PhraseBasedTranslatorImpl {      Lattice lattice;      LatticeTools::ConvertTextOrPLF(input, &lattice);      smeta->SetSourceLength(lattice.size()); +    smeta->ComputeInputLatticeType();      size_t est_nodes = lattice.size() * lattice.size() * (1 << max_distortion);      minus_lm_forest->ReserveNodes(est_nodes, est_nodes * 100);      if (add_pass_through_rules) { diff --git a/decoder/rescore_translator.cc b/decoder/rescore_translator.cc index 10192f7a..18c83c56 100644 --- a/decoder/rescore_translator.cc +++ b/decoder/rescore_translator.cc @@ -53,6 +53,7 @@ bool RescoreTranslator::TranslateImpl(const string& input,                                const vector<double>& weights,                                Hypergraph* minus_lm_forest) {    smeta->SetSourceLength(0);  // don't know how to compute this +  smeta->input_type_ = cdec::kFOREST;    return pimpl_->Translate(input, weights, minus_lm_forest);  } diff --git a/decoder/scfg_translator.cc b/decoder/scfg_translator.cc index 83b65c28..538f82ec 100644 --- a/decoder/scfg_translator.cc +++ b/decoder/scfg_translator.cc @@ -195,6 +195,7 @@ struct SCFGTranslatorImpl {      Lattice& lattice = smeta->src_lattice_;      LatticeTools::ConvertTextOrPLF(input, &lattice);      smeta->SetSourceLength(lattice.size()); +    smeta->ComputeInputLatticeType();      if (add_pass_through_rules){        if (!SILENT) cerr << "Adding pass through grammar" << endl;        PassThroughGrammar* g = new PassThroughGrammar(lattice, default_nt, ctf_iterations_, num_pt_features); diff --git a/decoder/sentence_metadata.h b/decoder/sentence_metadata.h index f2a779f4..19f3721d 100644 --- a/decoder/sentence_metadata.h +++ b/decoder/sentence_metadata.h @@ -5,10 +5,16 @@  #include <map>  #include <cassert>  #include "lattice.h" +#include "tree_fragment.h"  struct DocScorer;  // deprecated, will be removed  struct Score;     // deprecated, will be removed +namespace cdec { +enum InputType { kSEQUENCE, kTREE, kLATTICE, kFOREST, kUNKNOWN }; +class TreeFragment; +} +  class SentenceMetadata {   public:    friend class DecoderImpl; @@ -17,7 +23,17 @@ class SentenceMetadata {      src_len_(-1),      has_reference_(ref.size() > 0),      trg_len_(ref.size()), -    ref_(has_reference_ ? &ref : NULL) {} +    ref_(has_reference_ ? &ref : NULL), +    input_type_(cdec::kUNKNOWN) {} + +  // helper function for lattice inputs +  void ComputeInputLatticeType() { +    input_type_ = cdec::kSEQUENCE; +    for (auto& alt : src_lattice_) { +      if (alt.size() > 1) { input_type_ = cdec::kLATTICE; break; } +    } +  } +  cdec::InputType GetInputType() { return input_type_; }    int GetSentenceId() const { return sent_id_; } @@ -25,6 +41,8 @@ class SentenceMetadata {    // it has parsed the source    void SetSourceLength(int sl) { src_len_ = sl; } +  const cdec::TreeFragment& GetSourceTree() const { return src_tree_; } +    // this should be called if a separate model needs to    // specify how long the target sentence should be    void SetTargetLength(int tl) { @@ -64,12 +82,15 @@ class SentenceMetadata {    const Score* app_score;   public:    Lattice src_lattice_;  // this will only be set if inputs are finite state! +  cdec::TreeFragment src_tree_; // this will be set only if inputs are trees   private:    // you need to be very careful when depending on these values    // they will only be set during training / alignment contexts    const bool has_reference_;    int trg_len_;    const Lattice* const ref_; + public: +  cdec::InputType input_type_;  };  #endif diff --git a/decoder/tagger.cc b/decoder/tagger.cc index 30fb055f..500d2061 100644 --- a/decoder/tagger.cc +++ b/decoder/tagger.cc @@ -100,6 +100,8 @@ bool Tagger::TranslateImpl(const string& input,    Lattice& lattice = smeta->src_lattice_;    LatticeTools::ConvertTextToLattice(input, &lattice);    smeta->SetSourceLength(lattice.size()); +  smeta->ComputeInputLatticeType(); +  assert(smeta->GetInputType() == cdec::kSEQUENCE);    vector<WordID> sequence(lattice.size());    for (int i = 0; i < lattice.size(); ++i) {      assert(lattice[i].size() == 1); diff --git a/decoder/tree2string_translator.cc b/decoder/tree2string_translator.cc index bd3b01d0..61a3aba5 100644 --- a/decoder/tree2string_translator.cc +++ b/decoder/tree2string_translator.cc @@ -287,6 +287,8 @@ struct Tree2StringTranslatorImpl {                   const vector<double>& weights,                   Hypergraph* minus_lm_forest) {      cdec::TreeFragment input_tree(input, false); +    smeta->src_tree_ = input_tree; +    smeta->input_type_ = cdec::kTREE;      if (add_pass_through_rules) CreatePassThroughRules(input_tree);      Hypergraph hg;      hg.ReserveNodes(input_tree.nodes.size()); diff --git a/decoder/tree_fragment.cc b/decoder/tree_fragment.cc index 42f7793a..5f717c5b 100644 --- a/decoder/tree_fragment.cc +++ b/decoder/tree_fragment.cc @@ -64,6 +64,13 @@ int TreeFragment::SetupSpansRec(unsigned cur, int left) {    return right;  } +vector<int> TreeFragment::Terminals() const { +  vector<int> terms; +  for (auto& x : *this) +    if (IsTerminal(x)) terms.push_back(x); +  return terms; +} +  // cp is the character index in the tree  // np keeps track of the nodes (nonterminals) that have been built  // symp keeps track of the terminal symbols that have been built diff --git a/decoder/tree_fragment.h b/decoder/tree_fragment.h index 6b4842ee..e19b79fb 100644 --- a/decoder/tree_fragment.h +++ b/decoder/tree_fragment.h @@ -72,6 +72,8 @@ class TreeFragment {    BreadthFirstIterator bfs_begin(unsigned node_idx) const;    BreadthFirstIterator bfs_end() const; +  std::vector<int> Terminals() const; +   private:    // cp is the character index in the tree    // np keeps track of the nodes (nonterminals) that have been built | 
