From c47b1593d4ff630610d82e9ed05c3720619bb8ca Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Thu, 18 Jul 2024 13:54:50 +0200 Subject: [PATCH 1/8] use fastai api to (down-)load pets dataset --- vit_shapley/datamodules/datasets/Pet_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vit_shapley/datamodules/datasets/Pet_dataset.py b/vit_shapley/datamodules/datasets/Pet_dataset.py index 2ba0bf7..3e53fb3 100644 --- a/vit_shapley/datamodules/datasets/Pet_dataset.py +++ b/vit_shapley/datamodules/datasets/Pet_dataset.py @@ -3,6 +3,7 @@ import random import sklearn import sklearn.model_selection +from fastai.vision.all import untar_data, URLs import os from torch.utils.data import DataLoader @@ -66,7 +67,8 @@ def __init__(self, dataset_location, transform_params, explanation_location, exp def get_data_list(self): # Load files containing labels, and perform train/valid split if necessary - data=pd.read_csv(os.path.join(self.dataset_location, "annotations/list.txt"), sep=' ', skiprows=[0,1,2,3,4,5], names=["classid","species","breed"]) + path = untar_data(URLs.PETS) + data=pd.read_csv(os.path.join(path, "annotations/list.txt"), sep=' ', skiprows=[0,1,2,3,4,5], names=["classid","species","breed"]) idx_train, idx_valtest = sklearn.model_selection.train_test_split(data.index, random_state=44, test_size=0.2) idx_val, idx_test = sklearn.model_selection.train_test_split(idx_valtest, random_state=44, test_size=0.5) @@ -83,7 +85,7 @@ def get_data_list(self): #labels = data["classid"].astype(int)-1 #data['noisy_labels_0'].map(lambda x: self.labels.index(x)) labels = data.index.map(lambda x: "_".join(x.split('_')[:-1])).map(lambda x: self.labels.index(x)) - img_paths = data.index.map(lambda x: str(os.path.join(os.path.join(self.dataset_location, "images/"),x))+".jpg").values.tolist() + img_paths = data.index.map(lambda x: str(os.path.join(os.path.join(path, "images/"),x))+".jpg").values.tolist() data_list = [{'img_path': img_path, 'label': label, 'dataset': self.__class__.__name__} for img_path, label in zip(img_paths, labels)] random.Random(42).shuffle(data_list) From 2640a68b71b7fc3fadc3e45c761ee48458dde9af Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 13:44:40 +0200 Subject: [PATCH 2/8] merge code duplicates across classifier, surrogate & explainer into one single base model --- vit_shapley/modules/base_model.py | 135 +++++++++++++++++++++++ vit_shapley/modules/classifier.py | 81 +++----------- vit_shapley/modules/classifier_masked.py | 73 ++---------- vit_shapley/modules/classifier_utils.py | 34 ------ vit_shapley/modules/explainer.py | 97 +++++----------- vit_shapley/modules/explainer_utils.py | 33 ------ vit_shapley/modules/surrogate.py | 78 ++----------- vit_shapley/modules/surrogate_utils.py | 31 ------ 8 files changed, 191 insertions(+), 371 deletions(-) create mode 100644 vit_shapley/modules/base_model.py diff --git a/vit_shapley/modules/base_model.py b/vit_shapley/modules/base_model.py new file mode 100644 index 0000000..20b6b83 --- /dev/null +++ b/vit_shapley/modules/base_model.py @@ -0,0 +1,135 @@ +import logging +from typing import Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import models as cnn_models +from transformers import get_cosine_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup + + +class BaseModel(pl.LightningModule): + """ + Args: + backbone_type: should be the class name defined in `timm.models.vision_transformer` + download_weight: whether to initialize backbone with the pretrained weights + load_path: If not None. loads the weights saved in the checkpoint to the model + target_type: `binary` or `multi-class` or `multi-label` + output_dim: the dimension of output + checkpoint_metric: the metric used to determine whether to save the current status as checkpoints during the validation phase + optim_type: type of optimizer for optimizing parameters + learning_rate: learning rate of optimizer + weight_decay: weight decay of optimizer + decay_power: only `cosine` annealing scheduler is supported currently + warmup_steps: parameter for the `cosine` annealing scheduler + """ + + def __init__(self, + backbone_type: str = "vit_base_patch16_224", + download_weight: bool = False, + load_path: Union[str, None] = None, + target_type: str = "multiclass", + output_dim: int = 10, + checkpoint_metric: str = "accuracy", + optim_type: str = "Adamw", + learning_rate: float = 1e-5, + learning_rate_min: float = 0.0, + layer_decay_rate: float = 1.0, + weight_decay: float = 1e-5, + decay_power: str = "cosine", + warmup_steps: int = 500): + super().__init__() + + self.save_hyperparameters() + + self.logger_ = logging.getLogger(self.__class__.__name__) + + assert not (self.hparams.download_weight and self.hparams.load_path is not None), \ + "'download_weight' and 'load_path' cannot be activated at the same time as the downloaded weight will be overwritten by weights in 'load_path'." + + # Backbone initialization. (currently support only vit) + if self.__class__.__name__ == "Classifier": + import vit_shapley.modules.vision_transformer_verbose as vit + else: + import vit_shapley.modules.vision_transformer as vit + + if hasattr(vit, self.hparams.backbone_type): + self.backbone = getattr(vit, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) + elif hasattr(cnn_models, self.hparams.backbone_type): + self.backbone = getattr(cnn_models, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) + else: + raise NotImplementedError("Not supported backbone type") + if self.hparams.download_weight: + self.logger_.info("The backbone parameters were initialized with the downloaded pretrained weights.") + + # Nullify classification head built in the backbone module and rebuild. + if self.backbone.__class__.__name__ == 'VisionTransformer': + self.head_in_features = self.backbone.head.in_features + self.backbone.head = nn.Identity() + elif self.backbone.__class__.__name__ == 'ResNet': + self.head_in_features = self.backbone.fc.in_features + self.backbone.fc = nn.Identity() + elif self.backbone.__class__.__name__ == 'DenseNet': + self.head_in_features = self.backbone.classifier.in_features + self.backbone.classifier = nn.Identity() + else: + raise NotImplementedError("Not supported backbone type") + + def load_checkpoint(self): + # Load checkpoints + if self.hparams.load_path is not None: + checkpoint = torch.load(self.hparams.load_path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + ret = self.load_state_dict(state_dict, strict=False) + self.logger_.info(f"Model parameters were updated from a checkpoint file {self.hparams.load_path}") + self.logger_.info(f"Unmatched parameters - missing_keys: {ret.missing_keys}") + self.logger_.info(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}") + elif not self.hparams.download_weight: + self.logger_.info("The backbone parameters were randomly initialized.") + + def configure_optimizers(self): + if self.hparams.optim_type == "Adamw": + optimizer = optim.AdamW(self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay) + elif self.hparams.optim_type == "Adam": + optimizer = optim.Adam(self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay) + elif self.hparams.optim_type == "SGD": + optimizer = optim.SGD(self.parameters(), + lr=self.hparams.learning_rate, + momentum=0.9, + weight_decay=self.hparams.weight_decay) + else: + optimizer = None + + # setup scheduler + if self.trainer.max_steps is None or self.trainer.max_steps == -1: + max_steps = ( + len(self.trainer.datamodule.train_dataloader()) * self.trainer.max_epochs // self.trainer.accumulate_grad_batches) + else: + max_steps = self.trainer.max_steps + + if self.hparams.decay_power == "cosine": + scheduler = {"scheduler": get_cosine_schedule_with_warmup(optimizer, + num_warmup_steps=self.hparams.warmup_steps, + num_training_steps=max_steps), + "interval": "step"} + return ([optimizer], [scheduler]) + elif self.hparams.decay_power == "polynomial": + scheduler = {"scheduler": get_polynomial_decay_schedule_with_warmup(optimizer, + num_warmup_steps=self.hparams.warmup_steps, + num_training_steps=max_steps, + lr_end=self.hparams.learning_rate_min, + power=0.9), + "interval": "step"} + return ([optimizer], [scheduler]) + elif self.hparams.decay_power is None: + return optimizer + else: + NotImplementedError("Unsupported scheduler!") diff --git a/vit_shapley/modules/classifier.py b/vit_shapley/modules/classifier.py index 95b62eb..f9599ac 100644 --- a/vit_shapley/modules/classifier.py +++ b/vit_shapley/modules/classifier.py @@ -1,80 +1,28 @@ -import logging +from typing import Union -import pytorch_lightning as pl -import torch import torch.nn as nn -from torchvision import models as cnn_models -import vit_shapley.modules.vision_transformer_verbose as vit +from vit_shapley.modules.base_model import BaseModel from vit_shapley.modules import classifier_utils -class Classifier(pl.LightningModule): +class Classifier(BaseModel): """ `pytorch_lightning` module for image classifier Args: - backbone_type: should be the class name defined in `torchvision.models.cnn_models` or `timm.models.vision_transformer` - download_weight: whether to initialize backbone with the pretrained weights - load_path: If not None. loads the weights saved in the checkpoint to the model - target_type: `binary` or `multi-class` or `multi-label` - output_dim: the dimension of output - checkpoint_metric: the metric used to determine whether to save the current status as checkpoints during the validation phase - optim_type: type of optimizer for optimizing parameters - learning_rate: learning rate of optimizer - weight_decay: weight decay of optimizer - decay_power: only `cosine` annealing scheduler is supported currently - warmup_steps: parameter for the `cosine` annealing scheduler + enable_pos_embed: wether to add positional embeddings to patch embeddings + loss_weight: weighting of classes in cross-entropy loss """ - def __init__(self, backbone_type: str, download_weight: bool, load_path: str or None, - target_type: str, output_dim: int, enable_pos_embed: bool, - checkpoint_metric: str or None, optim_type: str or None, learning_rate: float or None, - loss_weight: None, - weight_decay: float or None, decay_power: str or None, warmup_steps: int or None): + def __init__(self, + enable_pos_embed: bool, + loss_weight: Union[float, None] = None, + *args, **kwargs): - super().__init__() - self.save_hyperparameters() - - self.logger_ = logging.getLogger(__name__) - - assert not (self.hparams.download_weight and self.hparams.load_path is not None), \ - "'download_weight' and 'load_path' cannot be activated at the same time as the downloaded weight will be overwritten by weights in 'load_path'." - - # Backbone initialization. (currently support only vit and cnn) - if hasattr(vit, self.hparams.backbone_type): - self.backbone = getattr(vit, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - elif hasattr(cnn_models, self.hparams.backbone_type): - self.backbone = getattr(cnn_models, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - else: - raise NotImplementedError("Not supported backbone type") - if self.hparams.download_weight: - self.logger_.info("The backbone parameters were initialized with the downloaded pretrained weights.") - else: - self.logger_.info("The backbone parameters were randomly initialized.") - - # Nullify classification head built in the backbone module and rebuild. - if self.backbone.__class__.__name__ == 'VisionTransformer': - head_in_features = self.backbone.head.in_features - self.backbone.head = nn.Identity() - elif self.backbone.__class__.__name__ == 'ResNet': - head_in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() - elif self.backbone.__class__.__name__ == 'DenseNet': - head_in_features = self.backbone.classifier.in_features - self.backbone.classifier = nn.Identity() - else: - raise NotImplementedError("Not supported backbone type") - self.head = nn.Linear(head_in_features, self.hparams.output_dim) - - # Load checkpoints - if self.hparams.load_path is not None: - checkpoint = torch.load(self.hparams.load_path, map_location="cpu") - state_dict = checkpoint["state_dict"] - ret = self.load_state_dict(state_dict, strict=False) - self.logger_.info(f"Model parameters were updated from a checkpoint file {self.hparams.load_path}") - self.logger_.info(f"Unmatched parameters - missing_keys: {ret.missing_keys}") - self.logger_.info(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}") + super().__init__(*args, **kwargs) + self.head = nn.Linear(self.head_in_features, self.hparams.output_dim) + self.load_checkpoint() if not self.hparams.enable_pos_embed: self.backbone.pos_embed.requires_grad = False @@ -82,9 +30,6 @@ def __init__(self, backbone_type: str, download_weight: bool, load_path: str or # Set up modules for calculating metric classifier_utils.set_metrics(self) - def configure_optimizers(self): - return classifier_utils.set_schedule(self) - def forward(self, images, output_attentions=False, output_hidden_states=False): if self.backbone.__class__.__name__ == 'VisionTransformer': output = self.backbone(images, output_attentions=output_attentions, @@ -125,4 +70,4 @@ def test_step(self, batch, batch_idx): loss = classifier_utils.compute_metrics(self, logits=logits, labels=labels, phase='test') def test_epoch_end(self, outs): - classifier_utils.epoch_wrapup(self, phase='test') \ No newline at end of file + classifier_utils.epoch_wrapup(self, phase='test') diff --git a/vit_shapley/modules/classifier_masked.py b/vit_shapley/modules/classifier_masked.py index 79c3ee1..3dbde48 100644 --- a/vit_shapley/modules/classifier_masked.py +++ b/vit_shapley/modules/classifier_masked.py @@ -1,81 +1,27 @@ -import logging -import pytorch_lightning as pl import torch import torch.nn as nn from torchvision import models as cnn_models +from vit_shapley.modules.base_model import BaseModel import vit_shapley.modules.vision_transformer as vit from vit_shapley.modules import classifier_utils -class ClassifierMasked(pl.LightningModule): +class ClassifierMasked(BaseModel): """ `pytorch_lightning` module for surrogate Args: mask_location: how the mask is applied to the input. ("pre-softmax" or "zero-input") - backbone_type: should be the class name defined in `torchvision.models.cnn_models` or `timm.models.vision_transformer` - download_weight: whether to initialize backbone with the pretrained weights - load_path: If not None. loads the weights saved in the checkpoint to the model - target_type: `binary` or `multi-class` or `multi-label` - output_dim: the dimension of output - target_model: This model will be trained to generate output similar to the output generated by 'target_model' for the same input. - checkpoint_metric: the metric used to determine whether to save the current status as checkpoints during the validation phase - optim_type: type of optimizer for optimizing parameters - learning_rate: learning rate of optimizer - weight_decay: weight decay of optimizer - decay_power: only `cosine` annealing scheduler is supported currently - warmup_steps: parameter for the `cosine` annealing scheduler """ - def __init__(self, mask_location: str, backbone_type: str, download_weight: bool, load_path: str or None, - target_type: str, output_dim: int, - checkpoint_metric: str or None, optim_type: str or None, learning_rate: float or None, - loss_weight: None, - weight_decay: float or None, decay_power: str or None, warmup_steps: int or None): + def __init__(self, + mask_location: str, + *args, **kwargs): - super().__init__() - self.save_hyperparameters() - - self.logger_ = logging.getLogger(__name__) - - assert not (self.hparams.download_weight and self.hparams.load_path is not None), \ - "'download_weight' and 'load_path' cannot be activated at the same time as the downloaded weight will be overwritten by weights in 'load_path'." - - # Backbone initialization. (currently support only vit and cnn) - if hasattr(vit, self.hparams.backbone_type): - self.backbone = getattr(vit, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - elif hasattr(cnn_models, self.hparams.backbone_type): - self.backbone = getattr(cnn_models, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - else: - raise NotImplementedError("Not supported backbone type") - if self.hparams.download_weight: - self.logger_.info("The backbone parameters were initialized with the downloaded pretrained weights.") - else: - self.logger_.info("The backbone parameters were randomly initialized.") - - # Nullify classification head built in the backbone module and rebuild. - if self.backbone.__class__.__name__ == 'VisionTransformer': - head_in_features = self.backbone.head.in_features - self.backbone.head = nn.Identity() - elif self.backbone.__class__.__name__ == 'ResNet': - head_in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() - elif self.backbone.__class__.__name__ == 'DenseNet': - head_in_features = self.backbone.classifier.in_features - self.backbone.classifier = nn.Identity() - else: - raise NotImplementedError("Not supported backbone type") - self.head = nn.Linear(head_in_features, self.hparams.output_dim) - - # Load checkpoints - if self.hparams.load_path is not None: - checkpoint = torch.load(self.hparams.load_path, map_location="cpu") - state_dict = checkpoint["state_dict"] - ret = self.load_state_dict(state_dict, strict=False) - self.logger_.info(f"Model parameters were updated from a checkpoint file {self.hparams.load_path}") - self.logger_.info(f"Unmatched parameters - missing_keys: {ret.missing_keys}") - self.logger_.info(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}") + super().__init__(*args, **kwargs) + self.head = nn.Linear(self.head_in_features, self.hparams.output_dim) + self.load_checkpoint() # Check the validity of 'mask_location` parameter if hasattr(vit, self.hparams.backbone_type): @@ -95,9 +41,6 @@ def __init__(self, mask_location: str, backbone_type: str, download_weight: bool # Set up modules for calculating metric classifier_utils.set_metrics(self) - def configure_optimizers(self): - return classifier_utils.set_schedule(self) - def forward(self, images, masks, mask_location=None): assert masks.shape[-1] == self.num_players mask_location = self.hparams.mask_location if mask_location is None else mask_location diff --git a/vit_shapley/modules/classifier_utils.py b/vit_shapley/modules/classifier_utils.py index b0499f0..1c5e2b9 100644 --- a/vit_shapley/modules/classifier_utils.py +++ b/vit_shapley/modules/classifier_utils.py @@ -2,40 +2,6 @@ import torch from torch.nn import functional as F from torchmetrics import Accuracy, MeanMetric, AUROC, Precision, F1Score, Recall, CohenKappa -from transformers import get_cosine_schedule_with_warmup -from transformers.optimization import AdamW - - -def set_schedule(pl_module): - optimizer = None - if pl_module.hparams.optim_type is None: - return ([None], [None],) - else: - if pl_module.hparams.optim_type == "Adamw": - optimizer = AdamW(params=pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "Adam": - optimizer = torch.optim.Adam(pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "SGD": - optimizer = torch.optim.SGD(pl_module.parameters(), lr=pl_module.hparams.learning_rate, momentum=0.9, - weight_decay=pl_module.hparams.weight_decay) - - if pl_module.trainer.max_steps is None or pl_module.trainer.max_steps == -1: - max_steps = ( - len(pl_module.trainer.datamodule.train_dataloader()) * pl_module.trainer.max_epochs // pl_module.trainer.accumulate_grad_batches) - else: - max_steps = pl_module.trainer.max_steps - - if pl_module.hparams.decay_power == "cosine": - scheduler = {"scheduler": get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=pl_module.hparams.warmup_steps, - num_training_steps=max_steps), - "interval": "step"} - else: - NotImplementedError("Only cosine scheduler is implemented for now") - - return ([optimizer], [scheduler],) def set_metrics(pl_module): diff --git a/vit_shapley/modules/explainer.py b/vit_shapley/modules/explainer.py index b2184aa..2f7eab6 100644 --- a/vit_shapley/modules/explainer.py +++ b/vit_shapley/modules/explainer.py @@ -1,17 +1,17 @@ -import copy +from typing import Union + import ipdb -import logging import pytorch_lightning as pl import torch import torch.nn as nn from torch.nn import functional as F -from torchvision import models as cnn_models +from vit_shapley.modules.base_model import BaseModel import vit_shapley.modules.vision_transformer as vit from vit_shapley.modules import explainer_utils -class Explainer(pl.LightningModule): +class Explainer(BaseModel): """ `pytorch_lightning` module for surrogate @@ -19,11 +19,6 @@ class Explainer(pl.LightningModule): normalization: 'additive' or 'multiplicative' normalization_class: 'additive', activation: - backbone_type: should be the class name defined in `torchvision.models.cnn_models` or `timm.models.vision_transformer` - download_weight: whether to initialize backbone with the pretrained weights - load_path: If not None. loads the weights saved in the checkpoint to the model - target_type: `binary` or `multi-class` or `multi-label` - output_dim: the dimension of output, explainer_head_num_attention_blocks: explainer_head_include_cls: @@ -36,60 +31,29 @@ class Explainer(pl.LightningModule): efficiency_lambda: lambda hyperparameter for efficiency penalty. efficiency_class_lambda: lambda hyperparameter for efficiency penalty. freeze_backbone: whether to freeze the backbone while training - checkpoint_metric: the metric used to determine whether to save the current status as checkpoints during the validation phase - optim_type: type of optimizer for optimizing parameters - learning_rate: learning rate of optimizer - weight_decay: weight decay of optimizer - decay_power: only `cosine` annealing scheduler is supported currently - warmup_steps: parameter for the `cosine` annealing scheduler """ - def __init__(self, normalization, normalization_class, activation, backbone_type: str, download_weight: bool, - load_path: str or None, - residual: list, - target_type: str, output_dim: int, - explainer_head_num_attention_blocks: int, explainer_head_include_cls: bool, - explainer_head_num_mlp_layers: int, explainer_head_mlp_layer_ratio: bool, explainer_norm: bool, - surrogate: pl.LightningModule, link: pl.LightningModule or nn.Module or None, efficiency_lambda, - efficiency_class_lambda, - freeze_backbone: bool, checkpoint_metric: str or None, - optim_type: str or None, learning_rate: float or None, weight_decay: float or None, - decay_power: str or None, warmup_steps: int or None, load_path_state_dict=False): - - super().__init__() + def __init__(self, + normalization = "additive", + normalization_class = None, + activation = "tanh", + residual: list = [], + explainer_head_num_attention_blocks: int = 1, + explainer_head_include_cls: bool = True, + explainer_head_num_mlp_layers: int = 3, + explainer_head_mlp_layer_ratio: bool = 4, + explainer_norm: bool = True, + surrogate: pl.LightningModule = None, + link: str = "softmax", + efficiency_lambda: float = 0.0, + efficiency_class_lambda: float = 0.0, + freeze_backbone: str = 'none', + *args, **kwargs): + + super().__init__(*args, **kwargs) self.save_hyperparameters() self.__null = None - self.logger_ = logging.getLogger(__name__) - - assert not (self.hparams.download_weight and self.hparams.load_path is not None), \ - "'download_weight' and 'load_path' cannot be activated at the same time as the downloaded weight will be overwritten by weights in 'load_path'." - - # Backbone initialization. (currently support only vit and cnn) - if hasattr(vit, self.hparams.backbone_type): - self.backbone = getattr(vit, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - elif hasattr(cnn_models, self.hparams.backbone_type): - self.backbone = getattr(cnn_models, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - else: - raise NotImplementedError("Not supported backbone type") - if self.hparams.download_weight: - self.logger_.info("The backbone parameters were initialized with the downloaded pretrained weights.") - else: - self.logger_.info("The backbone parameters were randomly initialized.") - - # Nullify classification head built in the backbone module and rebuild. - if self.backbone.__class__.__name__ == 'VisionTransformer': - head_in_features = self.backbone.head.in_features - self.backbone.head = nn.Identity() - elif self.backbone.__class__.__name__ == 'ResNet': - head_in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() - elif self.backbone.__class__.__name__ == 'DenseNet': - head_in_features = self.backbone.classifier.in_features - self.backbone.classifier = nn.Identity() - else: - raise NotImplementedError("Not supported backbone type") - if self.backbone.__class__.__name__ == 'VisionTransformer': # attention_blocks if self.hparams.explainer_head_num_attention_blocks == 0: @@ -174,17 +138,9 @@ def __init__(self, normalization, normalization_class, activation, backbone_type else: raise NotImplementedError("'explainer_head' is only implemented for VisionTransformer.") + # Load checkpoints - if self.hparams.load_path is not None: - if load_path_state_dict: - state_dict = torch.load(self.hparams.load_path, map_location="cpu") - else: - checkpoint = torch.load(self.hparams.load_path, map_location="cpu") - state_dict = checkpoint["state_dict"] - ret = self.load_state_dict(state_dict, strict=False) - self.logger_.info(f"Model parameters were updated from a checkpoint file {self.hparams.load_path}") - self.logger_.info(f"Unmatched parameters - missing_keys: {ret.missing_keys}") - self.logger_.info(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}") + self.load_checkpoint() # Set up link function if self.hparams.link is None: @@ -414,10 +370,7 @@ def __init__(self, normalization, normalization_class, activation, backbone_type # self.hparams.surrogate.backbone.norm = nn.Identity() - def configure_optimizers(self): - return explainer_utils.set_schedule(self) - - def null(self, images: torch.Tensor or None = None) -> torch.Tensor: + def null(self, images: Union[torch.Tensor, None] = None) -> torch.Tensor: """ calculate or load cached null diff --git a/vit_shapley/modules/explainer_utils.py b/vit_shapley/modules/explainer_utils.py index e3fc168..1ebc4e4 100644 --- a/vit_shapley/modules/explainer_utils.py +++ b/vit_shapley/modules/explainer_utils.py @@ -1,39 +1,6 @@ import torch from torch.nn import functional as F from torchmetrics import MeanMetric -from transformers import get_cosine_schedule_with_warmup -from transformers.optimization import AdamW - - -def set_schedule(pl_module): - optimizer = None - if pl_module.hparams.optim_type == "Adamw": - optimizer = AdamW(params=pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "Adam": - optimizer = torch.optim.Adam(pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "SGD": - optimizer = torch.optim.SGD(pl_module.parameters(), lr=pl_module.hparams.learning_rate, momentum=0.9, - weight_decay=pl_module.hparams.weight_decay) - - if pl_module.trainer.max_steps is None or pl_module.trainer.max_steps == -1: - max_steps = ( - len(pl_module.trainer.datamodule.train_dataloader()) * pl_module.trainer.max_epochs // pl_module.trainer.accumulate_grad_batches) - else: - max_steps = pl_module.trainer.max_steps - - if pl_module.hparams.decay_power == "cosine": - scheduler = {"scheduler": get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=pl_module.hparams.warmup_steps, - num_training_steps=max_steps), - "interval": "step"} - return ([optimizer], [scheduler],) - elif pl_module.hparams.decay_power is None: - return optimizer - else: - - NotImplementedError("Only cosine scheduler is implemented for now") def set_metrics(pl_module): diff --git a/vit_shapley/modules/surrogate.py b/vit_shapley/modules/surrogate.py index 2310941..af24691 100644 --- a/vit_shapley/modules/surrogate.py +++ b/vit_shapley/modules/surrogate.py @@ -1,87 +1,32 @@ -import logging +from typing import Union import pytorch_lightning as pl import torch import torch.nn as nn from torchvision import models as cnn_models +from vit_shapley.modules.base_model import BaseModel import vit_shapley.modules.vision_transformer as vit from vit_shapley.modules import surrogate_utils -class Surrogate(pl.LightningModule): +class Surrogate(BaseModel): """ `pytorch_lightning` module for surrogate Args: mask_location: how the mask is applied to the input. ("pre-softmax" or "zero-input") - backbone_type: should be the class name defined in `torchvision.models.cnn_models` or `timm.models.vision_transformer` - download_weight: whether to initialize backbone with the pretrained weights - load_path: If not None. loads the weights saved in the checkpoint to the model - target_type: `binary` or `multi-class` or `multi-label` - output_dim: the dimension of output target_model: This model will be trained to generate output similar to the output generated by 'target_model' for the same input. - checkpoint_metric: the metric used to determine whether to save the current status as checkpoints during the validation phase - optim_type: type of optimizer for optimizing parameters - learning_rate: learning rate of optimizer - weight_decay: weight decay of optimizer - decay_power: only `cosine` annealing scheduler is supported currently - warmup_steps: parameter for the `cosine` annealing scheduler """ - def __init__(self, mask_location: str, backbone_type: str, download_weight: bool, load_path: str or None, - target_type: str, output_dim: int, + def __init__(self, + mask_location: str, + target_model: Union[pl.LightningModule, nn.Module, None] = None, + *args, **kwargs): - target_model: pl.LightningModule or nn.Module or None, checkpoint_metric: str or None, - optim_type: str or None, - learning_rate: float or None, weight_decay: float or None, - decay_power: str or None, warmup_steps: int or None, load_path_state_dict=False): - - super().__init__() - self.save_hyperparameters() - - self.logger_ = logging.getLogger(__name__) - - assert not (self.hparams.download_weight and self.hparams.load_path is not None), \ - "'download_weight' and 'load_path' cannot be activated at the same time as the downloaded weight will be overwritten by weights in 'load_path'." - - # Backbone initialization. (currently support only vit and cnn) - if hasattr(vit, self.hparams.backbone_type): - self.backbone = getattr(vit, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - elif hasattr(cnn_models, self.hparams.backbone_type): - self.backbone = getattr(cnn_models, self.hparams.backbone_type)(pretrained=self.hparams.download_weight) - else: - raise NotImplementedError("Not supported backbone type") - if self.hparams.download_weight: - self.logger_.info("The backbone parameters were initialized with the downloaded pretrained weights.") - else: - self.logger_.info("The backbone parameters were randomly initialized.") - - # Nullify classification head built in the backbone module and rebuild. - if self.backbone.__class__.__name__ == 'VisionTransformer': - head_in_features = self.backbone.head.in_features - self.backbone.head = nn.Identity() - elif self.backbone.__class__.__name__ == 'ResNet': - head_in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() - elif self.backbone.__class__.__name__ == 'DenseNet': - head_in_features = self.backbone.classifier.in_features - self.backbone.classifier = nn.Identity() - else: - raise NotImplementedError("Not supported backbone type") - self.head = nn.Linear(head_in_features, self.hparams.output_dim) - - # Load checkpoints - if self.hparams.load_path is not None: - if load_path_state_dict: - state_dict = torch.load(self.hparams.load_path, map_location="cpu") - else: - checkpoint = torch.load(self.hparams.load_path, map_location="cpu") - state_dict = checkpoint["state_dict"] - ret = self.load_state_dict(state_dict, strict=False) - self.logger_.info(f"Model parameters were updated from a checkpoint file {self.hparams.load_path}") - self.logger_.info(f"Unmatched parameters - missing_keys: {ret.missing_keys}") - self.logger_.info(f"Unmatched parameters - unexpected_keys: {ret.unexpected_keys}") + super().__init__(*args, **kwargs) + self.head = nn.Linear(self.head_in_features, self.hparams.output_dim) + self.load_checkpoint() # Check the validity of 'mask_location` parameter if hasattr(vit, self.hparams.backbone_type): @@ -101,9 +46,6 @@ def __init__(self, mask_location: str, backbone_type: str, download_weight: bool # Set up modules for calculating metric surrogate_utils.set_metrics(self) - def configure_optimizers(self): - return surrogate_utils.set_schedule(self) - def forward(self, images, masks, mask_location=None): assert masks.shape[-1] == self.num_players mask_location = self.hparams.mask_location if mask_location is None else mask_location diff --git a/vit_shapley/modules/surrogate_utils.py b/vit_shapley/modules/surrogate_utils.py index 6385dbd..9bed918 100644 --- a/vit_shapley/modules/surrogate_utils.py +++ b/vit_shapley/modules/surrogate_utils.py @@ -2,37 +2,6 @@ import torch from torch.nn import functional as F from torchmetrics import MeanMetric -from transformers import get_cosine_schedule_with_warmup -from transformers.optimization import AdamW - - -def set_schedule(pl_module): - optimizer = None - if pl_module.hparams.optim_type == "Adamw": - optimizer = AdamW(params=pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "Adam": - optimizer = torch.optim.Adam(pl_module.parameters(), lr=pl_module.hparams.learning_rate, - weight_decay=pl_module.hparams.weight_decay) - elif pl_module.hparams.optim_type == "SGD": - optimizer = torch.optim.SGD(pl_module.parameters(), lr=pl_module.hparams.learning_rate, momentum=0.9, - weight_decay=pl_module.hparams.weight_decay) - - if pl_module.trainer.max_steps is None or pl_module.trainer.max_steps == -1: - max_steps = ( - len(pl_module.trainer.datamodule.train_dataloader()) * pl_module.trainer.max_epochs // pl_module.trainer.accumulate_grad_batches) - else: - max_steps = pl_module.trainer.max_steps - - if pl_module.hparams.decay_power == "cosine": - scheduler = {"scheduler": get_cosine_schedule_with_warmup(optimizer, - num_warmup_steps=pl_module.hparams.warmup_steps, - num_training_steps=max_steps), - "interval": "step"} - else: - NotImplementedError("Only cosine scheduler is implemented for now") - - return ([optimizer], [scheduler],) def set_metrics(pl_module): From 94276de2db487fb787bbbfd0b919573e99b1587f Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 14:24:52 +0200 Subject: [PATCH 3/8] - keep working directory consistent upon restart - create experiment's result directories automatically - import explainer from existing module - dynamic state dict loading when initializing LRP - add option to specfiy number of masks per rise batch --- notebooks/2_1_benchmarking.ipynb | 126 +++++++++++++++---------------- 1 file changed, 62 insertions(+), 64 deletions(-) diff --git a/notebooks/2_1_benchmarking.ipynb b/notebooks/2_1_benchmarking.ipynb index b5dc655..f7c8208 100644 --- a/notebooks/2_1_benchmarking.ipynb +++ b/notebooks/2_1_benchmarking.ipynb @@ -2,16 +2,26 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "7a263fe4", "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/rahnfelj/GitRepositories/vit-shapley\n" + ] + } + ], "source": [ "import os\n", - "print(os.getcwd())\n", - "os.chdir('../')\n", + "\n", + "if \"cwd\" not in locals():\n", + " cwd = os.path.dirname(os.getcwd())\n", + "os.chdir(cwd)\n", "print(os.getcwd())" ] }, @@ -25,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "72c897bc", "metadata": {}, "outputs": [], @@ -64,19 +74,21 @@ "parallel_mode = (0, 1)\n", "backbone_to_use=[\"vit_base_patch16_224\"]\n", "_config.update(dataset_Pet())\n", - "evaluation_stage=[\"1_classifier_evaluate\",\n", - " \"2_surrogate_evaluate\",\n", - " \"3_explanation_generate\",\n", - " \"4_insert_delete\",\n", - " \"5_sensitivity\",\n", - " \"6_noretraining\",\n", - " \"7_classifiermasked\",\n", - " \"8_elapsedtime\",\n", - " \"9_estimationerror\"][1]\n", - "\n", - "_config.update(env_chanwkim()); _config.update({'gpus_classifier':[4,],\n", - " 'gpus_surrogate':[4,],\n", - " 'gpus_explainer':[4,]})\n", + "\n", + "evaluation_stages=[\"1_classifier_evaluate\",\n", + " \"2_surrogate_evaluate\",\n", + " \"3_explanation_generate\",\n", + " \"4_insert_delete\",\n", + " \"5_sensitivity\",\n", + " \"6_noretraining\",\n", + " \"7_classifiermasked\",\n", + " \"8_elapsedtime\",\n", + " \"9_estimationerror\"]\n", + "evaluation_stage = evaluation_stages[0]\n", + "\n", + "_config.update(env_chanwkim()); _config.update({'gpus_classifier':[0,],\n", + " 'gpus_surrogate':[0,],\n", + " 'gpus_explainer':[0,]})\n", "\n", "_config.update({'classifier_backbone_type': None,\n", " 'classifier_download_weight': False,\n", @@ -89,7 +101,10 @@ " 'surrogate_download_weight': False,\n", " 'surrogate_load_path': None})\n", "_config.update({'explainer_num_mask_samples': 2,\n", - " 'explainer_paired_mask_samples': True})" + " 'explainer_paired_mask_samples': True})\n", + "\n", + "for stage in evaluation_stages:\n", + " os.makedirs(f'results/{stage}/{_config[\"datasets\"]}/', exist_ok=True)" ] }, { @@ -184,14 +199,14 @@ " \"deit_small_patch16_224\":{\n", " },\n", " \"vit_base_patch16_224\":{\n", - " \"classifier_path\":\"results/wandb_transformer_interpretability_project/3g01rci7/checkpoints/epoch=9-step=909.ckpt\",\n", + " \"classifier_path\":\"results/wandb_transformer_interpretability_project/pets/classifier/checkpoints/epoch=18-step=1728.ckpt\",\n", " \"surrogate_path\": {\n", - " \"original\": \"results/wandb_transformer_interpretability_project/3g01rci7/checkpoints/epoch=9-step=909.ckpt\",\n", - " \"pre-softmax\": \"results/wandb_transformer_interpretability_project/146vf465/checkpoints/epoch=40-step=3730.ckpt\",\n", + " \"original\": \"results/wandb_transformer_interpretability_project/pets/classifier/checkpoints/epoch=18-step=1728.ckpt\",\n", + " \"pre-softmax\": \"results/wandb_transformer_interpretability_project/pets/surrogate/checkpoints/epoch=44-step=4094.ckpt\",\n", " #\"zero-input\": \"results/wandb_transformer_interpretability_project/2z2qs6t0/checkpoints/epoch=44-step=23219.ckpt\",\n", " #\"zero-embedding\": \"results/wandb_transformer_interpretability_project/1pbmwnvb/checkpoints/epoch=45-step=23735.ckpt\"\n", " },\n", - " \"explainer_path\":\"results/wandb_transformer_interpretability_project/2oq7lhr7/checkpoints/epoch=85-step=7911.ckpt\"\n", + " \"explainer_path\":\"results/wandb_transformer_interpretability_project/pets/explainer/checkpoints/epoch=85-step=7911.ckpt\"\n", " },\n", " \"deit_base_patch16_224\":{\n", "\n", @@ -449,14 +464,6 @@ "len(train_dataset), len(val_dataset), len(test_dataset)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "427faeaa", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "id": "7380660e", @@ -589,16 +596,6 @@ " surrogate_dict[backbone_type]=mask_method_dict" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8414a2d", - "metadata": {}, - "outputs": [], - "source": [ - "from vitmedical.modules.explainer import Explainer" - ] - }, { "cell_type": "code", "execution_count": null, @@ -758,9 +755,12 @@ "\n", "for idx, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n", " checkpoint = torch.load(backbone_type_config[\"classifier_path\"], map_location=\"cpu\")\n", - " checkpoint[\"state_dict\"]=OrderedDict([(k.replace('backbone.',''), v) for k, v in checkpoint[\"state_dict\"].items()])\n", - " state_dict = checkpoint[\"state_dict\"]\n", - " \n", + " if \"state_dict\" in checkpoint:\n", + " state_dict = checkpoint[\"state_dict\"]\n", + " else:\n", + " state_dict = checkpoint\n", + " state_dict = OrderedDict([(k.replace('backbone.',''), v) for k, v in state_dict.items()])\n", + "\n", " model = getattr(ViT_new, backbone_type)(num_classes=_config[\"output_dim\"]).to(_config[\"gpus_classifier\"][idx])\n", " ret = model.load_state_dict(state_dict, strict=False)\n", " print(f\"Model parameters were updated from a checkpoint file {backbone_type_config['classifier_path']}\")\n", @@ -1124,17 +1124,17 @@ "metadata": {}, "outputs": [], "source": [ - "def rise(image, surrogate=None, classifier=None, include_prob=0.5, N=2000):\n", + "def rise(image, surrogate=None, classifier=None, include_prob=0.5, N=2000, masks_per_gpu=100):\n", " assert (surrogate is None) != (classifier is None)\n", " \n", " prob_list=[]\n", " mask_list=[]\n", " \n", " with torch.no_grad():\n", - " for i in range(N//100):\n", - " mask=torch.rand(100, 196) Date: Fri, 19 Jul 2024 16:59:58 +0200 Subject: [PATCH 4/8] evaluate only surrogate checkpoints that have been loaded --- notebooks/2_1_benchmarking.ipynb | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/notebooks/2_1_benchmarking.ipynb b/notebooks/2_1_benchmarking.ipynb index f7c8208..769b36f 100644 --- a/notebooks/2_1_benchmarking.ipynb +++ b/notebooks/2_1_benchmarking.ipynb @@ -2,20 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "7a263fe4", "metadata": { "scrolled": false }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/rahnfelj/GitRepositories/vit-shapley\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "\n", @@ -35,7 +27,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "72c897bc", "metadata": {}, "outputs": [], @@ -1932,7 +1924,7 @@ " out_original=surrogate_dict[backbone_type][\"original\"](batch[\"images\"].to(surrogate_dict[backbone_type][\"original\"].device),\n", " torch.ones((len(batch[\"images\"]), 196)).to(surrogate_dict[backbone_type][\"original\"].device))\n", "\n", - " for mask_location_model in [\"original\" , \"pre-softmax\", \"zero-input\", \"zero-embedding\"]:\n", + " for mask_location_model in surrogate_dict[backbone_type]:\n", " if mask_location_model==\"original\":\n", " kl_divergence=0\n", "\n", From fccea6db2d6324f64a585994e870727e15d25a9e Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 17:01:59 +0200 Subject: [PATCH 5/8] clean up unused code in load dataset cell --- notebooks/2_1_benchmarking.ipynb | 74 +------------------------------- 1 file changed, 1 insertion(+), 73 deletions(-) diff --git a/notebooks/2_1_benchmarking.ipynb b/notebooks/2_1_benchmarking.ipynb index 769b36f..37304c6 100644 --- a/notebooks/2_1_benchmarking.ipynb +++ b/notebooks/2_1_benchmarking.ipynb @@ -221,51 +221,6 @@ "metadata": {}, "outputs": [], "source": [ - "def generate_mask(num_players: int, num_mask_samples: int or None = None, paired_mask_samples: bool = True,\n", - " mode: str = 'uniform', random_state: np.random.RandomState or None = None) -> np.array:\n", - " \"\"\"\n", - " Args:\n", - " num_players: the number of players in the coalitional game\n", - " num_mask_samples: the number of masks to generate\n", - " paired_mask_samples: if True, the generated masks are pairs of x and 1-x.\n", - " mode: the distribution that the number of masked features follows. ('uniform' or 'shapley')\n", - " random_state: random generator\n", - "\n", - " Returns:\n", - " torch.Tensor of shape\n", - " (num_masks, num_players) if num_masks is int\n", - " (num_players) if num_masks is None\n", - "\n", - " \"\"\"\n", - " random_state = random_state or np.random\n", - "\n", - " num_samples_ = num_mask_samples or 1\n", - "\n", - " if paired_mask_samples:\n", - " assert num_samples_ % 2 == 0, \"'num_samples' must be a multiple of 2 if 'paired' is True\"\n", - " num_samples_ = num_samples_ // 2\n", - " else:\n", - " num_samples_ = num_samples_\n", - "\n", - " if mode == 'uniform':\n", - " masks = (random_state.rand(num_samples_, num_players) > random_state.rand(num_samples_, 1)).astype('int')\n", - " elif mode == 'shapley':\n", - " probs = 1 / (np.arange(1, num_players) * (num_players - np.arange(1, num_players)))\n", - " probs = probs / probs.sum()\n", - " masks = (random_state.rand(num_samples_, num_players) > 1 / num_players * random_state.choice(\n", - " np.arange(num_players - 1), p=probs, size=[num_samples_, 1])).astype('int')\n", - " else:\n", - " raise ValueError(\"'mode' must be 'random' or 'shapley'\")\n", - "\n", - " if paired_mask_samples:\n", - " masks = np.stack([masks, 1 - masks], axis=1).reshape(num_samples_ * 2, num_players)\n", - "\n", - " if num_mask_samples is None:\n", - " masks = masks.squeeze(0)\n", - " return masks # (num_masks)\n", - " else:\n", - " return masks # (num_samples, num_masks)\n", - "\n", "def set_datamodule(datasets,\n", " dataset_location,\n", " explanation_location_train,\n", @@ -340,33 +295,6 @@ " per_gpu_batch_size=_config[\"per_gpu_batch_size\"],\n", " test_data_split=_config[\"test_data_split\"])\n", "\n", - "# The batch for training classifier consists of images and labels, but the batch for training explainer consists of images and masks.\n", - "# The masks are generated to follow the Shapley distribution.\n", - "\"\"\"\n", - "original_getitem = copy.deepcopy(datamodule.dataset_cls.__getitem__)\n", - "def __getitem__(self, idx):\n", - " if self.split == 'train':\n", - " masks = generate_mask(num_players=surrogate.num_players,\n", - " num_mask_samples=_config[\"explainer_num_mask_samples\"],\n", - " paired_mask_samples=_config[\"explainer_paired_mask_samples\"], mode='shapley')\n", - " elif self.split == 'val' or self.split == 'test':\n", - " # get cached if available\n", - " if not hasattr(self, \"masks_cached\"):\n", - " self.masks_cached = {}\n", - " masks = self.masks_cached.setdefault(idx, generate_mask(num_players=surrogate.num_players,\n", - " num_mask_samples=_config[\n", - " \"explainer_num_mask_samples\"],\n", - " paired_mask_samples=_config[\n", - " \"explainer_paired_mask_samples\"],\n", - " mode='shapley'))\n", - " else:\n", - " raise ValueError(\"'split' variable must be train, val or test.\")\n", - " return {\"images\": original_getitem(self, idx)[\"images\"],\n", - " \"labels\": original_getitem(self, idx)[\"labels\"],\n", - " \"masks\": masks}\n", - "datamodule.dataset_cls.__getitem__ = __getitem__\n", - "\"\"\"\n", - "\n", "datamodule.set_train_dataset()\n", "datamodule.set_val_dataset()\n", "datamodule.set_test_dataset()\n", @@ -431,7 +359,7 @@ " xy=[dset[idx] for idx in images_idx]\n", " x, y = zip(*[(i['images'], i['labels']) for i in xy])\n", " x = torch.stack(x)\n", - " y_labels=[dset.labels[i] for i in y] " + " y_labels=[dset.labels[i] for i in y]" ] }, { From a99977fe2f9cfe256bbc08a51559be9bcda1b975 Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 17:10:20 +0200 Subject: [PATCH 6/8] fix bug: kernelshap method is evaluated on a subset of ImageNette. The corresponding cell was executed regardless of dataset leading to an exception when running it with pets --- notebooks/2_1_benchmarking.ipynb | 33 +++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/notebooks/2_1_benchmarking.ipynb b/notebooks/2_1_benchmarking.ipynb index 37304c6..74af09c 100644 --- a/notebooks/2_1_benchmarking.ipynb +++ b/notebooks/2_1_benchmarking.ipynb @@ -76,7 +76,7 @@ " \"7_classifiermasked\",\n", " \"8_elapsedtime\",\n", " \"9_estimationerror\"]\n", - "evaluation_stage = evaluation_stages[0]\n", + "evaluation_stage = evaluation_stages[3]\n", "\n", "_config.update(env_chanwkim()); _config.update({'gpus_classifier':[0,],\n", " 'gpus_surrogate':[0,],\n", @@ -1750,7 +1750,7 @@ " \"riseclassifier\", \n", " \"ours\"]\n", "explanation_method_to_run=[]\n", - "explanation_method_to_run+=explanation_method_to_run_[-1:]\n", + "explanation_method_to_run+=explanation_method_to_run_[1:2]\n", "\n", "\n", "print(explanation_method_to_run)" @@ -1959,19 +1959,22 @@ "metadata": {}, "outputs": [], "source": [ - "label_to_use=['Garbage truck', \n", - " 'Tench', \n", - " 'English springer', \n", - " 'Parachute', \n", - " 'Golf ball', \n", - " 'Gas pump']\n", - "kernelshap_sample_idx_list_all=[]\n", - "for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n", - " for random_seed in [2, 3, 4, 5]:\n", - " label_data_list=np.array([i['label'] for i in dset.data])\n", - " kernelshap_sample_idx_list=[np.random.RandomState(random_seed).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)]) for label_idx in [label_name_list.index(label) for label in label_to_use]]\n", - " kernelshap_sample_idx_list_all+=kernelshap_sample_idx_list\n", - "kernelshap_sample_path_list_all=[dset[i]['path'] for i in kernelshap_sample_idx_list_all] " + "if _config[\"datasets\"] == \"ImageNette\":\n", + " label_to_use=['Garbage truck', \n", + " 'Tench', \n", + " 'English springer', \n", + " 'Parachute', \n", + " 'Golf ball', \n", + " 'Gas pump']\n", + " kernelshap_sample_idx_list_all=[]\n", + " for _, (backbone_type, backbone_type_config) in enumerate(backbone_type_config_dict.items()):\n", + " for random_seed in [2, 3, 4, 5]:\n", + " label_data_list=np.array([i['label'] for i in dset.data])\n", + " kernelshap_sample_idx_list=[np.random.RandomState(random_seed).choice(np.arange(len(label_data_list))[(label_data_list==label_idx)]) for label_idx in [label_name_list.index(label) for label in label_to_use]]\n", + " kernelshap_sample_idx_list_all+=kernelshap_sample_idx_list\n", + " kernelshap_sample_path_list_all=[dset[i]['path'] for i in kernelshap_sample_idx_list_all]\n", + "else:\n", + " kernelshap_sample_path_list_all = []" ] }, { From 284a1e5b9a21ed3fcb10c7059c7c939380cb1551 Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 17:53:47 +0200 Subject: [PATCH 7/8] remove unused cells --- notebooks/2_1_benchmarking.ipynb | 30 +----------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/notebooks/2_1_benchmarking.ipynb b/notebooks/2_1_benchmarking.ipynb index 74af09c..3331197 100644 --- a/notebooks/2_1_benchmarking.ipynb +++ b/notebooks/2_1_benchmarking.ipynb @@ -76,7 +76,7 @@ " \"7_classifiermasked\",\n", " \"8_elapsedtime\",\n", " \"9_estimationerror\"]\n", - "evaluation_stage = evaluation_stages[3]\n", + "evaluation_stage = evaluation_stages[0]\n", "\n", "_config.update(env_chanwkim()); _config.update({'gpus_classifier':[0,],\n", " 'gpus_surrogate':[0,],\n", @@ -2328,26 +2328,6 @@ " return explanation_expaned_bool" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "835e5d77", - "metadata": {}, - "outputs": [], - "source": [ - "evaluation_stage=\"4_insert_delete\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "161dc89e", - "metadata": {}, - "outputs": [], - "source": [ - "explanation_method_to_run=[\"kernelshap\"]" - ] - }, { "cell_type": "code", "execution_count": null, @@ -2358,14 +2338,6 @@ "estimationerror_sample_path_list=pd.DataFrame(data_loader.dataset.data).groupby(\"label\").apply(lambda x: x.sample(n=10, random_state=42))[\"img_path\"].tolist()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "f37cb8db", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, From e470f9734c9a1fad3a7fa952a53e4f592be80d4a Mon Sep 17 00:00:00 2001 From: JensRahnfeld Date: Fri, 19 Jul 2024 20:48:38 +0200 Subject: [PATCH 8/8] add conda env file --- environment.yaml | 273 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 environment.yaml diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000..76f8897 --- /dev/null +++ b/environment.yaml @@ -0,0 +1,273 @@ +name: vit-shapley +channels: + - pytorch + - fastai + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - abseil-cpp=20230802.0 + - absl-py=2.1.0 + - annotated-types=0.6.0 + - asttokens=2.4.1 + - blas=1.0 + - brotli=1.1.0 + - brotli-bin=1.1.0 + - brotli-python=1.1.0 + - bzip2=1.0.8 + - c-ares=1.32.2 + - ca-certificates=2024.7.4 + - catalogue=2.0.10 + - certifi=2024.7.4 + - cffi=1.16.0 + - charset-normalizer=3.3.2 + - click=7.1.2 + - cloudpathlib=0.16.0 + - colorama=0.4.6 + - comm=0.2.2 + - confection=0.1.4 + - configparser=7.0.0 + - contourpy=1.2.1 + - cudatoolkit=11.1.1 + - cycler=0.11.0 + - cymem=2.0.6 + - cython-blis=0.7.9 + - dataclasses=0.8 + - dbus=1.13.18 + - debugpy=1.8.2 + - decorator=5.1.1 + - docker-pycreds=0.4.0 + - docopt=0.6.2 + - exceptiongroup=1.2.2 + - executing=2.0.1 + - expat=2.6.2 + - fastai=2.5.3 + - fastcore=1.3.29 + - fastdownload=0.0.7 + - fastprogress=1.0.3 + - ffmpeg=4.3 + - filelock=3.15.4 + - fontconfig=2.14.2 + - fonttools=4.53.1 + - freetype=2.12.1 + - fsspec=2024.6.1 + - future=1.0.0 + - gitdb=4.0.11 + - gitpython=3.1.43 + - glib=2.80.2 + - glib-tools=2.80.2 + - gmp=6.3.0 + - gnutls=3.6.13 + - grpc-cpp=1.48.2 + - grpcio=1.48.2 + - gst-plugins-base=1.14.1 + - gstreamer=1.14.1 + - gtest=1.14.0 + - h2=4.1.0 + - hpack=4.0.0 + - huggingface_hub=0.23.4 + - hyperframe=6.0.1 + - icu=73.2 + - idna=3.7 + - importlib-metadata=8.0.0 + - importlib-resources=6.4.0 + - importlib_metadata=8.0.0 + - importlib_resources=6.4.0 + - intel-openmp=2023.1.0 + - ipdb=0.13.9 + - ipykernel=6.29.5 + - ipython=8.18.1 + - jbig=2.1 + - jedi=0.19.1 + - jinja2=3.1.4 + - joblib=1.4.2 + - jpeg=9e + - jsonpickle=1.5.1 + - jupyter_client=8.6.2 + - jupyter_core=5.7.2 + - keyutils=1.6.1 + - kiwisolver=1.4.4 + - krb5=1.20.1 + - lame=3.100 + - langcodes=3.3.0 + - lcms2=2.12 + - ld_impl_linux-64=2.38 + - lerc=2.2.1 + - libblas=3.9.0 + - libbrotlicommon=1.1.0 + - libbrotlidec=1.1.0 + - libbrotlienc=1.1.0 + - libcblas=3.9.0 + - libclang=14.0.6 + - libclang13=14.0.6 + - libcups=2.3.3 + - libdeflate=1.7 + - libedit=3.1.20191231 + - libexpat=2.6.2 + - libffi=3.4.4 + - libgcc-ng=14.1.0 + - libgfortran-ng=14.1.0 + - libgfortran5=14.1.0 + - libglib=2.80.2 + - libgomp=14.1.0 + - libhwloc=2.11.1 + - libiconv=1.17 + - liblapack=3.9.0 + - libllvm14=14.0.6 + - libnsl=2.0.1 + - libpng=1.6.43 + - libpq=12.17 + - libprotobuf=3.20.3 + - libsodium=1.0.18 + - libsqlite=3.45.3 + - libstdcxx-ng=14.1.0 + - libtiff=4.3.0 + - libuuid=2.38.1 + - libuv=1.48.0 + - libwebp-base=1.4.0 + - libxcb=1.16 + - libxcrypt=4.4.36 + - libxkbcommon=1.7.0 + - libxml2=2.13.1 + - libzlib=1.2.13 + - lz4-c=1.9.3 + - markdown=3.6 + - markdown-it-py=2.2.0 + - markupsafe=2.1.5 + - matplotlib=3.9.1 + - matplotlib-base=3.9.1 + - matplotlib-inline=0.1.7 + - mdurl=0.1.0 + - mkl=2023.1.0 + - mkl-service=2.4.0 + - munch=2.5.0 + - munkres=1.1.4 + - murmurhash=1.0.7 + - mysql=5.7.20 + - ncurses=6.4.20240210 + - nest-asyncio=1.6.0 + - nettle=3.6 + - numpy=1.26.4 + - olefile=0.47 + - openh264=2.1.1 + - openjpeg=2.4.0 + - openssl=3.3.1 + - packaging=24.1 + - pandas=1.4.0 + - parso=0.8.4 + - pathtools=0.1.2 + - pathy=0.10.2 + - patsy=0.5.6 + - pcre2=10.43 + - pexpect=4.9.0 + - pickleshare=0.7.5 + - pillow=8.4.0 + - pip=24.0 + - platformdirs=4.2.2 + - preshed=3.0.6 + - promise=2.3 + - prompt-toolkit=3.0.47 + - protobuf=3.20.3 + - psutil=6.0.0 + - pthread-stubs=0.4 + - ptyprocess=0.7.0 + - pure_eval=0.2.2 + - py-cpuinfo=9.0.0 + - pybind11-abi=4 + - pycparser=2.22 + - pydantic=2.5.3 + - pydantic-core=2.14.6 + - pydeprecate=0.3.2 + - pygments=2.15.1 + - pyparsing=3.0.9 + - pyqt=5.15.4 + - pyqt5-sip=12.9.0 + - pysocks=1.7.1 + - python=3.9.19 + - python-dateutil=2.9.0 + - python_abi=3.9 + - pytorch=1.10.2 + - pytorch-lightning=1.5.9 + - pytorch-mutex=1.0 + - pytz=2024.1 + - pyyaml=6.0.1 + - pyzmq=26.0.3 + - qhull=2020.2 + - qt-main=5.15.2 + - re2=2022.04.01 + - readline=8.2 + - regex=2024.5.15 + - requests=2.32.3 + - rich=13.7.1 + - sacred=0.8.2 + - sacremoses=0.0.53 + - scikit-learn=1.4.2 + - scipy=1.13.1 + - seaborn=0.13.2 + - seaborn-base=0.13.2 + - sentry-sdk=2.10.0 + - setuptools=59.5.0 + - shellingham=1.5.0.post1 + - shortuuid=1.0.13 + - sip=6.5.1 + - six=1.16.0 + - smart_open=5.2.1 + - smmap=5.0.0 + - spacy=3.7.2 + - spacy-legacy=3.0.12 + - spacy-loggers=1.0.4 + - sqlite=3.45.3 + - srsly=2.4.8 + - stack_data=0.6.2 + - statsmodels=0.14.2 + - subprocess32=3.5.4 + - tbb=2021.12.0 + - tensorboard=2.17.0 + - tensorboard-data-server=0.7.0 + - termcolor=2.4.0 + - thinc=8.2.2 + - threadpoolctl=3.5.0 + - timm=0.5.4 + - tk=8.6.14 + - tokenizers=0.10.3 + - toml=0.10.2 + - torchmetrics=0.7.0 + - torchvision=0.11.3 + - tornado=6.4.1 + - tqdm=4.66.4 + - traitlets=5.14.3 + - transformers=4.16.2 + - ttach=0.0.3 + - typer=0.9.0 + - typing-extensions=4.12.2 + - typing_extensions=4.12.2 + - tzdata=2024a + - unicodedata2=15.1.0 + - urllib3=2.2.2 + - wandb=0.12.10 + - wasabi=0.9.1 + - wcwidth=0.2.13 + - weasel=0.3.4 + - werkzeug=3.0.3 + - wheel=0.43.0 + - wrapt=1.16.0 + - xkeyboard-config=2.42 + - xorg-kbproto=1.0.7 + - xorg-libx11=1.8.9 + - xorg-libxau=1.0.11 + - xorg-libxdmcp=1.1.3 + - xorg-xextproto=7.3.0 + - xorg-xproto=7.0.31 + - xz=5.4.6 + - yaml=0.2.5 + - yaspin=2.2.0 + - zeromq=4.3.5 + - zipp=3.19.2 + - zlib=1.2.13 + - zstandard=0.19.0 + - zstd=1.5.6 + - pip: + - captum==0.5.0 + - einops==0.8.0 + - opencv-python==4.6.0.66