-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_01_preprocess.py
More file actions
345 lines (307 loc) · 13.2 KB
/
_01_preprocess.py
File metadata and controls
345 lines (307 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
from typing import Dict, Optional, Union
import numpy as np
from scipy.sparse import issparse
import scanpy as sc
from scanpy.get import _get_obs_rep, _set_obs_rep
from anndata import AnnData
import logging
import torch
# Configure basic logger
logger = logging.getLogger(__name__)
class Preprocessor:
def __init__(
self,
use_key: Optional[str] = None,
filter_gene_by_counts: Union[int, bool] = False,
filter_cell_by_counts: Union[int, bool] = False,
normalize_total: Union[float, bool] = 1e4,
result_normed_key: Optional[str] = "X_normed",
log1p: bool = False,
result_log1p_key: str = "X_log1p",
subset_hvg: Union[int, bool] = False,
hvg_use_key: Optional[str] = None,
hvg_flavor: str = "seurat_v3",
binning: Optional[int] = None,
result_binned_key: str = "X_binned",
):
self.use_key = use_key
self.filter_gene_by_counts = filter_gene_by_counts
self.filter_cell_by_counts = filter_cell_by_counts
self.normalize_total = normalize_total
self.result_normed_key = result_normed_key
self.log1p = log1p
self.result_log1p_key = result_log1p_key
self.subset_hvg = subset_hvg
self.hvg_use_key = hvg_use_key
self.hvg_flavor = hvg_flavor
self.binning = binning
self.result_binned_key = result_binned_key
def __call__(self, adata: AnnData, batch_key: Optional[str] = None) -> Dict:
# The process remains the same, with logger.info replacing the scgpt logger calls
# Implement the rest of the method as in your original script
"""
format controls the different input value wrapping, including categorical
binned style, fixed-sum normalized counts, log1p fixed-sum normalized counts, etc.
Args:
adata (:class:`AnnData`):
The :class:`AnnData` object to preprocess.
batch_key (:class:`str`, optional):
The key of :class:`AnnData.obs` to use for batch information. This arg
is used in the highly variable gene selection step.
"""
key_to_process = self.use_key
# preliminary checks, will use later
if key_to_process == "X":
key_to_process = None # the following scanpy apis use arg None to use X
is_logged = self.check_logged(adata, obs_key=key_to_process)
# step 1: filter genes
if self.filter_gene_by_counts:
logger.info("Filtering genes by counts ...")
sc.pp.filter_genes(
adata,
min_counts=self.filter_gene_by_counts
if isinstance(self.filter_gene_by_counts, int)
else None,
)
# step 2: filter cells
if (
isinstance(self.filter_cell_by_counts, int)
and self.filter_cell_by_counts > 0
):
logger.info("Filtering cells by counts ...")
sc.pp.filter_cells(
adata,
min_counts=self.filter_cell_by_counts
if isinstance(self.filter_cell_by_counts, int)
else None,
)
# moved subset_hvg here
if self.subset_hvg:
logger.info("Subsetting highly variable genes ...")
if batch_key is None:
logger.warning(
"No batch_key is provided, will use all cells for HVG selection."
)
sc.pp.highly_variable_genes(
adata,
layer=self.hvg_use_key,
n_top_genes=self.subset_hvg
if isinstance(self.subset_hvg, int)
else None,
batch_key=batch_key,
flavor=self.hvg_flavor,
subset=True,
span = 1,
)
# step 3: normalize total
if self.normalize_total:
logger.info("Normalizing total counts ...")
normed_ = sc.pp.normalize_total(
adata,
target_sum=self.normalize_total
if isinstance(self.normalize_total, float)
else None,
layer=key_to_process,
inplace=False,
)["X"]
key_to_process = self.result_normed_key or key_to_process
_set_obs_rep(adata, normed_, layer=key_to_process)
"""
# step 4: log1p
if self.log1p:
logger.info("Log1p transforming ...")
if is_logged:
logger.warning(
"The input data seems to be already log1p transformed. "
"Set `log1p=False` to avoid double log1p transform."
)
if self.result_log1p_key:
_set_obs_rep(
adata,
_get_obs_rep(adata, layer=key_to_process),
layer=self.result_log1p_key,
)
key_to_process = self.result_log1p_key
sc.pp.log1p(adata, layer=key_to_process)
"""
# step 4: log1p (skips if input data seems to be already log1p transformed)
if self.log1p:
logger.info("Checking for prior log1p transformation...")
if self.check_logged(adata, obs_key=key_to_process):
logger.warning(
"The input data seems to be already log1p transformed. "
"Skipping log1p to avoid double transformation."
)
else:
logger.info("Applying log1p transformation...")
if self.result_log1p_key:
_set_obs_rep(
adata,
_get_obs_rep(adata, layer=key_to_process),
layer=self.result_log1p_key,
)
key_to_process = self.result_log1p_key
sc.pp.log1p(adata, layer=key_to_process)
"""
# step 5: subset hvg
if self.subset_hvg:
logger.info("Subsetting highly variable genes ...")
if batch_key is None:
logger.warning(
"No batch_key is provided, will use all cells for HVG selection."
)
sc.pp.highly_variable_genes(
adata,
layer=self.hvg_use_key,
n_top_genes=self.subset_hvg
if isinstance(self.subset_hvg, int)
else None,
batch_key=batch_key,
flavor=self.hvg_flavor,
subset=True,
)
"""
"""
# step 6 (version 1): binning without call to standalone binning function()
if self.binning:
logger.info("Binning data ...")
if not isinstance(self.binning, int):
raise ValueError(
"Binning arg must be an integer, but got {}.".format(self.binning)
)
n_bins = self.binning # NOTE: the first bin is always a spectial for zero
binned_rows = []
bin_edges = []
layer_data = _get_obs_rep(adata, layer=key_to_process)
layer_data = layer_data.A if issparse(layer_data) else layer_data
if layer_data.min() < 0:
raise ValueError(
f"Assuming non-negative data, but got min value {layer_data.min()}."
)
for row in layer_data:
if row.max() == 0:
logger.warning(
"The input data contains all zero rows. Please make sure "
"this is expected. You can use the `filter_cell_by_counts` "
"arg to filter out all zero rows."
)
binned_rows.append(np.zeros_like(row, dtype=np.int64))
bin_edges.append(np.array([0] * n_bins))
continue
non_zero_ids = row.nonzero()
non_zero_row = row[non_zero_ids]
bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
# bins = np.sort(np.unique(bins))
# NOTE: comment this line for now, since this will make the each category
# has different relative meaning across datasets
non_zero_digits = _digitize(non_zero_row, bins)
assert non_zero_digits.min() >= 1
assert non_zero_digits.max() <= n_bins - 1
binned_row = np.zeros_like(row, dtype=np.int64)
binned_row[non_zero_ids] = non_zero_digits
binned_rows.append(binned_row)
bin_edges.append(np.concatenate([[0], bins]))
adata.layers[self.result_binned_key] = np.stack(binned_rows)
adata.obsm["bin_edges"] = np.stack(bin_edges)
"""
# step 6: binning
if self.binning:
logger.info("Binning data ...")
if not isinstance(self.binning, int):
raise ValueError(f"Binning arg must be an integer, but got {self.binning}")
# Prepare data for binning
layer_data = _get_obs_rep(adata, layer=self.use_key)
layer_data = layer_data.A if issparse(layer_data) else layer_data
# Initialize a list to store binned data rows
binned_rows = []
# Process each row with the binning function
for row in layer_data:
if isinstance(row, np.ndarray):
row_data = row # If the data is already a NumPy array, use it directly
else:
row_data = row.numpy() # If the data is a tensor, convert to NumPy array
# Call the higher-level binning function
binned_row = binning(row_data, self.binning)
# If the binning function returns a tensor, convert it to a NumPy array
if isinstance(binned_row, torch.Tensor):
binned_row = binned_row.numpy()
# Append the binned row to the list
binned_rows.append(binned_row)
# Convert the list of binned rows into a NumPy array and store in the specified layer
adata.layers[self.result_binned_key] = np.stack(binned_rows)
def check_logged(self, adata: AnnData, obs_key: Optional[str] = None) -> bool:
# Method implementation remains the same
"""
Check if the data is already log1p transformed.
Args:
adata (:class:`AnnData`):
The :class:`AnnData` object to preprocess.
obs_key (:class:`str`, optional):
The key of :class:`AnnData.obs` to use for batch information. This arg
is used in the highly variable gene selection step.
"""
data = _get_obs_rep(adata, layer=obs_key)
max_, min_ = data.max(), data.min()
if max_ > 30:
return False
if min_ < 0:
return False
non_zero_min = data[data > 0].min()
if non_zero_min >= 1:
return False
return True
def _digitize(x: np.ndarray, bins: np.ndarray, side="both") -> np.ndarray:
# Method implementation remains the same
"""
Digitize the data into bins. This method spreads data uniformly when bins
have same values.
Args:
x (:class:`np.ndarray`):
The data to digitize.
bins (:class:`np.ndarray`):
The bins to use for digitization, in increasing order.
side (:class:`str`, optional):
The side to use for digitization. If "one", the left side is used. If
"both", the left and right side are used. Default to "one".
Returns:
:class:`np.ndarray`:
The digitized data.
"""
assert x.ndim == 1 and bins.ndim == 1
left_digits = np.digitize(x, bins)
if side == "one":
return left_digits
right_difits = np.digitize(x, bins, right=True)
rands = np.random.rand(len(x)) # uniform random numbers
digits = rands * (right_difits - left_digits) + left_digits
digits = np.ceil(digits).astype(np.int64)
return digits
# You can add additional helper functions as required, ensuring they do not depend on scgpt
def binning(
row: Union[np.ndarray, torch.Tensor], n_bins: int
) -> Union[np.ndarray, torch.Tensor]:
"""Binning the row into n_bins."""
dtype = row.dtype
return_np = False if isinstance(row, torch.Tensor) else True
row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row
# TODO: use torch.quantile and torch.bucketize
if row.max() == 0:
logger.warning(
"The input data contains row of zeros. Please make sure this is expected."
)
return (
np.zeros_like(row, dtype=dtype)
if return_np
else torch.zeros_like(row, dtype=dtype)
)
if row.min() <= 0:
non_zero_ids = row.nonzero()
non_zero_row = row[non_zero_ids]
bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
non_zero_digits = _digitize(non_zero_row, bins)
binned_row = np.zeros_like(row, dtype=np.int64)
binned_row[non_zero_ids] = non_zero_digits
else:
bins = np.quantile(row, np.linspace(0, 1, n_bins - 1))
binned_row = _digitize(row, bins)
return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype)