Skip to content

Commit 1bdc300

Browse files
committed
Merge branch 'whale-pipeline-evaluation' of https://github.com/dice-group/dice-embeddings into whale-pipeline-evaluation
2 parents 8ee0e5f + 148711c commit 1bdc300

1 file changed

Lines changed: 57 additions & 79 deletions

File tree

dicee/eval_static_funcs.py

Lines changed: 57 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
@torch.no_grad()
99
def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tuple, List],
10-
re_vocab: Dict[Tuple, List], batch_size=128) -> Dict:
10+
re_vocab: Dict[Tuple, List]) -> Dict:
1111
"""
1212
1313
Parameters
@@ -16,7 +16,6 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
1616
triples
1717
er_vocab
1818
re_vocab
19-
batch_size
2019
2120
Returns
2221
-------
@@ -32,88 +31,67 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
3231
all_entities = torch.arange(0, num_entities).long()
3332
all_entities = all_entities.reshape(len(all_entities), )
3433
# Iterating one by one is not good when you are using batch norm
35-
# Iterate over test triples in batches
36-
for batch_start in tqdm(range(0, len(triples)), batch_size):
37-
batch_end = min(batch_start + batch_size, len(triples))
38-
batch_triples = triples[batch_start:batch_end]
39-
40-
# Prepare batch data
41-
str_h_batch = [data_point[0] for data_point in batch_triples]
42-
str_r_batch = [data_point[1] for data_point in batch_triples]
43-
str_t_batch = [data_point[2] for data_point in batch_triples]
44-
45-
h_batch = [model.get_entity_index(str_h) for str_h in str_h_batch]
46-
r_batch = [model.get_entity_index(str_r) for str_r in str_r_batch]
47-
t_batch = [model.get_entity_index(str_t) for str_t in str_t_batch]
48-
49-
h_batch_tensor = torch.tensor(h_batch)
50-
r_batch_tensor = torch.tensor(r_batch)
51-
t_batch_tensor = torch.tensor(t_batch)
52-
53-
batch_size_current = len(batch_triples)
54-
num_entities = len(all_entities)
55-
34+
for i in tqdm(range(0, len(triples))):
35+
# (1) Get a triple (head entity, relation, tail entity
36+
data_point = triples[i]
37+
str_h, str_r, str_t = data_point[0], data_point[1], data_point[2]
38+
39+
h, r, t = model.get_entity_index(str_h), model.get_relation_index(str_r), model.get_entity_index(str_t)
5640
# (2) Predict missing heads and tails
57-
x = torch.stack((h_batch_tensor.repeat_interleave(num_entities), r_batch_tensor.repeat_interleave(num_entities), all_entities.repeat(batch_size_current)), dim=1)
58-
predictions_tails = model.model.forward_triples(x).view(batch_size_current, num_entities)
59-
60-
x = torch.stack((all_entities.repeat(batch_size_current),
61-
r_batch_tensor.repeat_interleave(num_entities),
62-
t_batch_tensor.repeat_interleave(num_entities)
41+
x = torch.stack((torch.tensor(h).repeat(num_entities, ),
42+
torch.tensor(r).repeat(num_entities, ),
43+
all_entities), dim=1)
44+
45+
predictions_tails = model.model.forward_triples(x)
46+
x = torch.stack((all_entities,
47+
torch.tensor(r).repeat(num_entities, ),
48+
torch.tensor(t).repeat(num_entities)
6349
), dim=1)
6450

65-
predictions_heads = model.model.forward_triples(x).view(batch_size_current, num_entities)
51+
predictions_heads = model.model.forward_triples(x)
6652
del x
6753

68-
# Now process each triple in the batch
69-
for i in range(batch_size_current):
70-
h = h_batch[i]
71-
r = r_batch[i]
72-
t = t_batch[i]
73-
str_h = str_h_batch[i]
74-
str_r = str_r_batch[i]
75-
str_t = str_t_batch[i]
76-
# 3. Computed filtered ranks for missing tail entities.
77-
# 3.1. Compute filtered tail entity rankings
78-
filt_tails = [model.entity_to_idx[i] for i in er_vocab[(str_h, str_r)]]
79-
# 3.2 Get the predicted target's score
80-
target_value = predictions_tails[t].item()
81-
# 3.3 Filter scores of all triples containing filtered tail entities
82-
predictions_tails[filt_tails] = -np.Inf
83-
# 3.4 Reset the target's score
84-
predictions_tails[t] = target_value
85-
# 3.5. Sort the score
86-
_, sort_idxs = torch.sort(predictions_tails, descending=True)
87-
sort_idxs = sort_idxs.detach()
88-
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]
89-
90-
# 4. Computed filtered ranks for missing head entities.
91-
# 4.1. Retrieve head entities to be filtered
92-
filt_heads = [model.entity_to_idx[i] for i in re_vocab[(str_r, str_t)]]
93-
# 4.2 Get the predicted target's score
94-
target_value = predictions_heads[h].item()
95-
# 4.3 Filter scores of all triples containing filtered head entities.
96-
predictions_heads[filt_heads] = -np.Inf
97-
predictions_heads[h] = target_value
98-
_, sort_idxs = torch.sort(predictions_heads, descending=True)
99-
sort_idxs = sort_idxs.detach()
100-
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]
101-
102-
# 4. Add 1 to ranks as numpy array first item has the index of 0.
103-
filt_head_entity_rank += 1
104-
filt_tail_entity_rank += 1
105-
106-
rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
107-
# 5. Store reciprocal ranks.
108-
reciprocal_ranks.append(rr)
109-
# print(f'{i}.th triple: mean reciprical rank:{rr}')
110-
111-
# 4. Compute Hit@N
112-
for hits_level in range(1, 11):
113-
res = 1 if filt_head_entity_rank <= hits_level else 0
114-
res += 1 if filt_tail_entity_rank <= hits_level else 0
115-
if res > 0:
116-
hits.setdefault(hits_level, []).append(res)
54+
# 3. Computed filtered ranks for missing tail entities.
55+
# 3.1. Compute filtered tail entity rankings
56+
filt_tails = [model.entity_to_idx[i] for i in er_vocab[(str_h, str_r)]]
57+
# 3.2 Get the predicted target's score
58+
target_value = predictions_tails[t].item()
59+
# 3.3 Filter scores of all triples containing filtered tail entities
60+
predictions_tails[filt_tails] = -np.Inf
61+
# 3.4 Reset the target's score
62+
predictions_tails[t] = target_value
63+
# 3.5. Sort the score
64+
_, sort_idxs = torch.sort(predictions_tails, descending=True)
65+
sort_idxs = sort_idxs.detach()
66+
filt_tail_entity_rank = np.where(sort_idxs == t)[0][0]
67+
68+
# 4. Computed filtered ranks for missing head entities.
69+
# 4.1. Retrieve head entities to be filtered
70+
filt_heads = [model.entity_to_idx[i] for i in re_vocab[(str_r, str_t)]]
71+
# 4.2 Get the predicted target's score
72+
target_value = predictions_heads[h].item()
73+
# 4.3 Filter scores of all triples containing filtered head entities.
74+
predictions_heads[filt_heads] = -np.Inf
75+
predictions_heads[h] = target_value
76+
_, sort_idxs = torch.sort(predictions_heads, descending=True)
77+
sort_idxs = sort_idxs.detach()
78+
filt_head_entity_rank = np.where(sort_idxs == h)[0][0]
79+
80+
# 4. Add 1 to ranks as numpy array first item has the index of 0.
81+
filt_head_entity_rank += 1
82+
filt_tail_entity_rank += 1
83+
84+
rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank)
85+
# 5. Store reciprocal ranks.
86+
reciprocal_ranks.append(rr)
87+
# print(f'{i}.th triple: mean reciprical rank:{rr}')
88+
89+
# 4. Compute Hit@N
90+
for hits_level in range(1, 11):
91+
res = 1 if filt_head_entity_rank <= hits_level else 0
92+
res += 1 if filt_tail_entity_rank <= hits_level else 0
93+
if res > 0:
94+
hits.setdefault(hits_level, []).append(res)
11795

11896
mean_reciprocal_rank = sum(reciprocal_ranks) / (float(len(triples) * 2))
11997

0 commit comments

Comments
 (0)