diff --git a/retrieval.py b/retrieval.py index 82b2851..3d56162 100644 --- a/retrieval.py +++ b/retrieval.py @@ -12,7 +12,6 @@ from bs4 import BeautifulSoup from transformers import AutoTokenizer, AutoModelForMaskedLM import optimum.bettertransformer.transformation -from sentence_transformers import SentenceTransformer try: from .retrievers.faiss_retriever import FaissRetriever @@ -20,14 +19,14 @@ from .retrievers.splade_retriever import SpladeRetriever from .chunkers.semantic_chunker import BoundedSemanticChunker from .chunkers.character_chunker import RecursiveCharacterTextSplitter - from .utils import Document + from .utils import Document, MySentenceTransformer except ImportError: from retrievers.faiss_retriever import FaissRetriever from retrievers.bm25_retriever import BM25Retriever from retrievers.splade_retriever import SpladeRetriever from chunkers.semantic_chunker import BoundedSemanticChunker from chunkers.character_chunker import RecursiveCharacterTextSplitter - from utils import Document + from utils import Document, MySentenceTransformer class DocumentRetriever: @@ -37,9 +36,9 @@ def __init__(self, device="cuda", num_results: int = 5, similarity_threshold: fl model_cache_dir: str = None, chunking_method: str = "character-based", chunker_breakpoint_threshold_amount: int = 10, client_timeout: int = 10): self.device = device - self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir, - device=device, - model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16}) + self.embedding_model = MySentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir, + device=device, + model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16}) if keyword_retriever == "splade": splade_kwargs = {"cache_dir": model_cache_dir, "torch_dtype": torch.float32 if device == "cpu" else torch.float16, diff --git a/retrievers/faiss_retriever.py b/retrievers/faiss_retriever.py index 4ecf7a9..aa9933d 100644 --- a/retrievers/faiss_retriever.py +++ b/retrievers/faiss_retriever.py @@ -2,17 +2,16 @@ import faiss import numpy as np -from sentence_transformers import SentenceTransformer try: - from ..utils import Document, cosine_similarity + from ..utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer except: - from utils import Document, cosine_similarity + from utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer class FaissRetriever: - def __init__(self, embedding_model: SentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5): + def __init__(self, embedding_model: MySentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5): self.embedding_model = embedding_model self.num_results = num_results self.similarity_threshold = similarity_threshold @@ -24,7 +23,7 @@ def add_documents(self, documents: List[Document]): if not documents: return self.documents = documents - self.document_embeddings = self.embedding_model.encode([doc.page_content for doc in documents]) + self.document_embeddings = self.embedding_model.batch_encode([doc.page_content for doc in documents]) self.index.add(self.document_embeddings) def get_relevant_documents(self, query: str) -> List[Document]: diff --git a/retrievers/splade_retriever.py b/retrievers/splade_retriever.py index f053333..fb88d5a 100644 --- a/retrievers/splade_retriever.py +++ b/retrievers/splade_retriever.py @@ -11,80 +11,9 @@ from scipy.sparse import csr_array try: - from ..utils import Document + from ..utils import Document, SimilarLengthsBatchifyer except: - from utils import Document - - -class SimilarLengthsBatchifyer: - """ - Generator class to split samples into batches. Groups sample sequences - of equal/similar length together to minimize the need for padding within a batch. - """ - def __init__(self, batch_size, inputs, max_padding_len=10): - # Remember number of samples - self.num_samples = len(inputs) - - self.unique_lengths = set() - self.length_to_sample_indices = {} - - for i in range(0, len(inputs)): - len_input = len(inputs[i]) - - self.unique_lengths.add(len_input) - - # For each length, keep track of the indices of the samples that have this length - # E.g.: self.length_to_sample_indices = { 3: [3,5,11], 4: [1,2], ...} - if len_input in self.length_to_sample_indices: - self.length_to_sample_indices[len_input].append(i) - else: - self.length_to_sample_indices[len_input] = [i] - - # Use a dynamic batch size to speed up inference at a constant VRAM usage - self.unique_lengths = sorted(list(self.unique_lengths)) - max_chars_per_batch = self.unique_lengths[-1] * batch_size - self.length_to_batch_size = {length: int(max_chars_per_batch / (length * batch_size)) * batch_size for length in self.unique_lengths} - - # Merge samples of similar lengths in those cases where the amount of samples - # of a particular length is < dynamic batch size - accum_len_diff = 0 - for i in range(1, len(self.unique_lengths)): - if accum_len_diff >= max_padding_len: - accum_len_diff = 0 - continue - curr_len = self.unique_lengths[i] - prev_len = self.unique_lengths[i-1] - len_diff = curr_len - prev_len - if (len_diff <= max_padding_len and - (len(self.length_to_sample_indices[curr_len]) < self.length_to_batch_size[curr_len] - or len(self.length_to_sample_indices[prev_len]) < self.length_to_batch_size[prev_len])): - self.length_to_sample_indices[curr_len].extend(self.length_to_sample_indices[prev_len]) - self.length_to_sample_indices[prev_len] = [] - accum_len_diff += len_diff - else: - accum_len_diff = 0 - - def __len__(self): - return self.num_samples - - def __iter__(self): - # Iterate over all possible sentence lengths - for length in self.unique_lengths: - - # Get indices of all samples for the current length - # for example, all indices of samples with a length of 7 - sequence_indices = self.length_to_sample_indices[length] - if len(sequence_indices) == 0: - continue - - dyn_batch_size = self.length_to_batch_size[length] - - # Compute the number of batches - num_batches = np.ceil(len(sequence_indices) / dyn_batch_size) - - # Loop over all possible batches - for batch_indices in np.array_split(sequence_indices, num_batches): - yield batch_indices + from utils import Document, SimilarLengthsBatchifyer def neg_dot_dist(x, y): @@ -112,7 +41,9 @@ def __init__(self, splade_doc_tokenizer, splade_doc_model, splade_query_tokenize def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]: indices = [] values = [] - batchifyer = SimilarLengthsBatchifyer(batch_size, texts) + tokenized_texts = self.splade_doc_tokenizer(texts, truncation=False, padding=False, + return_tensors="np")["input_ids"] + batchifyer = SimilarLengthsBatchifyer(batch_size, tokenized_texts) texts = np.array(texts) batch_indices = [] for index_batch in batchifyer: diff --git a/utils.py b/utils.py index 035a404..53e75b1 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,14 @@ -from typing import Dict +from typing import Dict, Literal +import warnings +import math +import copy from dataclasses import dataclass +from torch import Tensor +import torch import numpy as np +from sentence_transformers import SentenceTransformer, quantize_embeddings +from sentence_transformers.util import batch_to_device, truncate_embeddings @dataclass @@ -54,4 +61,262 @@ def dict_list_to_pretty_str(data: list[dict]) -> str: ret_str += f"Source URL: {d['href']}\n" return ret_str else: - raise ValueError("Input must be dict or list[dict]") \ No newline at end of file + raise ValueError("Input must be dict or list[dict]") + + +class SimilarLengthsBatchifyer: + """ + Generator class to split samples into batches. Groups sample sequences + of equal/similar length together to minimize the need for padding within a batch. + """ + def __init__(self, batch_size, inputs, max_padding_len=10): + # Remember number of samples + self.num_samples = len(inputs) + + self.unique_lengths = set() + self.length_to_sample_indices = {} + + for i in range(0, len(inputs)): + len_input = len(inputs[i]) + + self.unique_lengths.add(len_input) + + # For each length, keep track of the indices of the samples that have this length + # E.g.: self.length_to_sample_indices = { 3: [3,5,11], 4: [1,2], ...} + if len_input in self.length_to_sample_indices: + self.length_to_sample_indices[len_input].append(i) + else: + self.length_to_sample_indices[len_input] = [i] + + # Use a dynamic batch size to speed up inference at a constant VRAM usage + self.unique_lengths = sorted(list(self.unique_lengths)) + max_chars_per_batch = self.unique_lengths[-1] * batch_size + self.length_to_batch_size = {length: int(max_chars_per_batch / (length * batch_size)) * batch_size for length in self.unique_lengths} + + # Merge samples of similar lengths in those cases where the amount of samples + # of a particular length is < dynamic batch size + accum_len_diff = 0 + for i in range(1, len(self.unique_lengths)): + if accum_len_diff >= max_padding_len: + accum_len_diff = 0 + continue + curr_len = self.unique_lengths[i] + prev_len = self.unique_lengths[i-1] + len_diff = curr_len - prev_len + if (len_diff <= max_padding_len and + (len(self.length_to_sample_indices[curr_len]) < self.length_to_batch_size[curr_len] + or len(self.length_to_sample_indices[prev_len]) < self.length_to_batch_size[prev_len])): + self.length_to_sample_indices[curr_len].extend(self.length_to_sample_indices[prev_len]) + self.length_to_sample_indices[prev_len] = [] + accum_len_diff += len_diff + else: + accum_len_diff = 0 + + def __len__(self): + return self.num_samples + + def __iter__(self): + # Iterate over all possible sentence lengths + for length in self.unique_lengths: + + # Get indices of all samples for the current length + # for example, all indices of samples with a length of 7 + sequence_indices = self.length_to_sample_indices[length] + if len(sequence_indices) == 0: + continue + + dyn_batch_size = self.length_to_batch_size[length] + + # Compute the number of batches + num_batches = np.ceil(len(sequence_indices) / dyn_batch_size) + + # Loop over all possible batches + for batch_indices in np.array_split(sequence_indices, num_batches): + yield batch_indices + + +class MySentenceTransformer(SentenceTransformer): + def batch_encode( + self, + sentences: str | list[str], + prompt_name: str | None = None, + prompt: str | None = None, + batch_size: int = 32, + output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding", + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + **kwargs, + ) -> list[Tensor] | np.ndarray | Tensor: + if self.device.type == "hpu" and not self.is_hpu_graph_enabled: + import habana_frameworks.torch as ht + + ht.hpu.wrap_in_hpu_graph(self, disable_tensor_cache=True) + self.is_hpu_graph_enabled = True + + self.eval() + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sentence_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if prompt is None: + if prompt_name is not None: + try: + prompt = self.prompts[prompt_name] + except KeyError: + raise ValueError( + f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}." + ) + elif self.default_prompt_name is not None: + prompt = self.prompts.get(self.default_prompt_name, None) + else: + if prompt_name is not None: + warnings.warn( + "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. " + "Ignoring the `prompt_name` in favor of `prompt`." + ) + + extra_features = {} + if prompt is not None: + sentences = [prompt + sentence for sentence in sentences] + + # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling + # Tracking the prompt length allow us to remove the prompt during pooling + tokenized_prompt = self.tokenize([prompt]) + if "input_ids" in tokenized_prompt: + extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1 + + if device is None: + device = self.device + else: + device = torch.device(device) + + self.to(device) + + all_embeddings = [] + tokenized_sentences = self.tokenizer(sentences, verbose=False)["input_ids"] + batchifyer = SimilarLengthsBatchifyer(batch_size, tokenized_sentences) + sentences = np.array(sentences) + batch_indices = [] + for index_batch in batchifyer: + batch_indices.append(index_batch) + sentences_batch = sentences[index_batch] + features = self.tokenize(sentences_batch) + if self.device.type == "hpu": + if "input_ids" in features: + curr_tokenize_len = features["input_ids"].shape + additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1] + features["input_ids"] = torch.cat( + ( + features["input_ids"], + torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + features["attention_mask"] = torch.cat( + ( + features["attention_mask"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + if "token_type_ids" in features: + features["token_type_ids"] = torch.cat( + ( + features["token_type_ids"], + torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8), + ), + -1, + ) + + features = batch_to_device(features, device) + features.update(extra_features) + + with torch.no_grad(): + out_features = self.forward(features, **kwargs) + if self.device.type == "hpu": + out_features = copy.deepcopy(out_features) + + out_features["sentence_embedding"] = truncate_embeddings( + out_features["sentence_embedding"], self.truncate_dim + ) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]): + last_mask_id = len(attention) - 1 + while last_mask_id > 0 and attention[last_mask_id].item() == 0: + last_mask_id -= 1 + + embeddings.append(token_emb[0: last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = {name: out_features[name][sent_idx] for name in out_features} + embeddings.append(row) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + # fixes for #522 and #487 to avoid oom problems on gpu with large datasets + if convert_to_numpy: + embeddings = embeddings.to("cpu", non_blocking=True) + sync_device(device) + + all_embeddings.extend(embeddings) + + # Restore order after SimilarLengthsBatchifyer disrupted it: + # Ensure that the order of 'indices' and 'values' matches the order of the 'texts' parameter + batch_indices = np.concatenate(batch_indices) + sorted_indices = np.argsort(batch_indices) + all_embeddings = [all_embeddings[i] for i in sorted_indices] + + if precision and precision != "float32": + all_embeddings = quantize_embeddings(all_embeddings, precision=precision) + + if convert_to_tensor: + if len(all_embeddings): + if isinstance(all_embeddings, np.ndarray): + all_embeddings = torch.from_numpy(all_embeddings) + else: + all_embeddings = torch.Tensor() + elif convert_to_numpy: + if not isinstance(all_embeddings, np.ndarray): + if all_embeddings and all_embeddings[0].dtype == torch.bfloat16: + all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings]) + else: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + elif isinstance(all_embeddings, np.ndarray): + all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings] + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings + + +def sync_device(device: torch.device): + if device.type == "cpu": + return + elif device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + elif device.type == "xpu": + torch.xpu.synchronize(device) + else: + warnings.warn("Device type does not match 'cuda', 'xpu' or 'mps'. Not synchronizing")