|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from collections import defaultdict |
| 4 | +from itertools import product |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import pandas as pd |
| 8 | +from anndata import AnnData |
| 9 | +from scanpy import logging as logg |
| 10 | +from spatialdata import SpatialData |
| 11 | + |
| 12 | +from squidpy._docs import d |
| 13 | +from squidpy.gr._utils import _save_data |
| 14 | + |
| 15 | +__all__ = ["sliding_window"] |
| 16 | + |
| 17 | + |
| 18 | +@d.dedent |
| 19 | +def sliding_window( |
| 20 | + adata: AnnData | SpatialData, |
| 21 | + library_key: str | None = None, |
| 22 | + window_size: int | None = None, |
| 23 | + overlap: int = 0, |
| 24 | + coord_columns: tuple[str, str] = ("globalX", "globalY"), |
| 25 | + sliding_window_key: str = "sliding_window_assignment", |
| 26 | + spatial_key: str = "spatial", |
| 27 | + drop_partial_windows: bool = False, |
| 28 | + copy: bool = False, |
| 29 | +) -> pd.DataFrame | None: |
| 30 | + """ |
| 31 | + Divide a tissue slice into regulary shaped spatially contiguous regions (windows). |
| 32 | +
|
| 33 | + Parameters |
| 34 | + ---------- |
| 35 | + %(adata)s |
| 36 | + window_size: int |
| 37 | + Size of the sliding window. |
| 38 | + %(library_key)s |
| 39 | + coord_columns: Tuple[str, str] |
| 40 | + Tuple of column names in `adata.obs` that specify the coordinates (x, y), e.i. ('globalX', 'globalY') |
| 41 | + sliding_window_key: str |
| 42 | + Base name for sliding window columns. |
| 43 | + overlap: int |
| 44 | + Overlap size between consecutive windows. (0 = no overlap) |
| 45 | + %(spatial_key)s |
| 46 | + drop_partial_windows: bool |
| 47 | + If True, drop windows that are smaller than the window size at the borders. |
| 48 | + copy: bool |
| 49 | + If True, return the result, otherwise save it to the adata object. |
| 50 | +
|
| 51 | + Returns |
| 52 | + ------- |
| 53 | + If ``copy = True``, returns the sliding window annotation(s) as pandas dataframe |
| 54 | + Otherwise, stores the sliding window annotation(s) in .obs. |
| 55 | + """ |
| 56 | + if overlap < 0: |
| 57 | + raise ValueError("Overlap must be non-negative.") |
| 58 | + |
| 59 | + if isinstance(adata, SpatialData): |
| 60 | + adata = adata.table |
| 61 | + |
| 62 | + # we don't want to modify the original adata in case of copy=True |
| 63 | + if copy: |
| 64 | + adata = adata.copy() |
| 65 | + |
| 66 | + # extract coordinates of observations |
| 67 | + x_col, y_col = coord_columns |
| 68 | + if x_col in adata.obs and y_col in adata.obs: |
| 69 | + coords = adata.obs[[x_col, y_col]].copy() |
| 70 | + elif spatial_key in adata.obsm: |
| 71 | + coords = pd.DataFrame( |
| 72 | + adata.obsm[spatial_key][:, :2], |
| 73 | + index=adata.obs.index, |
| 74 | + columns=[x_col, y_col], |
| 75 | + ) |
| 76 | + else: |
| 77 | + raise ValueError( |
| 78 | + f"Coordinates not found. Provide `{coord_columns}` in `adata.obs` or specify a suitable `spatial_key` in `adata.obsm`." |
| 79 | + ) |
| 80 | + |
| 81 | + # infer window size if not provided |
| 82 | + if window_size is None: |
| 83 | + coord_range = max( |
| 84 | + coords[x_col].max() - coords[x_col].min(), |
| 85 | + coords[y_col].max() - coords[y_col].min(), |
| 86 | + ) |
| 87 | + # mostly arbitrary choice, except that full integers usually generate windows with 1-2 cells at the borders |
| 88 | + window_size = max(int(np.floor(coord_range // 3.95)), 1) |
| 89 | + |
| 90 | + if window_size <= 0: |
| 91 | + raise ValueError("Window size must be larger than 0.") |
| 92 | + |
| 93 | + if library_key is not None and library_key not in adata.obs: |
| 94 | + raise ValueError(f"Library key '{library_key}' not found in adata.obs") |
| 95 | + |
| 96 | + libraries = [None] if library_key is None else adata.obs[library_key].unique() |
| 97 | + |
| 98 | + # Create a DataFrame to store the sliding window assignments |
| 99 | + sliding_window_df = pd.DataFrame(index=adata.obs.index) |
| 100 | + |
| 101 | + if sliding_window_key in adata.obs: |
| 102 | + logg.warning(f"Overwriting existing column '{sliding_window_key}' in adata.obs.") |
| 103 | + |
| 104 | + for lib in libraries: |
| 105 | + if lib is not None: |
| 106 | + lib_mask = adata.obs[library_key] == lib |
| 107 | + lib_coords = coords.loc[lib_mask] |
| 108 | + else: |
| 109 | + lib_mask = np.ones(len(adata), dtype=bool) |
| 110 | + lib_coords = coords |
| 111 | + |
| 112 | + min_x, max_x = lib_coords[x_col].min(), lib_coords[x_col].max() |
| 113 | + min_y, max_y = lib_coords[y_col].min(), lib_coords[y_col].max() |
| 114 | + |
| 115 | + # precalculate windows |
| 116 | + windows = _calculate_window_corners( |
| 117 | + min_x=min_x, |
| 118 | + max_x=max_x, |
| 119 | + min_y=min_y, |
| 120 | + max_y=max_y, |
| 121 | + window_size=window_size, |
| 122 | + overlap=overlap, |
| 123 | + drop_partial_windows=drop_partial_windows, |
| 124 | + ) |
| 125 | + |
| 126 | + lib_key = f"{lib}_" if lib is not None else "" |
| 127 | + |
| 128 | + # assign observations to windows |
| 129 | + for idx, window in windows.iterrows(): |
| 130 | + x_start = window["x_start"] |
| 131 | + x_end = window["x_end"] |
| 132 | + y_start = window["y_start"] |
| 133 | + y_end = window["y_end"] |
| 134 | + |
| 135 | + mask = ( |
| 136 | + (lib_coords[x_col] >= x_start) |
| 137 | + & (lib_coords[x_col] <= x_end) |
| 138 | + & (lib_coords[y_col] >= y_start) |
| 139 | + & (lib_coords[y_col] <= y_end) |
| 140 | + ) |
| 141 | + obs_indices = lib_coords.index[mask] |
| 142 | + |
| 143 | + if overlap == 0: |
| 144 | + mask = ( |
| 145 | + (lib_coords[x_col] >= x_start) |
| 146 | + & (lib_coords[x_col] <= x_end) |
| 147 | + & (lib_coords[y_col] >= y_start) |
| 148 | + & (lib_coords[y_col] <= y_end) |
| 149 | + ) |
| 150 | + obs_indices = lib_coords.index[mask] |
| 151 | + sliding_window_df.loc[obs_indices, sliding_window_key] = f"{lib_key}window_{idx}" |
| 152 | + |
| 153 | + else: |
| 154 | + col_name = f"{sliding_window_key}_{lib_key}window_{idx}" |
| 155 | + sliding_window_df.loc[obs_indices, col_name] = True |
| 156 | + sliding_window_df.loc[:, col_name].fillna(False, inplace=True) |
| 157 | + |
| 158 | + if overlap == 0: |
| 159 | + # create categorical variable for ordered windows |
| 160 | + sliding_window_df[sliding_window_key] = pd.Categorical( |
| 161 | + sliding_window_df[sliding_window_key], |
| 162 | + ordered=True, |
| 163 | + categories=sorted( |
| 164 | + sliding_window_df[sliding_window_key].unique(), |
| 165 | + key=lambda x: int(x.split("_")[-1]), |
| 166 | + ), |
| 167 | + ) |
| 168 | + |
| 169 | + sliding_window_df[x_col] = coords[x_col] |
| 170 | + sliding_window_df[y_col] = coords[y_col] |
| 171 | + |
| 172 | + if copy: |
| 173 | + return sliding_window_df |
| 174 | + for col_name, col_data in sliding_window_df.items(): |
| 175 | + _save_data(adata, attr="obs", key=col_name, data=col_data) |
| 176 | + |
| 177 | + |
| 178 | +def _calculate_window_corners( |
| 179 | + min_x: int, |
| 180 | + max_x: int, |
| 181 | + min_y: int, |
| 182 | + max_y: int, |
| 183 | + window_size: int, |
| 184 | + overlap: int = 0, |
| 185 | + drop_partial_windows: bool = False, |
| 186 | +) -> pd.DataFrame: |
| 187 | + """ |
| 188 | + Calculate the corner points of all windows covering the area from min_x to max_x and min_y to max_y, |
| 189 | + with specified window_size and overlap. |
| 190 | +
|
| 191 | + Parameters |
| 192 | + ---------- |
| 193 | + min_x: float |
| 194 | + minimum X coordinate |
| 195 | + max_x: float |
| 196 | + maximum X coordinate |
| 197 | + min_y: float |
| 198 | + minimum Y coordinate |
| 199 | + max_y: float |
| 200 | + maximum Y coordinate |
| 201 | + window_size: float |
| 202 | + size of each window |
| 203 | + overlap: float |
| 204 | + overlap between consecutive windows (must be less than window_size) |
| 205 | + drop_partial_windows: bool |
| 206 | + if True, drop border windows that are smaller than window_size; |
| 207 | + if False, create smaller windows at the borders to cover the remaining space. |
| 208 | +
|
| 209 | + Returns |
| 210 | + ------- |
| 211 | + windows: pandas DataFrame with columns ['x_start', 'x_end', 'y_start', 'y_end'] |
| 212 | + """ |
| 213 | + if overlap < 0: |
| 214 | + raise ValueError("Overlap must be non-negative.") |
| 215 | + if overlap >= window_size: |
| 216 | + raise ValueError("Overlap must be less than the window size.") |
| 217 | + |
| 218 | + x_step = window_size - overlap |
| 219 | + y_step = window_size - overlap |
| 220 | + |
| 221 | + # Generate starting points |
| 222 | + x_starts = np.arange(min_x, max_x, x_step) |
| 223 | + y_starts = np.arange(min_y, max_y, y_step) |
| 224 | + |
| 225 | + # Create all combinations of x and y starting points |
| 226 | + starts = list(product(x_starts, y_starts)) |
| 227 | + windows = pd.DataFrame(starts, columns=["x_start", "y_start"]) |
| 228 | + windows["x_end"] = windows["x_start"] + window_size |
| 229 | + windows["y_end"] = windows["y_start"] + window_size |
| 230 | + |
| 231 | + # Adjust windows that extend beyond the bounds |
| 232 | + if not drop_partial_windows: |
| 233 | + windows["x_end"] = windows["x_end"].clip(upper=max_x) |
| 234 | + windows["y_end"] = windows["y_end"].clip(upper=max_y) |
| 235 | + else: |
| 236 | + valid_windows = (windows["x_end"] <= max_x) & (windows["y_end"] <= max_y) |
| 237 | + windows = windows[valid_windows] |
| 238 | + |
| 239 | + windows = windows.reset_index(drop=True) |
| 240 | + return windows[["x_start", "x_end", "y_start", "y_end"]] |
0 commit comments