summaryrefslogtreecommitdiff
path: root/extractor/translation_table.cc
blob: 1b1ba11245ca88472029064c7cc4f3ce0aa242ab (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include "translation_table.h"

#include <string>
#include <vector>

#include <boost/functional/hash.hpp>

#include "alignment.h"
#include "data_array.h"

using namespace std;

namespace extractor {

TranslationTable::TranslationTable(shared_ptr<DataArray> source_data_array,
                                   shared_ptr<DataArray> target_data_array,
                                   shared_ptr<Alignment> alignment) :
    source_data_array(source_data_array), target_data_array(target_data_array) {
  const vector<int>& source_data = source_data_array->GetData();
  const vector<int>& target_data = target_data_array->GetData();

  unordered_map<int, int> source_links_count;
  unordered_map<int, int> target_links_count;
  unordered_map<pair<int, int>, int, PairHash> links_count;

  // For each pair of aligned source target words increment their link count by
  // 1. Unaligned words are paired with the NULL token.
  for (size_t i = 0; i < source_data_array->GetNumSentences(); ++i) {
    vector<pair<int, int>> links = alignment->GetLinks(i);
    int source_start = source_data_array->GetSentenceStart(i);
    int target_start = target_data_array->GetSentenceStart(i);
    // Ignore END_OF_LINE markers.
    int next_source_start = source_data_array->GetSentenceStart(i + 1) - 1;
    int next_target_start = target_data_array->GetSentenceStart(i + 1) - 1;
    vector<int> source_sentence(source_data.begin() + source_start,
        source_data.begin() + next_source_start);
    vector<int> target_sentence(target_data.begin() + target_start,
        target_data.begin() + next_target_start);
    vector<int> source_linked_words(source_sentence.size());
    vector<int> target_linked_words(target_sentence.size());

    for (pair<int, int> link: links) {
      source_linked_words[link.first] = 1;
      target_linked_words[link.second] = 1;
      IncrementLinksCount(source_links_count, target_links_count, links_count,
          source_sentence[link.first], target_sentence[link.second]);
    }

    for (size_t i = 0; i < source_sentence.size(); ++i) {
      if (!source_linked_words[i]) {
        IncrementLinksCount(source_links_count, target_links_count, links_count,
                            source_sentence[i], DataArray::NULL_WORD);
      }
    }

    for (size_t i = 0; i < target_sentence.size(); ++i) {
      if (!target_linked_words[i]) {
        IncrementLinksCount(source_links_count, target_links_count, links_count,
                            DataArray::NULL_WORD, target_sentence[i]);
      }
    }
  }

  // Calculating:
  //   p(e | f) = count(e, f) / count(f)
  //   p(f | e) = count(e, f) / count(e)
  for (pair<pair<int, int>, int> link_count: links_count) {
    int source_word = link_count.first.first;
    int target_word = link_count.first.second;
    double score1 = 1.0 * link_count.second / source_links_count[source_word];
    double score2 = 1.0 * link_count.second / target_links_count[target_word];
    translation_probabilities[link_count.first] = make_pair(score1, score2);
  }
}

TranslationTable::TranslationTable() {}

TranslationTable::~TranslationTable() {}

void TranslationTable::IncrementLinksCount(
    unordered_map<int, int>& source_links_count,
    unordered_map<int, int>& target_links_count,
    unordered_map<pair<int, int>, int, PairHash>& links_count,
    int source_word_id,
    int target_word_id) const {
  ++source_links_count[source_word_id];
  ++target_links_count[target_word_id];
  ++links_count[make_pair(source_word_id, target_word_id)];
}

double TranslationTable::GetTargetGivenSourceScore(
    const string& source_word, const string& target_word) {
  if (!source_data_array->HasWord(source_word) ||
      !target_data_array->HasWord(target_word)) {
    return -1;
  }

  int source_id = source_data_array->GetWordId(source_word);
  int target_id = target_data_array->GetWordId(target_word);
  auto entry = make_pair(source_id, target_id);
  auto it = translation_probabilities.find(entry);
  if (it == translation_probabilities.end()) {
    return 0;
  }
  return it->second.first;
}

double TranslationTable::GetSourceGivenTargetScore(
    const string& source_word, const string& target_word) {
  if (!source_data_array->HasWord(source_word) ||
      !target_data_array->HasWord(target_word)) {
    return -1;
  }

  int source_id = source_data_array->GetWordId(source_word);
  int target_id = target_data_array->GetWordId(target_word);
  auto entry = make_pair(source_id, target_id);
  auto it = translation_probabilities.find(entry);
  if (it == translation_probabilities.end()) {
    return 0;
  }
  return it->second.second;
}

bool TranslationTable::operator==(const TranslationTable& other) const {
  return *source_data_array == *other.source_data_array &&
         *target_data_array == *other.target_data_array &&
         translation_probabilities == other.translation_probabilities;
}

} // namespace extractor