Skip to content

Commit

Permalink
Use token counts for SimilarLengthsBatchifyer (#155)
Browse files Browse the repository at this point in the history
* use token counts for SimilarLengthsBatchifyer and use the batchifyer for dense embedding model too

* make synchronization device agnostic
  • Loading branch information
mamei16 authored Feb 23, 2025
1 parent fc49f9d commit 4b8f85c
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 87 deletions.
11 changes: 5 additions & 6 deletions retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
from bs4 import BeautifulSoup
from transformers import AutoTokenizer, AutoModelForMaskedLM
import optimum.bettertransformer.transformation
from sentence_transformers import SentenceTransformer

try:
from .retrievers.faiss_retriever import FaissRetriever
from .retrievers.bm25_retriever import BM25Retriever
from .retrievers.splade_retriever import SpladeRetriever
from .chunkers.semantic_chunker import BoundedSemanticChunker
from .chunkers.character_chunker import RecursiveCharacterTextSplitter
from .utils import Document
from .utils import Document, MySentenceTransformer
except ImportError:
from retrievers.faiss_retriever import FaissRetriever
from retrievers.bm25_retriever import BM25Retriever
from retrievers.splade_retriever import SpladeRetriever
from chunkers.semantic_chunker import BoundedSemanticChunker
from chunkers.character_chunker import RecursiveCharacterTextSplitter
from utils import Document
from utils import Document, MySentenceTransformer


class DocumentRetriever:
Expand All @@ -37,9 +36,9 @@ def __init__(self, device="cuda", num_results: int = 5, similarity_threshold: fl
model_cache_dir: str = None, chunking_method: str = "character-based",
chunker_breakpoint_threshold_amount: int = 10, client_timeout: int = 10):
self.device = device
self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir,
device=device,
model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16})
self.embedding_model = MySentenceTransformer("all-MiniLM-L6-v2", cache_folder=model_cache_dir,
device=device,
model_kwargs={"torch_dtype": torch.float32 if device == "cpu" else torch.float16})
if keyword_retriever == "splade":
splade_kwargs = {"cache_dir": model_cache_dir,
"torch_dtype": torch.float32 if device == "cpu" else torch.float16,
Expand Down
9 changes: 4 additions & 5 deletions retrievers/faiss_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer

try:
from ..utils import Document, cosine_similarity
from ..utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer
except:
from utils import Document, cosine_similarity
from utils import Document, cosine_similarity, MySentenceTransformer, SimilarLengthsBatchifyer


class FaissRetriever:

def __init__(self, embedding_model: SentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5):
def __init__(self, embedding_model: MySentenceTransformer, num_results: int = 5, similarity_threshold: float = 0.5):
self.embedding_model = embedding_model
self.num_results = num_results
self.similarity_threshold = similarity_threshold
Expand All @@ -24,7 +23,7 @@ def add_documents(self, documents: List[Document]):
if not documents:
return
self.documents = documents
self.document_embeddings = self.embedding_model.encode([doc.page_content for doc in documents])
self.document_embeddings = self.embedding_model.batch_encode([doc.page_content for doc in documents])
self.index.add(self.document_embeddings)

def get_relevant_documents(self, query: str) -> List[Document]:
Expand Down
79 changes: 5 additions & 74 deletions retrievers/splade_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,80 +11,9 @@
from scipy.sparse import csr_array

try:
from ..utils import Document
from ..utils import Document, SimilarLengthsBatchifyer
except:
from utils import Document


class SimilarLengthsBatchifyer:
"""
Generator class to split samples into batches. Groups sample sequences
of equal/similar length together to minimize the need for padding within a batch.
"""
def __init__(self, batch_size, inputs, max_padding_len=10):
# Remember number of samples
self.num_samples = len(inputs)

self.unique_lengths = set()
self.length_to_sample_indices = {}

for i in range(0, len(inputs)):
len_input = len(inputs[i])

self.unique_lengths.add(len_input)

# For each length, keep track of the indices of the samples that have this length
# E.g.: self.length_to_sample_indices = { 3: [3,5,11], 4: [1,2], ...}
if len_input in self.length_to_sample_indices:
self.length_to_sample_indices[len_input].append(i)
else:
self.length_to_sample_indices[len_input] = [i]

# Use a dynamic batch size to speed up inference at a constant VRAM usage
self.unique_lengths = sorted(list(self.unique_lengths))
max_chars_per_batch = self.unique_lengths[-1] * batch_size
self.length_to_batch_size = {length: int(max_chars_per_batch / (length * batch_size)) * batch_size for length in self.unique_lengths}

# Merge samples of similar lengths in those cases where the amount of samples
# of a particular length is < dynamic batch size
accum_len_diff = 0
for i in range(1, len(self.unique_lengths)):
if accum_len_diff >= max_padding_len:
accum_len_diff = 0
continue
curr_len = self.unique_lengths[i]
prev_len = self.unique_lengths[i-1]
len_diff = curr_len - prev_len
if (len_diff <= max_padding_len and
(len(self.length_to_sample_indices[curr_len]) < self.length_to_batch_size[curr_len]
or len(self.length_to_sample_indices[prev_len]) < self.length_to_batch_size[prev_len])):
self.length_to_sample_indices[curr_len].extend(self.length_to_sample_indices[prev_len])
self.length_to_sample_indices[prev_len] = []
accum_len_diff += len_diff
else:
accum_len_diff = 0

def __len__(self):
return self.num_samples

def __iter__(self):
# Iterate over all possible sentence lengths
for length in self.unique_lengths:

# Get indices of all samples for the current length
# for example, all indices of samples with a length of 7
sequence_indices = self.length_to_sample_indices[length]
if len(sequence_indices) == 0:
continue

dyn_batch_size = self.length_to_batch_size[length]

# Compute the number of batches
num_batches = np.ceil(len(sequence_indices) / dyn_batch_size)

# Loop over all possible batches
for batch_indices in np.array_split(sequence_indices, num_batches):
yield batch_indices
from utils import Document, SimilarLengthsBatchifyer


def neg_dot_dist(x, y):
Expand Down Expand Up @@ -112,7 +41,9 @@ def __init__(self, splade_doc_tokenizer, splade_doc_model, splade_query_tokenize
def compute_document_vectors(self, texts: List[str], batch_size: int) -> Tuple[List[List[int]], List[List[float]]]:
indices = []
values = []
batchifyer = SimilarLengthsBatchifyer(batch_size, texts)
tokenized_texts = self.splade_doc_tokenizer(texts, truncation=False, padding=False,
return_tensors="np")["input_ids"]
batchifyer = SimilarLengthsBatchifyer(batch_size, tokenized_texts)
texts = np.array(texts)
batch_indices = []
for index_batch in batchifyer:
Expand Down
Loading

0 comments on commit 4b8f85c

Please sign in to comment.