Skip to content

Commit

Permalink
feat: Add device parameter to CellMapDataset, CellMapDataSplit, and C…
Browse files Browse the repository at this point in the history
…ellMapDataLoader for flexible device management
  • Loading branch information
rhoadesScholar committed Dec 13, 2024
1 parent d8d3ce1 commit eda089e
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
pip install --upgrade ".[test]"
- name: Test with pytest
run: |
pytest --color=yes --cov=cellmap_models --cov-report=xml --cov-report=term-missing tests
pytest --color=yes --cov=cellmap_data --cov-report=xml --cov-report=term-missing tests
# Coverage should work out of the box for public repos. For private repos, more setup is likely required.
- name: Coverage
uses: codecov/codecov-action@v5
12 changes: 10 additions & 2 deletions src/cellmap_data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
sampler: Sampler | Callable | None = None,
is_train: bool = True,
rng: Optional[torch.Generator] = None,
device: Optional[str | torch.device] = None,
**kwargs,
):
"""
Expand All @@ -52,6 +53,7 @@ def __init__(
sampler (Sampler | Callable | None): The sampler to use.
is_train (bool): Whether the data is for training and thus should be shuffled.
rng (Optional[torch.Generator]): The random number generator to use.
device (Optional[str | torch.device]): The device to use. Defaults to "cuda" or "mps" if available, else "cpu".
`**kwargs`: Additional arguments to pass to the DataLoader.
"""
Expand All @@ -63,13 +65,19 @@ def __init__(
self.sampler = sampler
self.is_train = is_train
self.rng = rng
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.dataset.to(device)
if self.sampler is None and self.weighted_sampler:
assert isinstance(
self.dataset, CellMapMultiDataset
), "Weighted sampler only relevant for CellMapMultiDataset"
self.sampler = self.dataset.get_weighted_sampler(self.batch_size, self.rng)
if torch.cuda.is_available():
self.dataset.to("cuda")
self.default_kwargs = kwargs.copy()
kwargs.update(
{
Expand Down
5 changes: 5 additions & 0 deletions src/cellmap_data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
force_has_data: bool = False,
empty_value: float | int = torch.nan,
pad: bool = False,
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes the CellMapDataset class.
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
force_has_data (bool, optional): Whether to force the dataset to report that it has data. Defaults to False.
empty_value (float | int, optional): The value to fill in for empty data. Defaults to torch.nan.
pad (bool, optional): Whether to pad the image data to match requested arrays. Defaults to False.
device (Optional[str | torch.device], optional): The device for the dataset. Defaults to None. If None, the device will be set to "cuda" if available, "mps" if available, or "cpu" if neither are available.
"""
self.raw_path = raw_path
Expand All @@ -166,6 +168,8 @@ def __init__(
self._current_center = None
self._current_spatial_transforms = None
self.input_sources: dict[str, CellMapImage] = {}
if device is not None:
self._device = torch.device(device)
for array_name, array_info in self.input_arrays.items():
self.input_sources[array_name] = CellMapImage(
self.raw_path,
Expand All @@ -184,6 +188,7 @@ def __init__(
self.has_data = False
for array_name, array_info in self.target_arrays.items():
self.target_sources[array_name] = self.get_target_array(array_info)
self.to(self.device)

@property
def center(self) -> Mapping[str, float] | None:
Expand Down
20 changes: 20 additions & 0 deletions src/cellmap_data/datasplit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CellMapDataSplit:
force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data.
context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None.
device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None.
Note:
The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied.
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
class_relation_dict: Optional[Mapping[str, Sequence[str]]] = None,
force_has_data: bool = False,
context: Optional[tensorstore.Context] = None, # type: ignore
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes the CellMapDatasets class.
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(
force_has_data (bool): Whether to force the datasets to have data even if no ground truth data is found. Defaults to False. Useful for training with only raw data.
context (Optional[tensorstore.Context]): The TensorStore context for the image data. Defaults to None.
device (Optional[str | torch.device]): Device to use for the dataloaders. Defaults to None.
Note:
The csv_path, dataset_dict, and datasets arguments are mutually exclusive, but one must be supplied.
Expand All @@ -170,6 +173,7 @@ def __init__(
self.classes = classes
self.empty_value = empty_value
self.pad = pad
self.device = device
if isinstance(pad, str):
self.pad_training = pad.lower() == "train"
self.pad_validation = pad.lower() == "validate"
Expand Down Expand Up @@ -309,6 +313,7 @@ def construct(self, dataset_dict) -> None:
empty_value=self.empty_value,
class_relation_dict=self.class_relation_dict,
pad=self.pad_training,
device=self.device,
)
)
except ValueError as e:
Expand Down Expand Up @@ -336,6 +341,7 @@ def construct(self, dataset_dict) -> None:
empty_value=self.empty_value,
class_relation_dict=self.class_relation_dict,
pad=self.pad_validation,
device=self.device,
)
)
except ValueError as e:
Expand Down Expand Up @@ -437,3 +443,17 @@ def set_arrays(
for attr in reset_attrs:
if hasattr(self, attr):
delattr(self, attr)

def to(self, device: str | torch.device) -> None:
"""Sets the device for the dataloaders."""
self.device = device
for dataset in self.train_datasets:
dataset.to(device)
for dataset in self.validation_datasets:
dataset.to(device)
if hasattr(self, "_train_datasets_combined"):
self._train_datasets_combined.to(device)
if hasattr(self, "_validation_datasets_combined"):
self._validation_datasets_combined.to(device)
if hasattr(self, "_validation_blocks"):
self._validation_blocks.to(device)
6 changes: 5 additions & 1 deletion src/cellmap_data/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
axis_order: str | Sequence[str] = "zyx",
value_transform: Optional[Callable] = None,
context: Optional[tensorstore.Context] = None, # type: ignore
device: Optional[str | torch.device] = None,
) -> None:
"""Initializes a CellMapImage object.
Expand All @@ -87,6 +88,7 @@ def __init__(
axis_order (str, optional): The order of the axes in the image. Defaults to "zyx".
value_transform (Optional[callable], optional): A function to transform the image pixel data. Defaults to None.
context (Optional[tensorstore.Context], optional): The context for the image data. Defaults to None.
device (Optional[str | torch.device], optional): The device to load the image data onto. Defaults to "cuda" if available, then "mps", then "cpu".
"""

self.path = path
Expand Down Expand Up @@ -114,7 +116,9 @@ def __init__(
self._current_spatial_transforms = None
self._current_coords = None
self._current_center = None
if torch.cuda.is_available():
if device is not None:
self.device = device
elif torch.cuda.is_available():
self.device = "cuda"
elif torch.backends.mps.is_available():
self.device = "mps"
Expand Down

0 comments on commit eda089e

Please sign in to comment.