Skip to content

Commit

Permalink
feat: Enable padding by default in CellMapDataset and update raw valu…
Browse files Browse the repository at this point in the history
…e transforms in CellMapDataSplit
  • Loading branch information
rhoadesScholar committed Jan 30, 2025
1 parent 3f7ef68 commit d780256
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
rng: Optional[torch.Generator] = None,
force_has_data: bool = False,
empty_value: float | int = torch.nan,
pad: bool = False,
pad: bool = True,
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes the CellMapDataset class.
Expand Down
30 changes: 22 additions & 8 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from typing import Any, Callable, Mapping, Optional, Sequence
import tensorstore
import torch
import torchvision.transforms.v2 as T
from tqdm import tqdm
from .transforms.augment import NaNtoNum, Normalize, Binarize
from .dataset import CellMapDataset
from .multidataset import CellMapMultiDataset
from .subdataset import CellMapSubset
Expand Down Expand Up @@ -51,9 +53,9 @@ class CellMapDataSplit:
{transform_name: {transform_args}}
train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Defaults to None. Example is to add gaussian noise to the raw data.
val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Defaults to None. Example is to normalize the raw data.
target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Defaults to None. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order.
train_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in training datasets. Example is to add gaussian noise to the raw data. Defailts to Normalize, ToDtype, and NaNtoNum.
val_raw_value_transforms (Optional[Callable]): A function to apply to the raw data in validation datasets. Example is to normalize the raw data. Defaults to Normalize, ToDtype, and NaNtoNum.
target_value_transforms (Optional[Callable | Sequence[Callable] | Mapping[str, Callable]]): A function to convert the ground truth data to target arrays. Example is to convert the ground truth data to a signed distance transform. May be a single function, a list of functions, or a dictionary of functions for each class. In the case of a list of functions, it is assumed that the functions correspond to each class in the classes list in order. Defaults to ToDtype and Binarize.
class_relation_dict (Optional[Mapping[str, Sequence[str]]]): A dictionary containing the class relations. The dictionary should have the following structure::
{
Expand Down Expand Up @@ -97,11 +99,23 @@ def __init__(
dataset_dict: Optional[Mapping[str, Sequence[Mapping[str, str]]]] = None,
csv_path: Optional[str] = None,
spatial_transforms: Optional[Mapping[str, Any]] = None,
train_raw_value_transforms: Optional[Callable] = None,
val_raw_value_transforms: Optional[Callable] = None,
target_value_transforms: Optional[
Callable | Sequence[Callable] | Mapping[str, Callable]
] = None,
train_raw_value_transforms: Optional[T.Transform] = T.Compose(
[
Normalize(),
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
val_raw_value_transforms: Optional[T.Transform] = T.Compose(
[
Normalize(),
T.ToDtype(torch.float, scale=True),
NaNtoNum({"nan": 0, "posinf": None, "neginf": None}),
],
),
target_value_transforms: Optional[T.Transform] = T.Compose(
[T.ToDtype(torch.float), Binarize()]
),
class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None,
force_has_data: bool = False,
context: Optional[tensorstore.Context] = None, # type: ignore
Expand Down

0 comments on commit d780256

Please sign in to comment.