Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

working on perturbation detectors #51

Merged
merged 22 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
05a4cc2
working on perturbation detectors
rabah-khalek Aug 8, 2024
cf2a8c4
Refactor HF ppl model to convert numpy array to PIL image
Inokinoki Aug 5, 2024
7ba35b6
Allow to set global mode for an HF ppl model for PIL conversion
Inokinoki Aug 5, 2024
6c3ffc9
mode switch in hf models
rabah-khalek Aug 8, 2024
7e14795
supporting gray scale
rabah-khalek Aug 8, 2024
f3d8e32
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 8, 2024
c30dade
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 10, 2024
98baa6d
added missing predict_rgb_image
rabah-khalek Aug 12, 2024
a9fa22f
ensuring backward compatibility with predict_image
rabah-khalek Aug 12, 2024
e9198ce
updating detectors
rabah-khalek Aug 12, 2024
4dd46b4
Adding noise perturbation detector with Gaussian noise (#52)
bmalezieux Aug 12, 2024
e547d4d
updating detectors
rabah-khalek Aug 12, 2024
a44399d
refactoring detectors
rabah-khalek Aug 13, 2024
fe26272
small updates
rabah-khalek Aug 13, 2024
c359c9c
refactored spec setting
rabah-khalek Aug 13, 2024
6dca401
fixed import in object_detection dataloader
rabah-khalek Aug 13, 2024
99d98dd
renaming pert detectors
rabah-khalek Aug 13, 2024
6601ecb
Merge branch 'main' into perturbation-detectors
rabah-khalek Aug 13, 2024
14de1fa
Merge branch 'perturbation-detectors' into refactoring-detectors
rabah-khalek Aug 13, 2024
182731c
fixed import
rabah-khalek Aug 13, 2024
6ba1994
fixing get_scan_results args
rabah-khalek Aug 13, 2024
ffbb425
Merge pull request #53 from Giskard-AI/refactoring-detectors
rabah-khalek Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions giskard_vision/core/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,10 @@
get_image_channel_number,
get_image_size,
)
from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import AttributesIssueMeta

from ..types import TypesBase

EthicalIssueMeta = IssueGroup(
"Ethical",
description="The data are filtered by metadata like age, facial hair, or gender to detect ethical biases.",
)
PerformanceIssueMeta = IssueGroup(
"Performance",
description="The data are filtered by metadata like emotion, head pose, or exposure value to detect performance issues.",
)
AttributesIssueMeta = IssueGroup(
"Attributes",
description="The data are filtered by the image attributes like width, height, or brightness value to detect issues.",
)


class DataIteratorBase(ABC):
"""Abstract class serving as a base template for DataLoaderBase and DataLoaderWrapper classes.
Expand Down
3 changes: 2 additions & 1 deletion giskard_vision/core/dataloaders/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from PIL.Image import Image as PILImage

from giskard_vision.core.dataloaders.base import AttributesIssueMeta, DataIteratorBase
from giskard_vision.core.dataloaders.base import DataIteratorBase
from giskard_vision.core.dataloaders.meta import MetaData, get_pil_image_depth
from giskard_vision.core.issues import AttributesIssueMeta
from giskard_vision.utils.errors import GiskardError, GiskardImportError


Expand Down
2 changes: 1 addition & 1 deletion giskard_vision/core/dataloaders/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from PIL.Image import Image as PILImage

from giskard_vision.core.detectors.base import IssueGroup
from giskard_vision.core.issues import IssueGroup


class MetaData:
Expand Down
72 changes: 72 additions & 0 deletions giskard_vision/core/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,78 @@ def get_image(self, idx: int) -> np.ndarray:
return cv2.GaussianBlur(image, self._kernel_size, *self._sigma)


class NoisyDataLoader(DataLoaderWrapper):
"""Wrapper class for a DataIteratorBase, providing noisy images.

Args:
dataloader (DataIteratorBase): The data loader to be wrapped.
sigma (float): Standard deviation of the Gaussian noise.

Returns:
NoisyDataLoader: Noisy data loader instance.
"""

def __init__(
self,
dataloader: DataIteratorBase,
sigma: float = 0.1,
) -> None:
"""
Initializes the BlurredDataLoader.

Args:
dataloader (DataIteratorBase): The data loader to be wrapped.
sigma (float): Standard deviation of the Gaussian noise.
"""
super().__init__(dataloader)
self._sigma = sigma

@property
def name(self):
"""
Gets the name of the blurred data loader.

Returns:
str: The name of the blurred data loader.
"""
return "noisy"

def get_image(self, idx: int) -> np.ndarray:
"""
Gets a blurred image using Gaussian blur.

Args:
idx (int): Index of the data.

Returns:
np.ndarray: Blurred image data.
"""
image = super().get_image(idx)
return self.add_gaussian_noise(image, self._sigma * 255)

def add_gaussian_noise(self, image, std_dev):
"""
Add Gaussian noise to the image

Args:
image (np.ndarray): Image
std_dev (float): Standard deviation of the Gaussian noise.

Returns:
np.ndarray: Noisy image
"""
# Generate Gaussian noise
noise = np.random.normal(0, std_dev, image.shape).astype(np.float32)

# Add the noise to the image
noisy_image = cv2.add(image.astype(np.float32), noise)

# Clip the values to stay within valid range (0-255 for uint8)
noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)

return noisy_image


class ColoredDataLoader(DataLoaderWrapper):
"""Wrapper class for a DataIteratorBase, providing color-altered images using OpenCV color conversion.

Expand Down
51 changes: 39 additions & 12 deletions giskard_vision/core/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence, Tuple

from giskard_vision.core.issues import IssueGroup
from giskard_vision.utils.errors import GiskardImportError


@dataclass(frozen=True)
class IssueGroup:
name: str
description: str
from .specs import DetectorSpecsBase


@dataclass
Expand Down Expand Up @@ -51,7 +48,7 @@ def get_meta_required(self) -> dict:
}


class DetectorVisionBase:
class DetectorVisionBase(DetectorSpecsBase):
"""
Abstract class for Vision Detectors

Expand All @@ -67,12 +64,6 @@ class DetectorVisionBase:
evaluation results for the scan.
"""

issue_group: IssueGroup
warning_messages: dict
issue_level_threshold: float = 0.2
deviation_threshold: float = 0.05
num_images: int = 0

