From dee82a2b33fd7cfb894625cecd9c492adcac5fe0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 28 Nov 2024 16:11:35 +0100 Subject: [PATCH] (chore): good chunk size --- src/anndata/_core/merge.py | 17 ++++++++++++----- src/anndata/_io/specs/lazy_methods.py | 9 ++++++++- src/anndata/experimental/backed/_lazy_arrays.py | 12 ++++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index 086880bed..44a5357f6 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1096,9 +1096,10 @@ def make_dask_col_from_extension_dtype( from anndata._io.specs.lazy_methods import ( compute_chunk_layout_for_axis_size, + get_chunksize, maybe_open_h5, ) - from anndata.experimental import read_lazy + from anndata.experimental import read_elem_lazy from anndata.experimental.backed._compat import DataArray from anndata.experimental.backed._compat import xarray as xr @@ -1106,10 +1107,17 @@ def make_dask_col_from_extension_dtype( elem_name = col.attrs.get("elem_name") dims = col.dims coords = col.coords.copy() + with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: + maybe_chunk_size = get_chunksize(read_elem_lazy(f)) + chunk_size = ( + compute_chunk_layout_for_axis_size( + 1000 if maybe_chunk_size is None else maybe_chunk_size[0], col.shape[0] + ), + ) def get_chunk(block_info=None): with maybe_open_h5(base_path_or_zarr_group, elem_name) as f: - v = read_lazy(f) + v = read_elem_lazy(f) variable = xr.Variable( data=xr.core.indexing.LazilyIndexedArray(v), dims=dims ) @@ -1128,10 +1136,9 @@ def get_chunk(block_info=None): dtype = "object" else: dtype = col.dtype.numpy_dtype - # TODO: get good chunk size? return da.map_blocks( get_chunk, - chunks=(compute_chunk_layout_for_axis_size(1000, col.shape[0]),), + chunks=chunk_size, meta=np.array([], dtype=dtype), dtype=dtype, ) @@ -1185,7 +1192,7 @@ def get_attrs(annotations: Iterable[Dataset2D]) -> dict: """ index_names = np.unique([a.index.name for a in annotations]) assert len(index_names) == 1, "All annotations must have the same index name." - if any(a.index.dtype == "int64" for a in annotations): + if any(np.issubdtype(a.index.dtype, np.integer) for a in annotations): msg = "Concatenating with a pandas numeric index among the indices. Index may likely not be unique." warn(msg, UserWarning) index_keys = [ diff --git a/src/anndata/_io/specs/lazy_methods.py b/src/anndata/_io/specs/lazy_methods.py index 27dc4c992..c247d304d 100644 --- a/src/anndata/_io/specs/lazy_methods.py +++ b/src/anndata/_io/specs/lazy_methods.py @@ -2,7 +2,7 @@ import re from contextlib import contextmanager -from functools import partial +from functools import partial, singledispatch from pathlib import Path from typing import TYPE_CHECKING, overload @@ -92,6 +92,13 @@ def make_dask_chunk( return chunk +@singledispatch +def get_chunksize(obj) -> tuple[int, ...]: + if hasattr(obj, "chunks"): + return obj.chunks + raise ValueError("object of type {type(obj)} has no recognized chunks") + + @_LAZY_REGISTRY.register_read(H5Group, IOSpec("csc_matrix", "0.1.0")) @_LAZY_REGISTRY.register_read(H5Group, IOSpec("csr_matrix", "0.1.0")) @_LAZY_REGISTRY.register_read(ZarrGroup, IOSpec("csc_matrix", "0.1.0")) diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 6ee4eb404..e928847c1 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -7,6 +7,7 @@ from anndata._core.index import _subset from anndata._core.views import as_view +from anndata._io.specs.lazy_methods import get_chunksize from anndata.compat import H5Array, ZarrArray from ..._settings import settings @@ -28,6 +29,7 @@ class ZarrOrHDF5Wrapper(ZarrArrayWrapper, Generic[K]): def __init__(self, array: K): + self.chunks = array.chunks if isinstance(array, ZarrArray): return super().__init__(array) self._array = array @@ -152,3 +154,13 @@ def _subset_masked(a: DataArray, subset_idx: Index): @as_view.register(DataArray) def _view_pd_boolean_array(a: DataArray, view_args): return a + + +@get_chunksize.register(MaskedArray) +def _(a: MaskedArray): + return get_chunksize(a._values) + + +@get_chunksize.register(CategoricalArray) +def _(a: CategoricalArray): + return get_chunksize(a._codes)