From 8a57bfb4a94cdfca99fb6070a703b3c78d7b129b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Oskar=20B=C3=B6rjesson?= Date: Mon, 30 Sep 2024 09:04:52 +0200 Subject: [PATCH] Use target cell_key when mapping (#122) * Use target cell_key when mapping * Fix test --- scarf/datastore/mapping_datastore.py | 18 +++++++++++++++--- scarf/mapping_utils.py | 11 ++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/scarf/datastore/mapping_datastore.py b/scarf/datastore/mapping_datastore.py index 2a072e3..5aef6d4 100644 --- a/scarf/datastore/mapping_datastore.py +++ b/scarf/datastore/mapping_datastore.py @@ -33,6 +33,7 @@ def run_mapping( target_assay: Assay, target_name: str, target_feat_key: str, + target_cell_key: str = "I", from_assay: Optional[str] = None, cell_key: str = "I", feat_key: Optional[str] = None, @@ -57,6 +58,7 @@ def run_mapping( target_name: Name of target data. This used to keep track of projections in the Zarr hierarchy target_feat_key: This will be used to name wherein the normalized target data will be saved in its own zarr hierarchy. + target_cell_key: Cell key for the target data. (Default value: 'I') from_assay: Name of assay to be used. If no value is provided then the default assay will be used. cell_key: Cell key. Should be same as the one that was used in the desired graph. (Default value: 'I') feat_key: Feature key. Should be same as the one that was used in the desired graph. By default, the latest @@ -119,6 +121,7 @@ def run_mapping( cell_key, feat_key, target_feat_key, + target_cell_key, filter_null, exclude_missing, self.nthreads, @@ -151,15 +154,23 @@ def run_mapping( logger.warning(f"`save_k` was decreased to {ann_obj.k}") save_k = ann_obj.k target_data = daskarr.from_zarr( - target_assay.z[f"normed__I__{target_feat_key}/data"], inline_array=True + target_assay.z[f"normed__{target_cell_key}__{target_feat_key}/data"], + inline_array=True, ) if run_coral is True: # Reversing coral here to correct target data coral( - target_data, ann_obj.data, target_assay, target_feat_key, self.nthreads + target_data, + ann_obj.data, + target_assay, + target_feat_key, + target_cell_key, + self.nthreads, ) target_data = daskarr.from_zarr( - target_assay.z[f"normed__I__{target_feat_key}/data_coral"], + target_assay.z[ + f"normed__{target_cell_key}__{target_feat_key}/data_coral" + ], inline_array=True, ) if ann_obj.method == "pca" and run_coral is False: @@ -328,6 +339,7 @@ def get_target_classes( store = self.zw[store_loc] indices = store["indices"][:] dists = store["distances"][:] + preds = [] weights = 1 - (dists / dists.max(axis=1).reshape(-1, 1)) for n in range(indices.shape[0]): diff --git a/scarf/mapping_utils.py b/scarf/mapping_utils.py index 76ab5aa..7bfd81e 100644 --- a/scarf/mapping_utils.py +++ b/scarf/mapping_utils.py @@ -38,7 +38,7 @@ def _correlation_alignment(s: daskarr, t: daskarr, nthreads: int) -> daskarr: return daskarr.dot(s, a_coral) -def coral(source_data, target_data, assay, feat_key: str, nthreads: int): +def coral(source_data, target_data, assay, feat_key: str, cell_key: str, nthreads: int): """Applies CORAL error correction to the input data. Args: @@ -46,6 +46,7 @@ def coral(source_data, target_data, assay, feat_key: str, nthreads: int): target_data (): assay (): feat_key (): + cell_key (): nthreads (): """ from .writers import dask_to_zarr @@ -87,7 +88,7 @@ def coral(source_data, target_data, assay, feat_key: str, nthreads: int): dask_to_zarr( data, assay.z["/"], - f"{assay.z.name}/normed__I__{feat_key}/data_coral", + f"{assay.z.name}/normed__{cell_key}__{feat_key}/data_coral", 1000, nthreads, msg="Writing out coral corrected data", @@ -149,6 +150,7 @@ def align_features( source_cell_key: str, source_feat_key: str, target_feat_key: str, + target_cell_key: str, filter_null: bool, exclude_missing: bool, nthreads: int, @@ -185,11 +187,10 @@ def align_features( norm_params = source_assay.z[normed_loc].attrs["subset_params"] sorted_t_idx = np.array(sorted(t_idx[t_idx != -1])) - # TODO: add target cell key normed_data = target_assay.normed( - target_assay.cells.active_index("I"), sorted_t_idx, **norm_params + target_assay.cells.active_index(target_cell_key), sorted_t_idx, **norm_params ) - loc = f"{target_assay.z.name}/normed__I__{target_feat_key}/data" + loc = f"{target_assay.z.name}/normed__{target_cell_key}__{target_feat_key}/data" og = create_zarr_dataset( target_assay.z["/"], loc, (1000,), "float64", (normed_data.shape[0], len(t_idx))