def run(
self,
model: Any,
Expand Down Expand Up @@ -139,6 +130,42 @@ def get_issues(

return issues

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)

@abstractmethod
def get_results(self, model: Any, dataset: Any) -> List[ScanResult]:
"""Returns a list of ScanResult
Expand Down
38 changes: 1 addition & 37 deletions giskard_vision/core/detectors/metadata_scan_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import numpy as np
import pandas as pd

from giskard_vision.core.dataloaders.base import PerformanceIssueMeta
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import PerformanceIssueMeta
from giskard_vision.core.tests.base import MetricBase
from giskard_vision.utils.errors import GiskardImportError

Expand Down Expand Up @@ -258,39 +258,3 @@ def get_df_for_scan(self, model: Any, dataset: Any, list_metadata: Sequence[str]
pass

return pd.DataFrame(df)

def get_scan_result(
self, metric_value, metric_reference_value, metric_name, filename_examples, name, size_data, issue_group
) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
raise GiskardImportError(["giskard"]) from e

relative_delta = metric_value - metric_reference_value
if self.metric_type == "relative":
relative_delta /= metric_reference_value

issue_level = IssueLevel.MINOR
if self.metric_direction == "better_lower":
if relative_delta > self.issue_level_threshold + self.deviation_threshold:
issue_level = IssueLevel.MAJOR
elif relative_delta > self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM
elif self.metric_direction == "better_higher":
if relative_delta < -(self.issue_level_threshold + self.deviation_threshold):
issue_level = IssueLevel.MAJOR
elif relative_delta < -self.issue_level_threshold:
issue_level = IssueLevel.MEDIUM

return ScanResult(
name=name,
metric_name=metric_name,
metric_value=metric_value,
metric_reference_value=metric_reference_value,
issue_level=issue_level,
slice_size=size_data,
filename_examples=filename_examples,
relative_delta=relative_delta,
issue_group=issue_group,
)
87 changes: 87 additions & 0 deletions giskard_vision/core/detectors/perturbation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import os
from abc import abstractmethod
from importlib import import_module
from pathlib import Path
from typing import Any, Sequence

import cv2

from giskard_vision.core.dataloaders.wrappers import FilteredDataLoader
from giskard_vision.core.detectors.base import DetectorVisionBase, ScanResult
from giskard_vision.core.issues import Robustness
from giskard_vision.core.tests.base import TestDiffBase


class PerturbationBaseDetector(DetectorVisionBase):
"""
Abstract class for Landmark Detection Detectors

Methods:
get_dataloaders(dataset: Any) -> Sequence[Any]:
Abstract method that returns a list of dataloaders corresponding to
slices or transformations

get_results(model: Any, dataset: Any) -> Sequence[ScanResult]:
Returns a list of ScanResult containing the evaluation results

get_scan_result(self, test_result) -> ScanResult:
Convert TestResult to ScanResult
"""

issue_group = Robustness

def set_specs_from_model_type(self, model_type):
module = import_module(f"giskard_vision.{model_type}.detectors.specs")
DetectorSpecs = getattr(module, "DetectorSpecs")

if DetectorSpecs:
# Only set attributes that are not part of Python's special attributes (those starting with __)
for attr_name, attr_value in vars(DetectorSpecs).items():
if not attr_name.startswith("__") and hasattr(self, attr_name):
setattr(self, attr_name, attr_value)
else:
raise ValueError(f"No detector specifications found for model type: {model_type}")

@abstractmethod
def get_dataloaders(self, dataset: Any) -> Sequence[Any]: ...

def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
self.set_specs_from_model_type(model.model_type)
dataloaders = self.get_dataloaders(dataset)

results = []
for dl in dataloaders:
test_result = TestDiffBase(metric=self.metric, threshold=1).run(
model=model,
dataloader=dl,
dataloader_ref=dataset,
)

# Save example images from dataloader and dataset
current_path = str(Path())
os.makedirs(f"{current_path}/examples_images", exist_ok=True)
filename_examples = []

index_worst = 0 if test_result.indexes_examples is None else test_result.indexes_examples[0]

if isinstance(dl, FilteredDataLoader):
filename_example_dataloader_ref = str(Path() / "examples_images" / f"{dataset.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader_ref, dataset[index_worst][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = str(Path() / "examples_images" / f"{dl.name}_{index_worst}.png")
cv2.imwrite(filename_example_dataloader, dl[index_worst][0][0])
filename_examples.append(filename_example_dataloader)
results.append(
self.get_scan_result(
test_result.metric_value_test,
test_result.metric_value_ref,
test_result.metric_name,
filename_examples,
dl.name,
len(dl),
issue_group=self.issue_group,
)
)

return results
13 changes: 13 additions & 0 deletions giskard_vision/core/detectors/specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from giskard_vision.core.issues import IssueGroup
from giskard_vision.image_classification.tests.performance import MetricBase


class DetectorSpecsBase:
issue_group: IssueGroup
warning_messages: dict
metric: MetricBase = None
metric_type: str = None
metric_direction: str = None
deviation_threshold: float = 0.10
issue_level_threshold: float = 0.05
num_images: int = 0
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import BlurredDataLoader

from ...core.detectors.decorator import maybe_detector
from .base import LandmarkDetectionBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("blurring_landmark", tags=["vision", "face", "landmark", "transformed", "blurred"])
class TransformationBlurringDetectorLandmark(LandmarkDetectionBaseDetector):
@maybe_detector("blurring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationBlurringDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance on blurred images
"""

issue_group = Robustness

def __init__(self, kernel_size=(11, 11), sigma=(3, 3)):
self.kernel_size = kernel_size
self.sigma = sigma
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from giskard_vision.core.dataloaders.wrappers import ColoredDataLoader

from ...core.detectors.decorator import maybe_detector
from .base import LandmarkDetectionBaseDetector, Robustness
from .perturbation import PerturbationBaseDetector


@maybe_detector("color_landmark", tags=["vision", "face", "landmark", "filtered", "colored"])
class TransformationColorDetectorLandmark(LandmarkDetectionBaseDetector):
@maybe_detector("coloring", tags=["vision", "robustness", "image_classification", "landmark", "object_detection"])
class TransformationColorDetector(PerturbationBaseDetector):
"""
Detector that evaluates models performance depending on images in grayscale
"""

issue_group = Robustness

def get_dataloaders(self, dataset):
dl = ColoredDataLoader(dataset)

Expand Down
Loading
Loading