77
88@torch .no_grad ()
99def evaluate_link_prediction_performance (model : KGE , triples , er_vocab : Dict [Tuple , List ],
10- re_vocab : Dict [Tuple , List ], batch_size = 128 ) -> Dict :
10+ re_vocab : Dict [Tuple , List ]) -> Dict :
1111 """
1212
1313 Parameters
@@ -16,7 +16,6 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
1616 triples
1717 er_vocab
1818 re_vocab
19- batch_size
2019
2120 Returns
2221 -------
@@ -32,88 +31,67 @@ def evaluate_link_prediction_performance(model: KGE, triples, er_vocab: Dict[Tup
3231 all_entities = torch .arange (0 , num_entities ).long ()
3332 all_entities = all_entities .reshape (len (all_entities ), )
3433 # Iterating one by one is not good when you are using batch norm
35- # Iterate over test triples in batches
36- for batch_start in tqdm (range (0 , len (triples )), batch_size ):
37- batch_end = min (batch_start + batch_size , len (triples ))
38- batch_triples = triples [batch_start :batch_end ]
39-
40- # Prepare batch data
41- str_h_batch = [data_point [0 ] for data_point in batch_triples ]
42- str_r_batch = [data_point [1 ] for data_point in batch_triples ]
43- str_t_batch = [data_point [2 ] for data_point in batch_triples ]
44-
45- h_batch = [model .get_entity_index (str_h ) for str_h in str_h_batch ]
46- r_batch = [model .get_entity_index (str_r ) for str_r in str_r_batch ]
47- t_batch = [model .get_entity_index (str_t ) for str_t in str_t_batch ]
48-
49- h_batch_tensor = torch .tensor (h_batch )
50- r_batch_tensor = torch .tensor (r_batch )
51- t_batch_tensor = torch .tensor (t_batch )
52-
53- batch_size_current = len (batch_triples )
54- num_entities = len (all_entities )
55-
34+ for i in tqdm (range (0 , len (triples ))):
35+ # (1) Get a triple (head entity, relation, tail entity
36+ data_point = triples [i ]
37+ str_h , str_r , str_t = data_point [0 ], data_point [1 ], data_point [2 ]
38+
39+ h , r , t = model .get_entity_index (str_h ), model .get_relation_index (str_r ), model .get_entity_index (str_t )
5640 # (2) Predict missing heads and tails
57- x = torch .stack ((h_batch_tensor .repeat_interleave (num_entities ), r_batch_tensor .repeat_interleave (num_entities ), all_entities .repeat (batch_size_current )), dim = 1 )
58- predictions_tails = model .model .forward_triples (x ).view (batch_size_current , num_entities )
59-
60- x = torch .stack ((all_entities .repeat (batch_size_current ),
61- r_batch_tensor .repeat_interleave (num_entities ),
62- t_batch_tensor .repeat_interleave (num_entities )
41+ x = torch .stack ((torch .tensor (h ).repeat (num_entities , ),
42+ torch .tensor (r ).repeat (num_entities , ),
43+ all_entities ), dim = 1 )
44+
45+ predictions_tails = model .model .forward_triples (x )
46+ x = torch .stack ((all_entities ,
47+ torch .tensor (r ).repeat (num_entities , ),
48+ torch .tensor (t ).repeat (num_entities )
6349 ), dim = 1 )
6450
65- predictions_heads = model .model .forward_triples (x ). view ( batch_size_current , num_entities )
51+ predictions_heads = model .model .forward_triples (x )
6652 del x
6753
68- # Now process each triple in the batch
69- for i in range (batch_size_current ):
70- h = h_batch [i ]
71- r = r_batch [i ]
72- t = t_batch [i ]
73- str_h = str_h_batch [i ]
74- str_r = str_r_batch [i ]
75- str_t = str_t_batch [i ]
76- # 3. Computed filtered ranks for missing tail entities.
77- # 3.1. Compute filtered tail entity rankings
78- filt_tails = [model .entity_to_idx [i ] for i in er_vocab [(str_h , str_r )]]
79- # 3.2 Get the predicted target's score
80- target_value = predictions_tails [t ].item ()
81- # 3.3 Filter scores of all triples containing filtered tail entities
82- predictions_tails [filt_tails ] = - np .Inf
83- # 3.4 Reset the target's score
84- predictions_tails [t ] = target_value
85- # 3.5. Sort the score
86- _ , sort_idxs = torch .sort (predictions_tails , descending = True )
87- sort_idxs = sort_idxs .detach ()
88- filt_tail_entity_rank = np .where (sort_idxs == t )[0 ][0 ]
89-
90- # 4. Computed filtered ranks for missing head entities.
91- # 4.1. Retrieve head entities to be filtered
92- filt_heads = [model .entity_to_idx [i ] for i in re_vocab [(str_r , str_t )]]
93- # 4.2 Get the predicted target's score
94- target_value = predictions_heads [h ].item ()
95- # 4.3 Filter scores of all triples containing filtered head entities.
96- predictions_heads [filt_heads ] = - np .Inf
97- predictions_heads [h ] = target_value
98- _ , sort_idxs = torch .sort (predictions_heads , descending = True )
99- sort_idxs = sort_idxs .detach ()
100- filt_head_entity_rank = np .where (sort_idxs == h )[0 ][0 ]
101-
102- # 4. Add 1 to ranks as numpy array first item has the index of 0.
103- filt_head_entity_rank += 1
104- filt_tail_entity_rank += 1
105-
106- rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank )
107- # 5. Store reciprocal ranks.
108- reciprocal_ranks .append (rr )
109- # print(f'{i}.th triple: mean reciprical rank:{rr}')
110-
111- # 4. Compute Hit@N
112- for hits_level in range (1 , 11 ):
113- res = 1 if filt_head_entity_rank <= hits_level else 0
114- res += 1 if filt_tail_entity_rank <= hits_level else 0
115- if res > 0 :
116- hits .setdefault (hits_level , []).append (res )
54+ # 3. Computed filtered ranks for missing tail entities.
55+ # 3.1. Compute filtered tail entity rankings
56+ filt_tails = [model .entity_to_idx [i ] for i in er_vocab [(str_h , str_r )]]
57+ # 3.2 Get the predicted target's score
58+ target_value = predictions_tails [t ].item ()
59+ # 3.3 Filter scores of all triples containing filtered tail entities
60+ predictions_tails [filt_tails ] = - np .Inf
61+ # 3.4 Reset the target's score
62+ predictions_tails [t ] = target_value
63+ # 3.5. Sort the score
64+ _ , sort_idxs = torch .sort (predictions_tails , descending = True )
65+ sort_idxs = sort_idxs .detach ()
66+ filt_tail_entity_rank = np .where (sort_idxs == t )[0 ][0 ]
67+
68+ # 4. Computed filtered ranks for missing head entities.
69+ # 4.1. Retrieve head entities to be filtered
70+ filt_heads = [model .entity_to_idx [i ] for i in re_vocab [(str_r , str_t )]]
71+ # 4.2 Get the predicted target's score
72+ target_value = predictions_heads [h ].item ()
73+ # 4.3 Filter scores of all triples containing filtered head entities.
74+ predictions_heads [filt_heads ] = - np .Inf
75+ predictions_heads [h ] = target_value
76+ _ , sort_idxs = torch .sort (predictions_heads , descending = True )
77+ sort_idxs = sort_idxs .detach ()
78+ filt_head_entity_rank = np .where (sort_idxs == h )[0 ][0 ]
79+
80+ # 4. Add 1 to ranks as numpy array first item has the index of 0.
81+ filt_head_entity_rank += 1
82+ filt_tail_entity_rank += 1
83+
84+ rr = 1.0 / filt_head_entity_rank + (1.0 / filt_tail_entity_rank )
85+ # 5. Store reciprocal ranks.
86+ reciprocal_ranks .append (rr )
87+ # print(f'{i}.th triple: mean reciprical rank:{rr}')
88+
89+ # 4. Compute Hit@N
90+ for hits_level in range (1 , 11 ):
91+ res = 1 if filt_head_entity_rank <= hits_level else 0
92+ res += 1 if filt_tail_entity_rank <= hits_level else 0
93+ if res > 0 :
94+ hits .setdefault (hits_level , []).append (res )
11795
11896 mean_reciprocal_rank = sum (reciprocal_ranks ) / (float (len (triples ) * 2 ))
11997
0 commit comments