diff --git a/scarf/readers.py b/scarf/readers.py index 1fc905c..69ef4f9 100644 --- a/scarf/readers.py +++ b/scarf/readers.py @@ -9,11 +9,11 @@ - LoomReader: A class to read in data in the form of a Loom file. """ +import math import os from abc import ABC, abstractmethod -from typing import Generator, Dict, List, Optional, Tuple -from typing import IO -import math +from typing import IO, Dict, Generator, List, Optional, Tuple + import h5py import numpy as np import pandas as pd @@ -392,38 +392,47 @@ def _read_dataset(self, key: Optional[str] = None): vals = None return vals - def read_header(self) -> pl.DataFrame: - header = pl.read_csv( + def read_header(self) -> pd.DataFrame: + header = pd.read_csv( self.matFn, - comment_prefix = '%', - separator=self.sep, - has_header=False, - n_rows=1, - new_columns=["nFeatures", "nCells", "nCounts"], + comment="%", + sep=self.sep, + header=None, + nrows=1, + names=["nFeatures", "nCells", "nCounts"], ) - if header['nCells'][0] == 0 and self.nCells > 0: - raise ValueError("ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file") - if header['nCells'][0] > 0 and self.nCells == 0: - raise ValueError("ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file") - if header['nCells'][0] == 0 and self.nCells == 0: - raise ValueError("ERROR: Barcode count in MTX header and barcodes file is 0. No data to read") + if header["nCells"][0] == 0 and self.nCells > 0: + raise ValueError( + "ERROR: Barcode count in MTX header is 0 but barcodes are present in the barcodes file" + ) + if header["nCells"][0] > 0 and self.nCells == 0: + raise ValueError( + "ERROR: Barcode count in MTX header is greater than 0 but no barcodes are present in the barcodes file" + ) + if header["nCells"][0] == 0 and self.nCells == 0: + raise ValueError( + "ERROR: Barcode count in MTX header and barcodes file is 0. No data to read" + ) return header - def process_batch(self, dfs: pl.DataFrame, filtering_cutoff: int) -> List: + def process_batch(self, dfs: List[pd.DataFrame], filtering_cutoff: int) -> np.array: """Returns a list of valid barcodes after filtering out background barcodes for a given batch. Args: dfs: A Polar DataFrame containing a chunk of data from the MTX file. filtering_cutoff: The cutoff value for filtering out background barcodes """ - dfs_ = dfs.group_by('barcode').agg(pl.sum('count')) + pl_dfs = [pl.DataFrame(df) for df in dfs] + pl_dfs = pl.concat(pl_dfs) + dfs_ = pl_dfs.group_by('barcode').agg(pl.sum('count')) dfs_ = dfs_.filter(pl.col('count') > filtering_cutoff) return np.sort(dfs_['barcode']) def _get_valid_barcodes( - self, filtering_cutoff: int, - batch_size: int = int(10e4), - lines_in_mem: int = int(10e6) + self, + filtering_cutoff: int, + batch_size: int = int(10e3), + lines_in_mem: int = int(10e6), ) -> np.ndarray: """Returns a list of valid barcodes after filtering out background barcodes. @@ -433,48 +442,53 @@ def _get_valid_barcodes( lines_in_mem: The number of lines to read into memory """ test_counter = 0 - matrixIO = pl.scan_csv( - self.matFn, - comment_prefix='%', - # skip_rows=3, - skip_rows_after_header=1, - separator=self.sep, - has_header=False, + matrixIO = pd.read_csv( + self.matFn, + comment="%", + sep=self.sep, + header=0, + chunksize=lines_in_mem, + names=["gene", "barcode", "count"], ) - assert len(matrixIO.collect_schema().names()) == 3 - matrixIO = matrixIO.rename({'column_1': 'gene', 'column_2': 'barcode', 'column_3': 'count'}) + header = self.read_header() nChunks = math.ceil(header["nCounts"][0] / lines_in_mem) test_counter = 0 valid_idx = [] start = 1 - dfs = pl.DataFrame() - for i in tqdmbar( - range(nChunks), desc="Filtering out background barcodes" + + dfs = [] + for chunk in tqdmbar( + # range(nChunks), + matrixIO, + total=nChunks, + desc="Filtering out background barcodes", ): - chunk = matrixIO.slice(i*lines_in_mem, lines_in_mem).collect() - # Check if we've reached or exceeded the current batch boundary - if (chunk[-1]['barcode'][0] - start) >= batch_size: # If the last "cell id" is greater than the start + batch size + if ( + (chunk.iloc[-1]["barcode"] - start) >= batch_size + ): # If the last "cell id" is greater than the start + batch size # Filter rows in the current chunk that belong to the current batch - idx = np.array(chunk['barcode'] < (batch_size + start)) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included. + idx = np.array( + chunk["barcode"].values < (batch_size + start) + ) # This is the crucial line. This makes sure that if any cell ID is spread over multiple chunks, it is not missed, as any cell ID that is less than the batch size + start is included. # If no rows belong to the current batch, move to the next batch. if idx.sum() == 0: - dfs = pl.concat([dfs, chunk]) + dfs.append(chunk) start += batch_size test_counter += len(chunk) continue # Process the rows belonging to the current batch mask_pos = np.where(idx)[0] mask_neg = np.where(~idx)[0] - dfs = pl.concat([dfs, chunk[mask_pos]]) + dfs.append(chunk.iloc[mask_pos]) valid_idx.append(self.process_batch(dfs, filtering_cutoff)) # Prepare for the next batch del dfs - dfs = chunk[mask_neg] + dfs = [chunk.iloc[mask_neg]] start += batch_size else: # If we haven't reached the batch boundary, accumulate the chunk - dfs = pl.concat([dfs, chunk]) + dfs.append(chunk) test_counter += len(chunk) # Process any remaining data after the main loop if len(dfs) > 0: @@ -512,7 +526,7 @@ def cell_names(self) -> List[str]: def rename_batches(self, collect: List[pl.DataFrame], batch_size: int) -> List: df = pl.concat(collect) - barcodes = np.array(df['barcode']) + barcodes = np.array(df["barcode"]) count_hash = {} for i, x in enumerate(np.unique(barcodes)): count_hash[x] = i @@ -535,14 +549,14 @@ def consume( dtype: The data type of the matrix. """ matrixIO = pl.read_csv_batched( - self.matFn, - has_header=False, + self.matFn, + has_header=False, separator=self.sep, comment_prefix="%", - skip_rows_after_header=1, - new_columns=['gene', 'barcode', 'count'], - schema_overrides={'gene': pl.Int64, 'barcode': pl.Int64, 'count': pl.Int64}, - batch_size=lines_in_mem + skip_rows_after_header=1, + new_columns=["gene", "barcode", "count"], + schema_overrides={"gene": pl.Int64, "barcode": pl.Int64, "count": pl.Int64}, + batch_size=lines_in_mem, ) unique_list = [] collect = [] @@ -551,20 +565,20 @@ def consume( if chunk is None: break chunk = chunk[0] - chunk = chunk.filter(pl.col('barcode').is_in(self.validBarcodeIdx)) - in_uniques = np.unique(chunk['barcode']) + chunk = chunk.filter(pl.col("barcode").is_in(self.validBarcodeIdx)) + in_uniques = np.unique(chunk["barcode"]) unique_list.extend(in_uniques) unique_list = list(set(unique_list)) if len(unique_list) > batch_size: diff = batch_size - (len(unique_list) - len(in_uniques)) mask_pos = in_uniques[:diff] mask_neg = in_uniques[diff:] - extra = chunk.filter(pl.col('barcode').is_in(mask_pos)) + extra = chunk.filter(pl.col("barcode").is_in(mask_pos)) collect.append(extra) collect = self.rename_batches(collect, batch_size) mtx = self.to_sparse(np.array(collect), dtype=dtype) yield mtx - left_out = chunk.filter(pl.col('barcode').is_in(mask_neg)) + left_out = chunk.filter(pl.col("barcode").is_in(mask_neg)) collect = [] unique_list = list(mask_neg) collect.append(left_out) @@ -635,8 +649,9 @@ def __init__( self.obsmAttrsKey: self._validate_group(self.obsmAttrsKey), self.matrixKey: self._validate_group(self.matrixKey), } - self.nCells, self.nFeatures = self._get_n(self.cellAttrsKey), self._get_n( - self.featureAttrsKey + self.nCells, self.nFeatures = ( + self._get_n(self.cellAttrsKey), + self._get_n(self.featureAttrsKey), ) self.cellIdsKey = self._fix_name_key(self.cellAttrsKey, cell_ids_key) self.featIdsKey = self._fix_name_key(self.featureAttrsKey, feature_ids_key) @@ -809,8 +824,9 @@ def _get_col_data( if i in ignore_keys: continue if isinstance(self.h5[group][i], h5py.Dataset): - yield i, self._replace_category_values( - self.h5[group][i][:], i, group + yield ( + i, + self._replace_category_values(self.h5[group][i][:], i, group), ) def _get_obsm_data( @@ -832,7 +848,7 @@ def _get_obsm_data( yield f"{i}{j+1}", g[:, j] else: logger.warning( - f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541 + f"Reading of obsm failed because it either does not exist or is not in expected format" # noqa: F541 ) def get_cell_columns(self) -> Generator[Tuple[str, np.ndarray], None, None]: