diff --git a/beir/retrieval/search/dense/__init__.py b/beir/retrieval/search/dense/__init__.py index 03dca007..5f6bcffb 100644 --- a/beir/retrieval/search/dense/__init__.py +++ b/beir/retrieval/search/dense/__init__.py @@ -1,3 +1,3 @@ -from .exact_search import DenseRetrievalExactSearch +from .exact_search import DenseRetrievalExactSearch , DenseOfflineRetrievalExactSearch from .exact_search_multi_gpu import DenseRetrievalParallelExactSearch from .faiss_search import DenseRetrievalFaissSearch, BinaryFaissSearch, PQFaissSearch, HNSWFaissSearch, HNSWSQFaissSearch, FlatIPFaissSearch, PCAFaissSearch, SQFaissSearch \ No newline at end of file diff --git a/beir/retrieval/search/dense/exact_search.py b/beir/retrieval/search/dense/exact_search.py index 642b21b9..4266d9f7 100644 --- a/beir/retrieval/search/dense/exact_search.py +++ b/beir/retrieval/search/dense/exact_search.py @@ -1,5 +1,6 @@ from .. import BaseSearch from .util import cos_sim, dot_score +import numpy as np import logging import torch from typing import Dict @@ -22,6 +23,19 @@ def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, self.convert_to_tensor = kwargs.get("convert_to_tensor", True) self.results = {} + def call_model_for_queries(self, queries, queries_ids=None): + query_embeddings = self.model.encode_queries( + queries, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor) + return query_embeddings + + def call_model_for_subcorpus(self, corpus_start_idx, corpus_end_idx, corpus, cor_ids=None): + return self.model.encode_corpus( + corpus[corpus_start_idx:corpus_end_idx], + batch_size=self.batch_size, + show_progress_bar=self.show_progress_bar, + convert_to_tensor = self.convert_to_tensor + ) + def search(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], @@ -38,10 +52,13 @@ def search(self, logger.info("Encoding Queries...") query_ids = list(queries.keys()) self.results = {qid: {} for qid in query_ids} - queries = [queries[qid] for qid in queries] - query_embeddings = self.model.encode_queries( - queries, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_tensor=self.convert_to_tensor) - + queries_strings = [] + queries_ids = [] + for qid, qs in queries.items(): + queries_strings.append(qs) + queries_ids.append(qid) + + query_embeddings = self.call_model_for_queries(queries_strings, queries_ids) logger.info("Sorting Corpus by document length (Longest first)...") corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True) @@ -58,12 +75,7 @@ def search(self, corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(corpus)) # Encode chunk of corpus - sub_corpus_embeddings = self.model.encode_corpus( - corpus[corpus_start_idx:corpus_end_idx], - batch_size=self.batch_size, - show_progress_bar=self.show_progress_bar, - convert_to_tensor = self.convert_to_tensor - ) + sub_corpus_embeddings = self.call_model_for_subcorpus(corpus_start_idx, corpus_end_idx, corpus, corpus_ids) # Compute similarites using either cosine-similarity or dot product cos_scores = self.score_functions[score_function](query_embeddings, sub_corpus_embeddings) @@ -91,3 +103,15 @@ def search(self, self.results[qid][corpus_id] = score return self.results + +class DenseOfflineRetrievalExactSearch(DenseRetrievalExactSearch): + def __init__(self, model, batch_size: int = 128, corpus_chunk_size: int = 50000, **kwargs): + super(DenseOfflineRetrievalExactSearch, self).__init__(model, batch_size, corpus_chunk_size, **kwargs) + + def call_model_for_queries(self, queries, queries_ids): + query_embeddings = self.model.encode_queries(queries_ids) + return query_embeddings + + def call_model_for_subcorpus(self, corpus_start_idx, corpus_end_idx, corpus ,cor_ids): + sub_corpus_embeddings = self.model.encode_corpus(cor_ids[corpus_start_idx:corpus_end_idx]) + return sub_corpus_embeddings diff --git a/examples/retrieval/evaluation/dense/evaluate_offline_model.py b/examples/retrieval/evaluation/dense/evaluate_offline_model.py new file mode 100644 index 00000000..104c7cec --- /dev/null +++ b/examples/retrieval/evaluation/dense/evaluate_offline_model.py @@ -0,0 +1,95 @@ +from time import time +from beir import util, LoggingHandler +from beir.retrieval import models +from beir.datasets.data_loader import GenericDataLoader +from beir.retrieval.evaluation import EvaluateRetrieval +from beir.retrieval.search.dense import DenseOfflineRetrievalExactSearch +from typing import List, Dict, Optional +import numpy as np + +import logging +import pathlib, os +import random + +class OfflineModel: + def __init__(self, model_path=None, **kwargs): + self.model = None # We don't use the model in offline mode + # simply load the query, corpus and qrels + self.query_npy = np.load(f"{model_path}/queries.npy") + self.corpus_npy = np.load(f"{model_path}/corpus0.npy") + self.query_ids = self.load_ids(f"{model_path}/queries.ids") + self.corpus_ids = self.load_ids(f"{model_path}/corpus0.ids") + + def load_ids(self, id_strs_path): + id_strs = open(id_strs_path, 'r').read().splitlines() + return {id.strip(): idx for idx, id in enumerate(id_strs)} + + # Write your own encoding query function (Returns: Query embeddings as numpy array) + # For eg ==> return np.asarray(self.model.encode(queries, batch_size=batch_size, **kwargs)) + def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> np.ndarray: + idxs = [self.query_ids[q_id] for q_id in queries] + return self.query_npy[idxs] + + # Write your own encoding corpus function (Returns: Document embeddings as numpy array) + # For eg ==> sentences = [(doc["title"] + " " + doc["text"]).strip() for doc in corpus] + # ==> return np.asarray(self.model.encode(sentences, batch_size=batch_size, **kwargs)) + def encode_corpus(self, corpus: List[str], batch_size: int = 8, **kwargs) -> np.ndarray: + idxs = [self.corpus_ids[c_id] for c_id in corpus] + return self.corpus_npy[idxs] + + +#### Just some code to print debug information to stdout +logging.basicConfig(format='%(asctime)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + level=logging.INFO, + handlers=[LoggingHandler()]) +#### /print debug information to stdout + +dataset = "scifact" + +#### Download nfcorpus.zip dataset and unzip the dataset +url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) +out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets") +data_path = util.download_and_unzip(url, out_dir) + +#### Provide the data path where nfcorpus has been downloaded and unzipped to the data loader +# data folder would contain these files: +# (1) nfcorpus/corpus.jsonl (format: jsonlines) +# (2) nfcorpus/queries.jsonl (format: jsonlines) +# (3) nfcorpus/qrels/test.tsv (format: tsv ("\t")) + +corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test") + +#### Dense Retrieval using SBERT (Sentence-BERT) #### +#### Provide any pretrained sentence-transformers model +#### The model was fine-tuned using cosine-similarity. +#### Complete list - https://www.sbert.net/docs/pretrained_models.html + +model = DenseOfflineRetrievalExactSearch(OfflineModel(model_path="./offline_model")) +retriever = EvaluateRetrieval(model, score_function="dot") + +#### Retrieve dense results (format of results is identical to qrels) +start_time = time() +results = retriever.retrieve(corpus, queries) +end_time = time() +print("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time)) +#### Evaluate your retrieval using NDCG@k, MAP@K ... + +logging.info("Retriever evaluation for k in: {}".format(retriever.k_values)) +ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) + +mrr = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="mrr") +recall_cap = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="r_cap") +hole = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="hole") + +#### Print top-k documents retrieved #### +top_k = 10 + +query_id, ranking_scores = random.choice(list(results.items())) +scores_sorted = sorted(ranking_scores.items(), key=lambda item: item[1], reverse=True) +logging.info("Query : %s\n" % queries[query_id]) + +for rank in range(top_k): + doc_id = scores_sorted[rank][0] + # Format: Rank x: ID [Title] Body + logging.info("Rank %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))