diff --git a/README.md b/README.md index 914559c..5900e43 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,9 @@ Tutorials demonstrating how to use ONNX in practice for varied scenarios across * [Float16 <-> Float32 converter](https://github.com/onnx/onnx-docker/blob/master/onnx-ecosystem/converter_scripts/float32_float16_onnx.ipynb) * [Version conversion](tutorials/VersionConversion.md) - +## Application of ONNX +* [Explainable AI for ONNX models](tutorials/XAI4ONNX_dianna_overview.ipynb) + ## Contributing We welcome improvements to the convertor tools and contributions of new ONNX bindings. Check out [contributor guide](https://github.com/onnx/onnx/blob/main/CONTRIBUTING.md) to get started. diff --git a/tutorials/XAI4ONNX_dianna_overview.ipynb b/tutorials/XAI4ONNX_dianna_overview.ipynb new file mode 100644 index 0000000..ca79e7f --- /dev/null +++ b/tutorials/XAI4ONNX_dianna_overview.ipynb @@ -0,0 +1,2000 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "\"Logo_ER10\"\n", + "\n", + "# General overview of using dianna\n", + "\n", + "**[DIANNA](https://dianna-ai.github.io/dianna/) is a Python package that brings explainable AI (XAI) to your research project.**\n", + "\n", + "It wraps carefully selected XAI methods (explainers) in a simple, uniform interface. It's built by, with and for (academic) researchers and research software engineers working on machine learning projects.\n", + "\n", + "This overview illustrates the main strengths of DIANNA, namely supporting *many data modalities* and *several explainers*. DIANNA is *future-proof* by supporting and advocating the *[ONNX](https://onnx.ai/) de-facto standard* for Neural Network models. Many modern frameworks alpready support native export to ONNX, for tutorials on conversion from PyTorch, Keras, Scikit-learn and TensorFlow see [conversion_onnx](https://github.com/dianna-ai/dianna/tree/main/tutorials/conversion_onnx) folder. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **General workflow**\n", + "\n", + "1. Provide your *trained model* and *data item* ( *text, image, time series or tabular* )\n", + "\n", + "```python\n", + "model_path = 'your_model.onnx' # model trained on your data modality\n", + "data_item = # data item for which the model's prediction needs to be explained \n", + "```\n", + "\n", + "2. If the task is *classification*: which are the *classes* your model has been trained for?\n", + "\n", + "```python \n", + "labels = [class_a, class_b] # example of binary classification labels\n", + "```\n", + "*Which* of these classes do you want an explanation for?\n", + "```python\n", + "explained_class_index = labels.index() # explained_class can be any of the labels\n", + "```\n", + "\n", + "3. Run dianna with the *explainer* of your choice ( *'LIME', 'RISE' or 'KernalSHAP'*) and visualize the output:\n", + "\n", + "```python\n", + "explanation = dianna.(model_path, data_item, explainer)\n", + "dianna.visualization.(explanation[explained_class_index], data_item)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setting up" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Colab Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "running_in_colab = 'google.colab' in str(get_ipython())\n", + "if running_in_colab:\n", + " # install dianna\n", + " !python3 -m pip install dianna[notebooks]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "#### Libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-06-19 11:56:05.340972: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + } + ], + "source": [ + "import os\n", + "import warnings\n", + "warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "# for explanations and visualization\n", + "import dianna\n", + "from dianna import visualization\n", + "from dianna import utils as dianna_utils\n", + "from dianna.utils.tokenizers import SpacyTokenizer\n", + "from dianna.utils.onnx_runner import SimpleModelRunner\n", + "from dianna.utils.downloader import download\n", + "\n", + "# ONNX\n", + "import onnx\n", + "import onnxruntime\n", + "from onnx_tf.backend import prepare\n", + "from dianna.utils.onnx_runner import SimpleModelRunner\n", + "\n", + "# text-related\n", + "import spacy\n", + "from torchtext.vocab import Vectors\n", + "from scipy.special import softmax\n", + "from scipy.special import expit as sigmoid\n", + "\n", + "# keras model and preprocessing tools for Image\n", + "from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions\n", + "from keras import backend as K\n", + "from keras import utils as keras_utils\n", + "\n", + "# for tabular data \n", + "from sklearn.model_selection import train_test_split\n", + "from numba.core.errors import NumbaDeprecationWarning\n", + "import warnings\n", + "# silence the Numba deprecation warnings in shap\n", + "warnings.simplefilter('ignore', category=NumbaDeprecationWarning)\n", + "\n", + "# visualizations\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "import random\n", + "random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **Data modalities**" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DIANNA supports *text, images, time-series* and *tabular data.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Text example*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's illustrate the general workflow above with textual data. The data item of interest is a sentence being (a part of) a movie review and the model has been trained to classify the movie reviews from the [Stanford sentiment treebank](https://nlp.stanford.edu/sentiment/index.html) into 'positive' and 'negative' sentiment classes. We are interested in which words are contributing positively (red) and which - negatively (blue) towards the model's decision to classify the review as positive and we would like to use the LIME explainer:\n", + "\n", + "*For a full example see the [lime_text](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/LIME/lime_text.ipynb) tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Provide your *trained model* and *text* of interest." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Download the pre-trained model from Zenodo:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "model_path = download('movie_review_model.onnx', 'model')" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# labels\n", + "word_vector_path = download('movie_reviews_word_vectors.txt', 'data')\n", + "labels = (\"negative\", \"positive\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The classifier accepts numerical tokens as input and outputs a score between 0 (the review is negative) and 1 (the review is positive).\n", + "Therefore, we define a model runner class, which accepts a sentence as input instead and returns one of two classes: negative or positive." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizer downloaded.\n" + ] + } + ], + "source": [ + "# ensure the tokenizer for english is available\n", + "from IPython.display import clear_output\n", + "\n", + "spacy.cli.download('en_core_web_sm')\n", + "clear_output()\n", + "print(\"Tokenizer downloaded.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class MovieReviewsModelRunner:\n", + " def __init__(self, model, word_vectors, max_filter_size):\n", + " self.run_model = dianna_utils.get_function(str(model))\n", + " self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))\n", + " self.max_filter_size = max_filter_size\n", + " \n", + " self.tokenizer = SpacyTokenizer(name='en_core_web_sm')\n", + "\n", + " def __call__(self, sentences):\n", + " # ensure the input has a batch axis\n", + " if isinstance(sentences, str):\n", + " sentences = [sentences]\n", + "\n", + " tokenized_sentences = []\n", + " for sentence in sentences:\n", + " # tokenize and pad to minimum length\n", + " tokens = self.tokenizer.tokenize(sentence.lower())\n", + " if len(tokens) < self.max_filter_size:\n", + " tokens += [''] * (self.max_filter_size - len(tokens))\n", + " \n", + " # numericalize the tokens\n", + " tokens_numerical = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi['']\n", + " for token in tokens]\n", + " tokenized_sentences.append(tokens_numerical)\n", + " \n", + " # run the model, applying a sigmoid because the model outputs logits\n", + " logits = self.run_model(tokenized_sentences)\n", + " pred = np.apply_along_axis(sigmoid, 1, logits)\n", + " \n", + " # output two classes\n", + " positivity = pred[:, 0]\n", + " negativity = 1 - positivity\n", + " return np.transpose([negativity, positivity])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# define model runner. max_filter_size is a property of the model\n", + "model_runner = MovieReviewsModelRunner(model_path, word_vector_path, max_filter_size=5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define a sentence of interest:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "review = \"A delectable and intriguing thriller filled with surprises\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " 2. Which are the classes your model has been trained for? Which of these classes do you want an explanation for?" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = (\"negative\", \"positive\") # sentiments of the movie reviews\n", + "explained_class_index = labels.index(\"positive\") # we are interested why our sentence is classified as having a positive sentiment\n", + "explained_class_index\n", + "labels.index('positive')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " 3. Run dianna with the explainer of your choice, 'LIME', and visualize the output. For textual data use the ```explain_text``` function.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "explanation = dianna.explain_text(model_runner, review, model_runner.tokenizer,'LIME', labels=[explained_class_index])[0]\n", + "explanation\n", + "fig, _ = visualization.highlight_text(explanation, model_runner.tokenizer.tokenize(review))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*The positive words (in red) carry the 'positive' sentiment classification.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Image example*\n", + "\n", + "Here we apply the general workflow with image data from Imagenet. The data item of interest is an image of a bee and we use the ResNet 50 model trained on ImageNet to classify 1000 objects. We are interested in which pixels are contributing positively (red) and which - negatively (blue) towards the model's decision to classify the image as a 'bee' and we would like to use the RISE, [Petsiuk et al., \"RISE: Randomized Input Sampling for Explanation of Black-box Models\", BMVC 2018].\n", + " \n", + "\n", + "*For a full example see the [rise_imagenet](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_imagenet.ipynb) tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Provide your *trained model* and *image* of interest." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Initialize the pretrained model." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class Model():\n", + " def __init__(self):\n", + " K.set_learning_phase(0)\n", + " self.model = ResNet50()\n", + " self.input_size = (224, 224)\n", + " \n", + " def run_on_batch(self, x):\n", + " return self.model.predict(x)\n", + "\n", + "model = Model()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load and preprocess the 'bee' image." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def load_img(path):\n", + " img = keras_utils.load_img(path, target_size=model.input_size)\n", + " x = keras_utils.img_to_array(img)\n", + " preproc_img = preprocess_input(x)\n", + " return img, preproc_img " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "img, preproc_img = load_img(download('bee.jpg', 'data'))\n", + "fig, ax = plt.subplots() \n", + "ax.axis('off')\n", + "plt.imshow(img)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 2. Which are the classes your model has been trained for? Which of these classes do you want an explanation for?" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "309\n", + "bee\n" + ] + } + ], + "source": [ + "labels = [range(1000)] # 1000 classes of objects\n", + "# we are interested why our image is classified as a 'bee'\n", + "def class_name(idx):\n", + " return decode_predictions(np.eye(1, 1000, idx))[0][0][1]\n", + "for i in range(1000): \n", + " if class_name(i) == 'bee':\n", + " explained_class_index = i\n", + "print(explained_class_index)\n", + "print(class_name(explained_class_index))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3. Run dianna with the explainer of your choice, 'RISE', and visualize the output. For image use the ```explain_image``` function." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "RISE masks random portions of the input image and passes the masked image through the model — the masked portion that decreases accuracy the most is the most “important” portion.To call the explainer and generate relevance scores map, the user need to specifiy the number of masks being randomly generated (`n_masks`), the resolution of features in masks (`feature_res`) and for each mask and each feature in the image, the probability of being kept unmasked (`p_keep`)." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Explaining: 0%| | 0/10 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(f'Explanation for `{class_name(explained_class_index)}` ({predictions[0][explained_class_index]})')\n", + "visualization.plot_image(explanation[explained_class_index], keras_utils.img_to_array(img)/255., heatmap_cmap='bwr')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "What would make our model think that the image is one of a 'garden_spider'?" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explanation for `garden_spider` (0.005400071851909161)\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "another_class_index = 74 # the fifth prediciton was 'garden_spider'\n", + "print(f'Explanation for `{class_name(another_class_index)}` ({predictions[0][another_class_index]})')\n", + "visualization.plot_image(explanation[another_class_index], keras_utils.img_to_array(img)/255., heatmap_cmap='bwr')\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*It is interesting to observe that the wings of the insect support the model's classification of the image as 'bee', while the body would be a strong evidence for 'spider'*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Time series example*\n", + "\n", + "Here we apply the general workflow on very simple time series (TS) data representing daily temperatures with hot and cold days. \n", + "We define a simple expert model that classifies the days to 'summer' or 'winter' based on a simple thresholding.\n", + "We use RISE to explain the individual days' contributions to the model's decision. This illustrates that the explainers can work on any ML model and are not limited to neural networks.\n", + "\n", + "*For a full example containing more complicated real temperature data from locations in Europe see the [rise_timeseries_weather](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/RISE/rise_timeseries_weather.ipynb) tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Define your *model* and *time-series* of interest." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# make up a weather dataset with extrems\n", + "cold_with_2_hot_days = np.expand_dims(np.array([30, 29] + list(np.zeros(26))) , axis=1)\n", + "data_extreme = cold_with_2_hot_days\n", + "fig = plt.figure()\n", + "plt.plot(data_extreme)\n", + "plt.xlabel(\"Time index\")\n", + "plt.ylabel(\"Celcius\")\n", + "plt.title(\"Temperature\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can define an 'expert' model which decides it's summer if the mean temperature is above a threshold, and winter - otherwise." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# We define a threshold for the model to make decisions\n", + "# The label is [\"summer\", \"winter\"]\n", + "threshold = 14\n", + "\n", + "def run_expert_model(data):\n", + " is_summer = np.mean(np.mean(data, axis=1), axis=1) > threshold\n", + " number_of_classes = 2\n", + " number_of_instances = data.shape[0]\n", + " result = np.zeros((number_of_instances ,number_of_classes))\n", + " result[is_summer] = [1.0, 0.0]\n", + " result[~is_summer] = [0.0, 1.0]\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " 2. Which are the classes your model has been trained for? Which of these classes do you want an explanation for?" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "labels = ('summer', 'winter') # two seasons\n", + "explained_class_index = labels.index('summer') # we are interested why our time-series is classified as 'summer'\n", + "explained_class_index\n", + "labels.index('summer')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " 3. Run dianna with the explainer of your choice, 'RISE', and visualize the output. For time-series data use the ```explain_timeseries``` function.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "RISE masks random portions of the input time-series based on given segmentations and passes the masked time-series through the model — the masked portion that decreases accuracy the most is the most “important” portion. we need to define the approach for masking (mask_type). Since our data is highly skewed, here we make the masked data cutoff to be the \"threshold\" value instead of the mean." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# we use the threshold to mask the data\n", + "def input_train_mean(_data):\n", + " return threshold" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Explaining: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 15275.90it/s]\n" + ] + } + ], + "source": [ + "# call the explainer\n", + "explanation = dianna.explain_timeseries(run_expert_model, input_timeseries=data_extreme,\n", + " method='rise', labels=[0,1], p_keep=0.1,\n", + " n_masks=10000, mask_type=input_train_mean)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we can visualize the relevance scores overlaid on time-series using the visualization functionality in dianna." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Normalize the explanation scores for the purpose of visualization\n", + "def normalize(data):\n", + " \"\"\"Squash all values into [-1,1] range.\"\"\"\n", + " zero_to_one = (data - np.min(data)) / (np.max(data) - np.min(data))\n", + " return 2*zero_to_one -1\n", + "\n", + "heatmap_channel = normalize(explanation[0])\n", + "segments = []\n", + "for i in range(len(heatmap_channel) - 1):\n", + " segments.append({\n", + " 'index': i,\n", + " 'start': i - 0.5,\n", + " 'stop': i + 0.5,\n", + " 'weight': heatmap_channel[i]})\n", + "fig, _ = visualization.plot_timeseries(range(len(heatmap_channel)), data_extreme,\n", + " segments, x_label=\"Time index\", y_label=\"Temperature\", cmap='bwr')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*The explanation for the classification of 'summer' given by the RISE explainer is consistent with our expectation as it marks all hot days in the timeseries.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tabular data example*\n", + "\n", + "In the examples so far, we have shown how dianna works on **classification** problems. Here we demonstrate the KernelSHAP explainer for a **regression** problem of the next-day temperature prediciton on tabular data. The model is an MLP regressor [trained](https://github.com/dianna-ai/dianna-exploration/blob/main/example_data/model_generation/sunshine_prediction/generate_model.ipynb) on a [weather dataset](https://zenodo.org/records/10580833) of temperatures for several locations in Europe.\n", + "\n", + "*The full example is given in the [kernalshap_tabular_weather](https://github.com/dianna-ai/dianna/blob/main/tutorials/explainers/KernelSHAP/kernelshap_tabular_weather.ipynb) tutorial." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Get the *data* and the *model* to explain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Load and prepare the data. As the target, the sunshine hours for the next day in the data-set will be used. Therefore, we will remove the last data point as this has no target. A tabular regression model will be trained which does not require time-based data, therefore DATE and MONTH can be removed.\n", + "\n", + "Select an instance to explain. DIANNA requires input in numpy format, so the input data is converted into a numpy array." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv(download('weather_prediction_dataset_light.csv', 'data'))\n", + "\n", + "X_data = data.drop(columns=['DATE', 'MONTH'])[:-1]\n", + "y_data = data.loc[1:][\"BASEL_sunshine\"]\n", + "\n", + "# training, validation and test split\n", + "X_train, X_holdout, y_train, y_holdout = train_test_split(X_data, y_data, test_size=0.3, random_state=0)\n", + "X_val, X_test, y_val, y_test = train_test_split(X_holdout, y_holdout, test_size=0.5, random_state=0)\n", + "\n", + "# select an instance to explain\n", + "data_instance = X_test.iloc[10].to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BASEL_cloud_coverBASEL_humidityBASEL_pressureBASEL_global_radiationBASEL_precipitationBASEL_sunshineBASEL_temp_meanBASEL_temp_minBASEL_temp_maxDE_BILT_cloud_cover...SONNBLICK_temp_meanSONNBLICK_temp_minSONNBLICK_temp_maxTOURS_humidityTOURS_pressureTOURS_global_radiationTOURS_precipitationTOURS_temp_meanTOURS_temp_minTOURS_temp_max
count548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000...548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000548.000000
mean5.6332120.7450551.0181651.2633760.2677924.26478110.8678836.97281015.2156935.385036...-4.708759-7.042883-2.3487230.7847991.0175911.3520260.16514612.0071177.73521916.278832
std2.2467830.1048960.0076590.9142860.5809994.2943936.9921346.3620118.2813532.230648...6.8745197.1207176.7518900.1160050.0082890.9296260.3720466.2460735.5703697.420308
min0.0000000.4200000.9893000.0600000.0000000.000000-6.100000-11.200000-3.7000000.000000...-25.600000-28.600000-24.7000000.3700000.9852000.0500000.000000-5.000000-9.000000-1.600000
25%4.0000000.6775001.0137000.4700000.0000000.3000005.5750002.0750009.2000004.000000...-9.500000-12.200000-6.9000000.7100001.0130000.5175000.0000007.3750003.67500010.800000
50%6.0000000.7600001.0179001.0000000.0100002.90000010.9500007.10000015.0500006.000000...-4.300000-6.350000-2.2500000.8000001.0176501.2200000.00000011.6000008.20000016.050000
75%7.0000000.8200001.0229501.9225000.2525007.40000016.32500011.82500021.7000007.000000...0.300000-1.5000002.0250000.8725001.0233002.0500000.16000017.00000012.02500022.300000
max8.0000000.9700001.0403003.4700005.36000015.00000027.70000019.60000035.9000008.000000...10.4000006.90000014.1000000.9900001.0388003.4500003.24000027.70000020.60000036.200000
\n", + "

8 rows × 89 columns

\n", + "
" + ], + "text/plain": [ + " BASEL_cloud_cover BASEL_humidity BASEL_pressure \\\n", + "count 548.000000 548.000000 548.000000 \n", + "mean 5.633212 0.745055 1.018165 \n", + "std 2.246783 0.104896 0.007659 \n", + "min 0.000000 0.420000 0.989300 \n", + "25% 4.000000 0.677500 1.013700 \n", + "50% 6.000000 0.760000 1.017900 \n", + "75% 7.000000 0.820000 1.022950 \n", + "max 8.000000 0.970000 1.040300 \n", + "\n", + " BASEL_global_radiation BASEL_precipitation BASEL_sunshine \\\n", + "count 548.000000 548.000000 548.000000 \n", + "mean 1.263376 0.267792 4.264781 \n", + "std 0.914286 0.580999 4.294393 \n", + "min 0.060000 0.000000 0.000000 \n", + "25% 0.470000 0.000000 0.300000 \n", + "50% 1.000000 0.010000 2.900000 \n", + "75% 1.922500 0.252500 7.400000 \n", + "max 3.470000 5.360000 15.000000 \n", + "\n", + " BASEL_temp_mean BASEL_temp_min BASEL_temp_max DE_BILT_cloud_cover \\\n", + "count 548.000000 548.000000 548.000000 548.000000 \n", + "mean 10.867883 6.972810 15.215693 5.385036 \n", + "std 6.992134 6.362011 8.281353 2.230648 \n", + "min -6.100000 -11.200000 -3.700000 0.000000 \n", + "25% 5.575000 2.075000 9.200000 4.000000 \n", + "50% 10.950000 7.100000 15.050000 6.000000 \n", + "75% 16.325000 11.825000 21.700000 7.000000 \n", + "max 27.700000 19.600000 35.900000 8.000000 \n", + "\n", + " ... SONNBLICK_temp_mean SONNBLICK_temp_min SONNBLICK_temp_max \\\n", + "count ... 548.000000 548.000000 548.000000 \n", + "mean ... -4.708759 -7.042883 -2.348723 \n", + "std ... 6.874519 7.120717 6.751890 \n", + "min ... -25.600000 -28.600000 -24.700000 \n", + "25% ... -9.500000 -12.200000 -6.900000 \n", + "50% ... -4.300000 -6.350000 -2.250000 \n", + "75% ... 0.300000 -1.500000 2.025000 \n", + "max ... 10.400000 6.900000 14.100000 \n", + "\n", + " TOURS_humidity TOURS_pressure TOURS_global_radiation \\\n", + "count 548.000000 548.000000 548.000000 \n", + "mean 0.784799 1.017591 1.352026 \n", + "std 0.116005 0.008289 0.929626 \n", + "min 0.370000 0.985200 0.050000 \n", + "25% 0.710000 1.013000 0.517500 \n", + "50% 0.800000 1.017650 1.220000 \n", + "75% 0.872500 1.023300 2.050000 \n", + "max 0.990000 1.038800 3.450000 \n", + "\n", + " TOURS_precipitation TOURS_temp_mean TOURS_temp_min TOURS_temp_max \n", + "count 548.000000 548.000000 548.000000 548.000000 \n", + "mean 0.165146 12.007117 7.735219 16.278832 \n", + "std 0.372046 6.246073 5.570369 7.420308 \n", + "min 0.000000 -5.000000 -9.000000 -1.600000 \n", + "25% 0.000000 7.375000 3.675000 10.800000 \n", + "50% 0.000000 11.600000 8.200000 16.050000 \n", + "75% 0.160000 17.000000 12.025000 22.300000 \n", + "max 3.240000 27.700000 20.600000 36.200000 \n", + "\n", + "[8 rows x 89 columns]" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_test.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 8. 0.76 1.0003 0.92 0.39 0.2 2. -0.2\n", + " 4.7 4. 0.85 1.0038 1.28 0.2 5.5 3.8\n", + " -1. 8.5 8. 0.87 0.73 0.14 0. 1.5\n", + " -1.1 3.8 4. 0.83 1.0024 0.98 0.12 2.9\n", + " 3.3 -2.7 8.8 5. 0.68 1.0124 0.96 0.04\n", + " 2.5 4.4 1.9 6.9 0.83 0.9996 1.14 0.21\n", + " 3.9 2. -2. 6.8 6. 0.84 1.0034 1.21\n", + " 0.02 4.7 3.3 -1.5 8.3 0.16 3.2 -0.4\n", + " 7.9 8. 0.86 0.997 0.54 0.58 0. 1.2\n", + " -0.2 2.8 8. 0.98 1.45 0.9 0. -16.8\n", + " -17.6 -15.9 0.87 1.0079 0.81 0.14 4. 0.2\n", + " 7.8 ]\n" + ] + } + ], + "source": [ + "print(data_instance)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "2. Download the pretrained ONNX model" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[3.0719438]], dtype=float32)" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# download onnx model and check the prediction with it\n", + "model_path = download('sunshine_hours_regression_model.onnx', 'model')\n", + " \n", + "loaded_model = SimpleModelRunner(model_path)\n", + "predictions = loaded_model(data_instance.reshape(1,-1).astype(np.float32))\n", + "predictions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A runner function is created to prepare data for the ONNX inference session.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "import onnxruntime as ort\n", + "\n", + "def run_model(data):\n", + " # get ONNX predictions\n", + " sess = ort.InferenceSession(model_path)\n", + " input_name = sess.get_inputs()[0].name\n", + " output_name = sess.get_outputs()[0].name\n", + "\n", + " onnx_input = {input_name: data.astype(np.float32)}\n", + " pred_onnx = sess.run([output_name], onnx_input)[0]\n", + " \n", + " return pred_onnx" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "3. Run dianna with the KernelSHAP explainer and visualize the output:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The simplest way to run DIANNA on tabular data is with `dianna.explain_tabular`. Note, that the training data is also required since KernelSHAP needs it to generate proper perturbation. The method's mode needs to be spesified as 'regression'." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "explanation = dianna.explain_tabular(run_model, input_tabular=data_instance, method='kernelshap',\n", + " mode ='regression', training_data = X_train, \n", + " training_data_kmeans = 5, feature_names=X_test.columns)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output can be visualized with the DIANNA built-in visualization function. It shows the top 10 importance of each feature contributing to the prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from dianna.visualization import plot_tabular\n", + "\n", + "fig, _ = plot_tabular(explanation, X_test.columns, num_features=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*We can see which min or max temperatures of which locations mostly influence (positively or negatively) the predicted by the trained modelnext-day temperature in Basel.*" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## **Explainers**\n", + "DIANNA supports **LIME**, [Ribeiro et al., \"Why Should I Trust You?: Explaining the Predictions of Any Classifier\", CoRR, 2016], **RISE**, [Petsiuk et al., \"RISE: Randomized Input Sampling for Explanation of Black-box Models\", BMVC 2018] and **KernalSHAP**, [Lundberg and Lee. ,\"A unified approach to interpreting model predictions.\", NIPS 2017] XAI methods. It allows users to compare the outputs of three different explainers on the same model and data, illustrated best by [dianna's dashboard](https://github.com/dianna-ai/dianna#dashboard). This section briefly demonstrates how to run on the command line the supported explainers for the simple binary classification task of distinguishing the hand-written digits \"0\" and \"1\" on a test example from the Binary MNIST dataset, a subset of the famous MNIST benchmark. It also gives the basics for each of the explainers." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Explaining a Pretrained Binary MNIST Classification Model *" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Load the Binary MNIST data, the pretrained [binary MNIST model](https://zenodo.org/record/5907177) and chose image to be explained." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# load dataset\n", + "data_path = download('binary-mnist.npz', 'data')\n", + "data = np.load(data_path)\n", + "# load testing data and the related labels\n", + "X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 256\n", + "y_test = data['y_test']" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Download the onnx model\n", + "\n", + "# load the onnx model and check the prediction with it\n", + "model_path = download('mnist_model_tf.onnx', 'model')\n", + "onnx_model = onnx.load(model_path)\n", + "# get the output node\n", + "output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Print class and image of a single instance in the test data for preview." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The predicted class for this test image is: digit 0\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAANpklEQVR4nO3df+hVdZ7H8dcrV/+xojJWtImdioimaPshIayt1TBDW1L5jyk0tWTYjwlmaIUNVxohBmzZaemvQslyF7dhSIdkWnJa+zVmhPZj1bSZLIxRvmVipVIwa773j+9x+I597+d+vffce26+nw/4cu8973vueXPp1Tn3fM7x44gQgBPfSU03AKA/CDuQBGEHkiDsQBKEHUjir/q5Mduc+gd6LCI82vKu9uy2r7P9e9s7bT/QzWcB6C13Os5ue5ykP0j6gaTdkjZJmhcR2wvrsGcHeqwXe/YrJe2MiA8j4k+Sfinppi4+D0APdRP2syT9ccTr3dWyv2B7ge3Ntjd3sS0AXer5CbqIWCZpmcRhPNCkbvbseySdPeL1d6plAAZQN2HfJOl82+fYniBprqS19bQFoG4dH8ZHxGHb90laJ2mcpBUR8W5tnQGoVcdDbx1tjN/sQM/15KIaAN8ehB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4k0dcpm9EbM2bMaFl7/fXXi+tecMEFxfqsWbOK9RtuuKFYf+6554r1ko0bNxbrGzZs6PizM2LPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMIvrADj11FOL9VWrVhXr1157bcvaV199VVx3woQJxfrJJ59crPdSu96//PLLYv2ee+5pWXvmmWc66unboNUsrl1dVGN7l6SDkr6WdDgipnXzeQB6p44r6K6JiH01fA6AHuI3O5BEt2EPSb+1/abtBaO9wfYC25ttb+5yWwC60O1h/IyI2GP7ryW9YPu9iHh15BsiYpmkZRIn6IAmdbVnj4g91eNeSb+WdGUdTQGoX8dhtz3R9ilHn0v6oaRtdTUGoF4dj7PbPlfDe3Np+OfAf0XEz9usw2H8KB577LFi/a677urZtnfs2FGsf/rpp8X6gQMHOt62Pepw8J+1u1e+nYMHD7asXXXVVcV1t2zZ0tW2m1T7OHtEfCjpbzvuCEBfMfQGJEHYgSQIO5AEYQeSIOxAEtzi2gcXXXRRsf7yyy8X65MmTSrWd+/e3bJ22223FdfduXNnsf75558X64cOHSrWS046qbyvefDBB4v1xYsXF+vjxo1rWVuzZk1x3TvvvLNY/+yzz4r1JrUaemPPDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJMGVzH5xyyinFertx9HbXQjz88MMta+3G8Jt05MiRYn3JkiXFert/BnvhwoUta7Nnzy6uu2LFimK9m6mom8KeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4H72Ppg5c2ax/tJLLxXrTz31VLF+xx13HG9LKXzwwQcta+ecc05x3SeffLJYnz9/fkc99QP3swPJEXYgCcIOJEHYgSQIO5AEYQeSIOxAEtzP3gcPPfRQV+u/8cYbNXWSy7p161rW7r777uK606dPr7udxrXds9teYXuv7W0jlp1h+wXb71ePp/e2TQDdGsth/FOSrjtm2QOS1kfE+ZLWV68BDLC2YY+IVyXtP2bxTZJWVs9XSrq53rYA1K3T3+yTI2Koev6xpMmt3mh7gaQFHW4HQE26PkEXEVG6wSUilklaJuW9EQYYBJ0OvX1ie4okVY9762sJQC90Gva1km6vnt8u6dl62gHQK20P420/LelqSWfa3i3pZ5KWSvqV7fmSPpI0p5dNDrpzzz23WJ86dWqx/sUXXxTrW7duPe6eIL344osta+3G2U9EbcMeEfNalL5fcy8AeojLZYEkCDuQBGEHkiDsQBKEHUiCW1xrcOuttxbr7YbmVq9eXaxv3LjxuHsCjsWeHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSYJy9BnPnzi3W293C+uijj9bZDjAq9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7H3w3nvvFesbNmzoUyfIjD07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsYTZw4sWVt/PjxfewE6EzbPbvtFbb32t42YtkS23tsv1P9Xd/bNgF0ayyH8U9Jum6U5f8eEZdWf/9db1sA6tY27BHxqqT9fegFQA91c4LuPttbqsP801u9yfYC25ttb+5iWwC61GnYH5N0nqRLJQ1J+kWrN0bEsoiYFhHTOtwWgBp0FPaI+CQivo6II5KWS7qy3rYA1K2jsNueMuLlbEnbWr0XwGBoO85u+2lJV0s60/ZuST+TdLXtSyWFpF2S7updi4Nhzpw5LWvnnXdecd19+/bV3Q7G4MYbb+x43cOHD9fYyWBoG/aImDfK4id60AuAHuJyWSAJwg4kQdiBJAg7kARhB5LgFld8a11xxRXF+qxZszr+7EWLFnW87qBizw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSTDOjoHVbhz9/vvvL9ZPO+20lrXXXnutuO66deuK9W8j9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATj7GO0a9eulrWDBw/2r5ETyLhx44r1hQsXFuu33HJLsb5nz56OP/tE/Kek2bMDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKOiP5tzO7fxvpo+/btxXq773jmzJnF+iBP+XzJJZcU6/fee2/L2uWXX15cd9q0aR31dNQ111zTsvbKK6909dmDLCI82vK2e3bbZ9t+yfZ22+/a/km1/AzbL9h+v3o8ve6mAdRnLIfxhyX9U0R8T9J0ST+2/T1JD0haHxHnS1pfvQYwoNqGPSKGIuKt6vlBSTsknSXpJkkrq7etlHRzj3oEUIPjujbe9nclXSbpDUmTI2KoKn0saXKLdRZIWtBFjwBqMOaz8bZPlrRa0k8j4sDIWgyfgRr1LFRELIuIaRHR3dkWAF0ZU9htj9dw0FdFxJpq8Se2p1T1KZL29qZFAHVoexhv25KekLQjIh4ZUVor6XZJS6vHZ3vS4QngwgsvLNaff/75Yn1oaKhYb9L06dOL9UmTJnX82e2GHNeuXVusb9q0qeNtn4jG8pv97yT9SNJW2+9UyxZpOOS/sj1f0keS5vSkQwC1aBv2iNggadRBeknfr7cdAL3C5bJAEoQdSIKwA0kQdiAJwg4kwS2uNZg9e3axvnjx4mL9sssuq7OdgXLkyJGWtf379xfXfeSRR4r1pUuXdtTTia7jW1wBnBgIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtn7YOrUqcV6u/vZL7744jrbqdXy5cuL9bfffrtl7fHHH6+7HYhxdiA9wg4kQdiBJAg7kARhB5Ig7EAShB1IgnF24ATDODuQHGEHkiDsQBKEHUiCsANJEHYgCcIOJNE27LbPtv2S7e2237X9k2r5Ett7bL9T/V3f+3YBdKrtRTW2p0iaEhFv2T5F0puSbtbwfOyHIuLfxrwxLqoBeq7VRTVjmZ99SNJQ9fyg7R2Szqq3PQC9dly/2W1/V9Jlkt6oFt1ne4vtFbZPb7HOAtubbW/urlUA3RjztfG2T5b0iqSfR8Qa25Ml7ZMUkh7S8KH+HW0+g8N4oMdaHcaPKey2x0v6jaR1EfGN2faqPf5vIqL4LyMSdqD3Or4RxrYlPSFpx8igVyfujpotaVu3TQLonbGcjZ8h6XeStko6Ov/uIknzJF2q4cP4XZLuqk7mlT6LPTvQY10dxteFsAO9x/3sQHKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJNr+g5M12yfpoxGvz6yWDaJB7W1Q+5LorVN19vY3rQp9vZ/9Gxu3N0fEtMYaKBjU3ga1L4neOtWv3jiMB5Ig7EASTYd9WcPbLxnU3ga1L4neOtWX3hr9zQ6gf5reswPoE8IOJNFI2G1fZ/v3tnfafqCJHlqxvcv21moa6kbnp6vm0Ntre9uIZWfYfsH2+9XjqHPsNdTbQEzjXZhmvNHvrunpz/v+m932OEl/kPQDSbslbZI0LyK297WRFmzvkjQtIhq/AMP230s6JOk/jk6tZftfJe2PiKXV/yhPj4h/HpDelug4p/HuUW+tphn/RzX43dU5/XknmtizXylpZ0R8GBF/kvRLSTc10MfAi4hXJe0/ZvFNklZWz1dq+D+WvmvR20CIiKGIeKt6flDS0WnGG/3uCn31RRNhP0vSH0e83q3Bmu89JP3W9pu2FzTdzCgmj5hm62NJk5tsZhRtp/Hup2OmGR+Y766T6c+7xQm6b5oREZdL+gdJP64OVwdSDP8GG6Sx08cknafhOQCHJP2iyWaqacZXS/ppRBwYWWvyuxulr758b02EfY+ks0e8/k61bCBExJ7qca+kX2v4Z8cg+eToDLrV496G+/mziPgkIr6OiCOSlqvB766aZny1pFURsaZa3Ph3N1pf/fremgj7Jknn2z7H9gRJcyWtbaCPb7A9sTpxItsTJf1QgzcV9VpJt1fPb5f0bIO9/IVBmca71TTjavi7a3z684jo+5+k6zV8Rv4DSf/SRA8t+jpX0v9Wf+823ZukpzV8WPd/Gj63MV/SJEnrJb0v6X8knTFAvf2nhqf23qLhYE1pqLcZGj5E3yLpnerv+qa/u0JfffneuFwWSIITdEAShB1IgrADSRB2IAnCDiRB2IEkCDuQxP8D0wdNenALPw0AAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# class name\n", + "class_name = ['digit 0', 'digit 1']\n", + "# instance index\n", + "i_instance = 3\n", + "# select instance for testing\n", + "test_sample = X_test[i_instance]\n", + "# model predictions with added batch axis to test sample\n", + "predictions = prepare(onnx_model).run(test_sample[None, ...])[f'{output_node}']\n", + "pred_class = class_name[np.argmax(predictions)]\n", + "other_class = class_name[np.argmin(predictions)]\n", + "# get the index of predictions\n", + "top_preds = np.argsort(-predictions)\n", + "inds = top_preds[0]\n", + "print(\"The predicted class for this test image is:\", pred_class)\n", + "plt.imshow(X_test[i_instance][:,:,0], cmap='gray') # 0 for channel" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "#### 1. LIME\n", + "\n", + "LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use dianna with LIME, in the explanation function (for images `explain_image`) we simply specify `method=\"LIME\"` and optionally specify the LIME hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0be6a2afa0e948e7a34ed348c1e27583", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/5000 [00:00" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(f'Explaination for `{pred_class}` with LIME')\n", + "fig, _ = visualization.plot_image(relevances[0], X_test[i_instance][:,:,0], data_cmap='gray')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explaination for `digit 1` with LIME\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAASkAAADxCAYAAACXg0F0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAASuElEQVR4nO3de7BdZX3G8echIQlMbXPgoEYuMQ5BLkGDxijQAYqA0T/AVtQwowYHFC/Yaa0WGGa8xDIT2z/SsaVK1CheyqU4alpxUgjgpRDNSRsuSQw5hJiLURrCtScXAr/+sdeJy+PZ717nnOXh3Svfz8ye7L1+611rnRl5fNdl758jQgCQq0Ne7AMAgBRCCkDWCCkAWSOkAGSNkAKQNUIKQNYIKQAH2F5q+zHbD7Wp2/YXbPfbfsD260q1BbY3Fq8FdR0TIQWg7OuS5iXqb5U0s3h9UNIXJcn2EZI+LemNkuZK+rTtnjoOiJACcEBE/FjSrsQqF0n6RrSslDTV9jRJb5F0R0TsiognJN2hdNhVNjFVPNyOqXXsBUBbO6SdEXHUaMfPO/742DkwUGnd1Tt2rJW0p7RoSUQsGcHujpa0tfR5W7Gs3fIxS4bUVEkfqGMvANpaKP1yLON3Dgyo74orKq3rz3xmT0TMGcv+xhune0C3s6u/xm67pGNLn48plrVbPmaEFNAE4xdSyyS9r7jL9yZJT0XEDknLJV1gu6e4YH5BsWzMkqd7ALpEPQEk2zdJOkdSr+1tat2xO1SSIuJLkm6X9DZJ/ZIGJL2/qO2y/TlJq4pNLYyI1AX4yggpoAkOqeekKCIu6VAPSR9tU1sqaWktB1JCSAHdrr5TuSwRUkATEFIAskZIAcgaIQUga4QUgGzZtd3dyxEhBTQBMykAWSOkAGSNkAKQLR7mBJA9LpwDyBozKQBZI6QAZItrUgCyR0gByBohBSBr3N0DkK2GX5NqbvwCB5MaGzHYnmd7Q9FK/eph6ottryleD9t+slR7vlRbVsefxkwKaIL6GjFMkHS9pPPVavC5yvayiFg3uE5E/HVp/Y9JOq20id0RMbuWgykwkwKaoL6Z1FxJ/RGxKSL2SbpZrdbq7Vwi6aYa/oK2CCmgCeoLqcrt0m1PlzRD0l2lxVNs99leafvto/xrfgene0C3G9mP3vXa7it9XhIRS0a55/mSbouI50vLpkfEdtuvknSX7Qcj4pFRbl8SIQU0Q/VrUjsjYk6iPpJ26fM1pAdfRGwv/t1k+x61rleNKaQ43QOaoL7TvVWSZtqeYXuSWkH0e3fpbJ8oqUfSfaVlPbYnF+97JZ0pad3QsSPFTOpFcNxxx7Wtbdu6tW1Nkq7v/UKyfsIJr+5Qn5msP/zwxmQ95f1bL0/Wt2zZMuptI6HG56QiYr/tKyUtlzRB0tKIWGt7oaS+iBgMrPmSbi46Gg86SdINtl9QawK0qHxXcLQIKaAJanyYMyJul3T7kGWfGvL5M8OMu1fSqbUdSIGQApqAr8UAyFqDvxZDSAHdruHf3SOkgCYgpABkjZACkDVCCmWTJ09O1te+4/5k/fkZM9rWnntuf3Ls5RMmJOuTJk1K1js57rjpox776P70M1bPPfdcsv66/5jbtrZu3Zgft2mukX0tpusQUkATMJMCkDVCCkDWCCkAWSOkAGSLhzkBZI+7ewCyxkwKZfee95Nkfc7ME0a97YkTD03WP7LzY8n6Vf/3N8n63r37RnxMgzr9dzCzw9/d6W+7/6LVbWtLH1+aHHvFbz6arDceIQUgW1yTApA9QgpA1rhwDiBrDZ5JNTd+gYNF1U4xFYPM9jzbG2z32756mPqltv/X9pridXmptsD2xuK1oI4/j5kU0AQ1zaRsT5B0vaTz1epevMr2smG6vtwSEVcOGXuEpE9LmiMpJK0uxj4xlmMipIbxxaP+OVl//yknj2n7Tz/9dNvaqd+dnRw7c9eVyfrf7jkxWd+3bwyPIHSo/+icu5P1s846K1mfPHlK29o555yTHHvY99uPlaTde/Yk612vvtO9uZL6I2JTa7O+WdJFqtY/7y2S7oiIXcXYOyTNk3TTWA6I0z2gCaqf7vXa7iu9PjhkS0dLKjd/3FYsG+odth+wfZvtwY7HVceOCDMpoNuN7EfvOrVZr+LfJd0UEXttXyHpRknnjnGbbTGTApqgvgvn2yUdW/p8TLHsgIh4PCL2Fh+/Iun1VceOBiEFNEF9IbVK0kzbM2xPUqud+rLyCranlT5eKGl98X65pAts99jukXRBsWxMON0DmqCmC+cRsd/2lWqFywRJSyNire2FkvoiYpmkv7R9oaT9knZJurQYu8v259QKOklaOHgRfSwIKaAJanyYMyJul3T7kGWfKr2/RtI1bcYulZT+NvgIEVJAt+MLxgef6yYvSta3H/aRDluIZPXNPz2/be3nmzcnx/6VPpGsj/4pqM7Sf5V01j1/lqz/14QfJ+tnnHFG29qJJ56UHHvsfx+brD+8Md1uq+vx3T0AWWMmBSBrhBSAbHFNCkD2CCkAWePCOYCsMZMCkC2uSR18Jk6YMKbxa9bcn6zHqlXJelOduSL9e1JPzXq8bW3q1J7k2FtPujlZn73x9cl61yOkAGSNkAKQNUIKQLZG9qN3XYeQApqAmRSArBFSALJGSAHIGiF1cPnJuT9K1o/rMP6ybR9I1vtGeDwHi/7+R9rW5sxJNzj58rHpH4O8flRH1CUa/jBnc28JAAeTQw6p9qqgQpv1j9teV/TdW2F7eqn2fKn9+rKhY0eDmRTQBOPbZv1/JM2JiAHbH5b095LeXdR2R8TsWg6mwEwKaIL6WlodaLMeEfskDbZZPyAi7o6IgeLjSrX66/3BEFJAt6saUPW2WR90maQflj5PKba70vbb6/jzON0DmqD66V4dbdaLXfo9kuZIOru0eHpEbLf9Kkl32X4wItrfEamAkAKaoL6vxVRqlW77PEnXSjq71HJdEbG9+HeT7XsknSZpTCHF6R7QBOPbZv00STdIujAiHist77E9uXjfK+lMSeUL7qNyUM6k/qlncbL+3pe8JFnfu3dPsv6bxx5L1jG89z56adva+jkPjd+BdJsan5Oq2Gb9HyT9kaR/c2u/WyLiQkknSbrB9gtqTYAWDbkrOCoHZUgBjTO+bdbPazPuXkmn1nYgBUIKaIIGP3FOSAFNQEgByBY/egcge8ykAGSNkGqW17zmtcl6T0+6fdL69em7qpO2bk3WgdoRUgCyRkgByBYXzgFkj5kUgKwRUgCy1fDfOCekgCYgpABkjZBqllmzZiXrnX6K5YKVb0vWt4z4iIAx4u4egGxxTQpA9ggpAFkjpABkrcEh1dyrbcDBpL5GDFXarE+2fUtR/5ntV5Zq1xTLN9h+Sx1/GiEFdLvB7+5VeXXc1IE262+VdLKkS2yfPGS1yyQ9ERHHS1os6fPF2JPV6i5ziqR5kv6l2N6YEFJAE4xjm/Xi843F+9skvdmttjEXSbo5IvZGxKOS+ovtjQnXpIaxc+fOZH3SFp6EQl5Cla9J9druK31eEhFLSp+Ha7P+xiHbOLBO0QLrKUlHFstXDhmbatFeCSEFNEBE5VVra7M+XjjdAxogotqrgipt1g+sY3uipD+R9HjFsSNGSAFdLkJ64YVqrwo6tlkvPi8o3l8s6a6IiGL5/OLu3wxJMyX9fKx/H6d7QAOM4HSvw3YqtVn/qqRv2u6XtEutIFOx3q2S1knaL+mjEfH8WI+JkAIaoK6Qam2rY5v1PZLe2WbsdZKuq+9oCCmgEeoMqdwQUkCXG8FF8a7U2JA69NBD29YmTOB+AZqFkAKQtYp37roSIQU0ADMpANnimhSA7BFSALJGSAHIGiEFIFuD391rqsaG1KxTTmlbW9NzRHLswMBA3YeDCm559b+2rb2mw9gXmvxfaQXMpABkjZACkDVCCkC2eE4KQPaafEmOkAIagJkUgKwRUkANvjztS8n6+044YdTbXnHnnaMe2+2afk2KH1YCGqDGbjFt2T7C9h22Nxb/9gyzzmzb99lea/sB2+8u1b5u+1Hba4rX7Cr7JaSABhiPkJJ0taQVETFT0ori81ADkt4XEYOt1v/R9tRS/ZMRMbt4ramyU0IKaIAaW1qllNur3yjp7UNXiIiHI2Jj8f5Xkh6TdNRYdkpIAV2u6iyqmEn12u4rvT44gl29LCJ2FO9/LellqZVtz5U0SdIjpcXXFaeBi21PrrJTLpwDDVBXm3Xbd0p6+TCla393fxG22+7V9jRJ35S0ICIG53DXqBVukyQtkXSVpIWdDpiQAhqgxuag57Wr2f6N7WkRsaMIocfarPfHkn4g6dqIWFna9uAsbK/tr0n6RJVj4nQPaIBxunBebq++QNL3h65QtGb/rqRvRMRtQ2rTin+t1vWsh6rslJkUatPpOajTTz89WZ8yZUrb2tatW5Jj/cgjyXrTjdNzUosk3Wr7Mkm/lPQuSbI9R9KHIuLyYtlZko60fWkx7tLiTt63bR8lyZLWSPpQlZ0SUkCXG68fvYuIxyW9eZjlfZIuL95/S9K32ow/dzT7JaSABmjyE+eEFNAAhBSArBFSALLV9C8YE1JAA/CjdwCyxkyqCz355JNta/v27R2/A2mQQ+xk/YwzzkjWZ82alaw/88zTbWtvWn5mcuy2Jk8lKiCkAGSLa1IAskdIAcgaIQUga02+JEdIAV2Oa1IAskdIAcgaIdWFHt28uW3t6aefSY6dPDn908uHH354sj4wMJCsv5hueNn1yfob3jC3bW3atOF+Vfa3Zr3i6FEd04Hx33lt29rm7dvHtO2mI6QAZGu8fk/qxUJIAQ3ATApA1pocUjRiABoglzbrxXrPl1qpLystn2H7Z7b7bd9SNG3oiJACGiCjNuuStLvUSv3C0vLPS1ocEcdLekLSZVV2SkgBXW6EHYzHomOb9XaKNlbnShpsc1V5PNekhtHbm25dv+k9G5P1Z59NP+LwYlpwzDHJ+mGHpR+vSOn06MWGDRuS9Ym/+tWo932wG8HdvV7bfaXPSyJiScWxVdusTyn2sV/Sooj4nqQjJT0ZEfuLdbZJqvTMCiEFNEBmbdanR8R226+SdJftByU9VfkIhyCkgAbIqc16RGwv/t1k+x5Jp0n6jqSpticWs6ljJFV6QpdrUkCXG8drUlXarPfYnly875V0pqR1ERGS7pZ0cWr8cAgpoAHGKaQWSTrf9kZJ5xWfZXuO7a8U65wkqc/2/WqF0qKIWFfUrpL0cdv9al2j+mqVnXK6BzTAeDzMWbHN+r2STm0zfpOk9l8ObYOQAhqA7+4ByBY/etdAf7rirGT9nrNXJOvTpk3rsIdO9XxF4n/tu3fvTo697777kvVzf3p+sv5csooUQgpA1ggpAFkjpABkix+9A5A9ZlIAskZIAcgaIQUga4RUw6z/xS+S9eO3vTpZ3/Te/mT9pS996YiPabysXr06Wb/i1x9uW1vV19e2JrV+0Qzjj4c5AWSPu3sAssZMCkDWCCkA2eKaFIDsEVIAskZIAcgad/cOMs88+2yyftQXh+v481s5/5/a6zrUV43LUaBO43VNyvYRkm6R9EpJmyW9KyKeGLLOn0laXFp0oqT5EfE921+XdLZ+297q0ohY02m/NGIAGiCXNusRcfdgi3W1nu8dkPSfpVU+WWrBvqbKTgkpoAEybbN+saQfRkS6tXUHhBTQAOMUUlXbrA+aL+mmIcuus/2A7cWD/fk64ZoU0OVG+KN3vbbLX8JcEhFLBj/U1GZdRYfjUyUtLy2+Rq1wmyRpiVp9+BZ2OmBCCmiAEcySdkbEnPbbGXub9cK7JH03Ig701yjNwvba/pqkT1Q5YE73gAbIpc16ySUacqpXBJtsW63rWQ9V2SkhBTRARm3WZfuVko6V9KMh479t+0FJD0rqlfR3VXbK6R7Q5cbrOakqbdaLz5slHT3MeqP6yTFCCmgAvhYDIGt8LQZA1phJAcgWvycFIHuEFICsEVIAskZIAcjWCL+713UIKaABmEkByBohBSBrhBSArBFSALLFhXMA2WMmBSBrhBSArBFSALLFF4wBZI+QApC1Jt/doxED0ADj0YjB9jttr7X9gu22bbFsz7O9wXa/7atLy2fY/lmx/Bbbk6rsl5ACulzVgKrhlPAhSX8h6cftVrA9QdL1kt4q6WRJl9g+uSh/XtLiiDhe0hOSLquyU0IKaIDxCKmIWB8RGzqsNldSf0Rsioh9km6WdFHRa+9cSbcV692oVu+9jpLXpHZIOxdKv6yyIQCjNn0sg3fsWL38s591b8XVp6TarNfgaElbS5+3SXqjpCMlPRkR+0vLf6/t1XCSIRURR43iIAGMo4iYV9e2bN8p6eXDlK6NiFTH4j8Y7u4BOCAizhvjJrar1b140DHFssclTbU9sZhNDS7viGtSAOq0StLM4k7eJEnzJS2LiJB0t6SLi/UWSKo0MyOkAFRi+89tb5N0uqQf2F5eLH+F7dslqZglXSlpuaT1km6NiLXFJq6S9HHb/Wpdo/pqpf1Gkx9VBdD1mEkByBohBSBrhBSArBFSALJGSAHIGiEFIGuEFICs/T/Bia3rDPz/9wAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(f'Explaination for `{other_class}` with LIME')\n", + "fig, _ = visualization.plot_image(relevances[1], X_test[i_instance][:,:,0], data_cmap='gray')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*It is worth noting that the explanation maps for both binary classes are complementary.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "#### 2. RISE\n", + "\n", + "RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates the relevance empirically by probing the model with randomly masked versions of the input image to obtain the corresponding outputs.
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "RISE masks random portions of the input image and passes the masked image through the model — the portion that decreases the accuracy the most is the most “important” portion.
\n", + "To call the explainer and generate the relevance scores, the user need to specified the number of masks being randomly generated (`n_masks`), the resolution of features in masks (`feature_res`) and for each mask and each feature in the image, the probability of being kept unmasked (`p_keep`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use dianna with RISE, in the explanation function (for images `explain_image`) we simply specify `method=\"RISE\"` and optionally specify the RISE hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Explaining: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:01<00:00, 36.06it/s]\n" + ] + } + ], + "source": [ + "relevances = dianna.explain_image(model_path, test_sample, method=\"RISE\",\n", + " labels=[i for i in range(2)],\n", + " n_masks=5000, feature_res=8, p_keep=.1,\n", + " axis_labels=('height','width','channels'))[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Visualize the relevance scores for the predicted class on top of the image." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Explaination for `digit 0` with RISE\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print(f'Explaination for `{pred_class}` with RISE')\n", + "fig, _ = visualization.plot_image(relevances, X_test[i_instance][:,:,0], data_cmap='gray')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*It is worth noting that the explanation map clearly shows the pixels which contribute positively (in red) to the \"0\" classification on the shape of the hand-written digit and the pixels whcih contributed negatively (in blue) for that decision resemble the complimentary class for \"1\" digits.*" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "#### 3. KernelSHAP\n", + "\n", + "SHapley Additive exPlanations, in short, SHAP, is a model-agnostic explainable AI approach which is used to decrypt the black-box models through estimating the [Shapley values](https://en.wikipedia.org/wiki/Shapley_value), which represent the relevancies of each data feature (image pixel, word in text, etc.). [KernelSHAP](https://proceedings.neurips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf) is a variant of SHAP. It is a method that uses the LIME framework to compute Shapley Values, and visualizes the relevance attributions for each pixel/super-pixel by displaying them on an image.
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "The user needs to specify the number model re-evaluations when explaining each prediction (`nsamples`). A binary mask need to be applied to the image indicating whihc image regiona are hidden. It requires the background color for the masked image, which can be specified by `background`.
\n", + "\n", + "Performing KernelSHAP for each pixel is inefficient. It is always a good practice to segment the input image to super-pixels and perform computations on them. The user has to specify some keyword arguments related to the segmentation: the (approximate) number of labels in the segmented output image (`n_segments`), and width of Gaussian smoothing kernel for pre-processing for each image dimension (`sigma`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use dianna with KernelSHAP, in the explanation fucntion ((for images `explain_image`) we simply specify method=\"KernelSHAP\" and optionally specify the method's hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "pycharm": { + "name": "#%%\n" + }, + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ce5a0f663b9e4b9b97f0e96e2497f024", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1 [00:00