Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization tensor v2 tests #1519

Merged
merged 4 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ jobs:
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
run: pip install torch==2.2.0 torchvision==0.17.0
- name: Install dev requirements
run: pip install -r requirements-dev.txt
- name: Install dependencies
run: |
pip install .[tests]
Expand All @@ -61,7 +63,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install linters
- name: Install dev requirements
run: pip install -r requirements-dev.txt
- name: Run checks
run: pre-commit run --files $(find albumentations -type f)
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,7 @@ conda_build/

.vscode/
conda.recipe/

.gitingore

*.ipynb
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ repos:
- id: mixed-line-ending
- id: destroyed-symlinks
- id: fix-byte-order-marker
- id: pretty-format-json
- id: check-json
- id: check-yaml
args: [ --unsafe ]
Expand Down
4 changes: 1 addition & 3 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,9 +968,7 @@ def bbox_flip(bbox: BoxInternalType, d: int, rows: int, cols: int) -> BoxInterna
return bbox


def bbox_transpose(
bbox: KeypointInternalType, axis: int, rows: int, cols: int
) -> KeypointInternalType: # skipcq: PYL-W0613
def bbox_transpose(bbox: KeypointInternalType, axis: int, rows: int, cols: int) -> KeypointInternalType:
"""Transposes a bounding box along given axis.

Args:
Expand Down
105 changes: 71 additions & 34 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import typing
import warnings
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Dict, Optional, Tuple, Type, Union
from pathlib import Path
from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union, cast

try:
import yaml
Expand Down Expand Up @@ -146,7 +146,7 @@ def from_dict(
) -> Optional[Serializable]:
"""
Args:
transform_dict (dict): A dictionary with serialized transform pipeline.
transform_dict: A dictionary with serialized transform pipeline.
nonserializable (dict): A dictionary that contains non-serializable transforms.
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
Keys in that dictionary should be named same as `name` arguments in respective transforms from
Expand All @@ -155,7 +155,7 @@ def from_dict(
"""
if lambda_transforms != "deprecated":
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
nonserializable = cast(Optional[Dict[str, Any]], lambda_transforms)

register_additional_transforms()
transform = transform_dict["transform"]
Expand All @@ -176,57 +176,94 @@ def check_data_format(data_format: str) -> None:


def save(
transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise"
transform: "Serializable",
filepath_or_buffer: Union[str, Path, TextIO],
data_format: str = "json",
on_not_implemented_error: str = "raise",
) -> None:
"""
Take a transform pipeline, serialize it and save a serialized version to a file
using either json or yaml format.
Serialize a transform pipeline and save it to either a file specified by a path or a file-like object
in either JSON or YAML format.

Args:
transform (obj): Transform to serialize.
filepath (str): Filepath to write to.
data_format (str): Serialization format. Should be either `json` or 'yaml'.
on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement
the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be
ignored and no transform arguments will be saved.
transform (Serializable): The transform pipeline to serialize.
filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to write the serialized
data to.
If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
the serialized data will be written to it directly.
data_format (str): The format to serialize the data in. Valid options are 'json' and 'yaml'.
Defaults to 'json'.
on_not_implemented_error (str): Determines the behavior if a transform does not implement the `to_dict` method.
If set to 'raise', a `NotImplementedError` is raised. If set to 'warn', the exception is ignored, and
no transform arguments are saved. Defaults to 'raise'.

Raises:
ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
"""
check_data_format(data_format)
transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
with open(filepath, "w") as f:

# Determine whether to write to a file or a file-like object
if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath
with open(filepath_or_buffer, "w") as f:
if data_format == "yaml":
if not yaml_available:
raise ValueError("You need to install PyYAML to save a pipeline in YAML format")
yaml.safe_dump(transform_dict, f, default_flow_style=False)
elif data_format == "json":
json.dump(transform_dict, f)
else: # Assume it's a file-like object
if data_format == "yaml":
if not yaml_available:
raise ValueError("You need to install PyYAML to save a pipeline in yaml format")
yaml.safe_dump(transform_dict, f, default_flow_style=False)
raise ValueError("You need to install PyYAML to save a pipeline in YAML format")
yaml.safe_dump(transform_dict, filepath_or_buffer, default_flow_style=False)
elif data_format == "json":
json.dump(transform_dict, f)
json.dump(transform_dict, filepath_or_buffer)


def load(
filepath: str,
filepath_or_buffer: Union[str, Path, TextIO],
data_format: str = "json",
nonserializable: Optional[Dict[str, Any]] = None,
lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
) -> object:
"""
Load a serialized pipeline from a json or yaml file and construct a transform pipeline.
Load a serialized pipeline from a file or file-like object and construct a transform pipeline.

Args:
filepath (str): Filepath to read from.
data_format (str): Serialization format. Should be either `json` or 'yaml'.
nonserializable (dict): A dictionary that contains non-serializable transforms.
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
Keys in that dictionary should be named same as `name` arguments in respective transforms from
a serialized pipeline.
lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to read the serialized
data from.
If a string is provided, it is interpreted as a path to a file. If a file-like object is provided,
the serialized data will be read from it directly.
data_format (str): The format of the serialized data. Valid options are 'json' and 'yaml'.
Defaults to 'json'.
nonserializable (Optional[Dict[str, Any]]): A dictionary that contains non-serializable transforms.
This dictionary is required when restoring a pipeline that contains non-serializable transforms.
Keys in the dictionary should be named the same as the `name` arguments in respective transforms
from the serialized pipeline. Defaults to None.

Returns:
object: The deserialized transform pipeline.

Raises:
ValueError: If `data_format` is 'yaml' but PyYAML is not installed.
"""
if lambda_transforms != "deprecated":
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)

check_data_format(data_format)
load_fn = json.load if data_format == "json" else yaml.safe_load
with open(filepath) as f:
transform_dict = load_fn(f) # type: ignore

if isinstance(filepath_or_buffer, (str, Path)): # Assume it's a filepath
with open(filepath_or_buffer) as f:
if data_format == "json":
transform_dict = json.load(f)
else:
if not yaml_available:
raise ValueError("You need to install PyYAML to load a pipeline in yaml format")
transform_dict = yaml.safe_load(f)
else: # Assume it's a file-like object
if data_format == "json":
transform_dict = json.load(filepath_or_buffer)
else:
if not yaml_available:
raise ValueError("You need to install PyYAML to load a pipeline in yaml format")
transform_dict = yaml.safe_load(filepath_or_buffer)

return from_dict(transform_dict, nonserializable=nonserializable)

Expand Down
25 changes: 0 additions & 25 deletions black.toml

This file was deleted.

1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
black==24.2.0
deepdiff==6.7.1
flake8==7.0.0
isort==5.13.2
mypy==1.8.0
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest

try:
import torch # skipcq: PYL-W0611
import torchvision # skipcq: PYL-W0611
import torch
import torchvision

torch_available = True
except ImportError:
Expand Down
1 change: 0 additions & 1 deletion tests/files/output_v1.1.0_with_totensor.json

This file was deleted.

1 change: 0 additions & 1 deletion tests/files/output_v1.1.0_without_totensor.json

This file was deleted.

113 changes: 55 additions & 58 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import json
import os
import io
import random
from pathlib import Path
from unittest.mock import patch

import cv2
import numpy as np
import pytest
from deepdiff import DeepDiff

import albumentations as A
import albumentations.augmentations.functional as F
import albumentations.augmentations.geometric.functional as FGeometric
from albumentations.core.serialization import SERIALIZABLE_REGISTRY, shorten_class_name
from albumentations.core.transforms_interface import ImageOnlyTransform
Expand Down Expand Up @@ -808,61 +808,58 @@ def vflip_keypoint(keypoint, **kwargs):
assert np.array_equal(aug_data["keypoints"], deserialized_aug_data["keypoints"])


# def test_serialization_v2_conversion_without_totensor():
# current_directory = os.path.dirname(os.path.abspath(__file__))
# files_directory = os.path.join(current_directory, "files")
# transform_1_1_0 = A.load(os.path.join(files_directory, "transform_v1.1.0_without_totensor.json"))
# with open(os.path.join(files_directory, "output_v1.1.0_without_totensor.json")) as f:
# output_1_1_0 = json.load(f)
# np.random.seed(42)
# image = np.random.randint(low=0, high=255, size=(256, 256, 3), dtype=np.uint8)
# random.seed(42)
# transformed_image = transform_1_1_0(image=image)["image"]
# assert transformed_image.tolist() == output_1_1_0


# @skipif_no_torch
# def test_serialization_v2_conversion_with_totensor():
# current_directory = os.path.dirname(os.path.abspath(__file__))
# files_directory = os.path.join(current_directory, "files")
# transform_1_1_0 = A.load(os.path.join(files_directory, "transform_v1.1.0_with_totensor.json"))
# with open(os.path.join(files_directory, "output_v1.1.0_with_totensor.json")) as f:
# output_1_1_0 = json.load(f)
# np.random.seed(42)
# random.seed(42)
# image = np.random.randint(low=0, high=255, size=(256, 256, 3), dtype=np.uint8)
# transformed_image = transform_1_1_0(image=image)["image"]
# assert transformed_image.numpy().tolist() == output_1_1_0


# def test_serialization_v2_without_totensor():
# current_directory = os.path.dirname(os.path.abspath(__file__))
# files_directory = os.path.join(current_directory, "files")
# transform = A.load(os.path.join(files_directory, "transform_serialization_v2_without_totensor.json"))
# with open(os.path.join(files_directory, "output_v1.1.0_without_totensor.json")) as f:
# output_1_1_0 = json.load(f)
# np.random.seed(42)
# random.seed(42)
# image = np.random.randint(low=0, high=255, size=(256, 256, 3), dtype=np.uint8)
# transformed_image = transform(image=image)["image"]

# with open("1.json", "w") as f:
# json.dump(transformed_image.tolist(), f)
# assert transformed_image.tolist() == output_1_1_0


# @skipif_no_torch
# def test_serialization_v2_with_totensor():
# current_directory = os.path.dirname(os.path.abspath(__file__))
# files_directory = os.path.join(current_directory, "files")
# transform = A.load(os.path.join(files_directory, "transform_serialization_v2_with_totensor.json"))
# with open(os.path.join(files_directory, "output_v1.1.0_with_totensor.json")) as f:
# output_1_1_0 = json.load(f)
# np.random.seed(42)
# image = np.random.randint(low=0, high=255, size=(256, 256, 3), dtype=np.uint8)
# random.seed(42)
# transformed_image = transform(image=image)["image"]
# assert transformed_image.numpy().tolist() == output_1_1_0
@pytest.mark.parametrize(
"transform_file_name",
["transform_v1.1.0_without_totensor.json", "transform_serialization_v2_without_totensor.json"],
)
def test_serialization_conversion_without_totensor(transform_file_name):
# Step 1: Load transform from file
current_directory = Path(__file__).resolve().parent
files_directory = current_directory / "files"
transform_file_path = files_directory / transform_file_name
transform = A.load(transform_file_path, data_format="json")

# Step 2: Serialize it to buffer in memory
buffer = io.StringIO()
A.save(transform, buffer, data_format="json")
buffer.seek(0) # Reset buffer position to the beginning

# Step 3: Load transform from this memory buffer
transform_from_buffer = A.load(buffer, data_format="json")

# Ensure the buffer is closed after use
buffer.close()

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"


@pytest.mark.parametrize(
"transform_file_name",
["transform_v1.1.0_with_totensor.json", "transform_serialization_v2_with_totensor.json"],
)
@skipif_no_torch
def test_serialization_conversion_with_totensor(transform_file_name):
# Load transform from file
current_directory = Path(__file__).resolve().parent
files_directory = current_directory / "files"
transform_file_path = files_directory / transform_file_name

transform = A.load(transform_file_path, data_format="json")

# Serialize it to buffer in memory
buffer = io.StringIO()
A.save(transform, buffer, data_format="json")
buffer.seek(0) # Reset buffer position to the beginning

# Load transform from this memory buffer
transform_from_buffer = A.load(buffer, data_format="json")
buffer.close() # Ensure the buffer is closed after use

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"


def test_custom_transform_with_overlapping_name():
Expand Down
Loading