Skip to content

philipperemy/keract

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Keract: Keras Activations + Gradients

Downloads Downloads Keract CI

Tested with Tensorflow 2.9, 2.10, 2.11, 2.12, 2.13, 2.14 and 2.15 (Nov 17, 2023).

pip install keract

You have just found a way to get the activations (outputs) and gradients for each layer of your Tensorflow/Keras model (LSTM, conv nets...).

Important Note: The nested models are not well supported. The recent versions of Tensorflow made it extremely tricky to extract the layer outputs reliably. Please refer to the example section to see what is possible.

API

Get activations (nodes/layers outputs as Numpy arrays)

keract.get_activations(model, x, layer_names=None, nodes_to_evaluate=None, output_format='simple', nested=False, auto_compile=True)

Fetch activations (nodes/layers outputs as Numpy arrays) for a Keras model and an input X. By default, all the activations for all the layers are returned.

  • model: Keras compiled model or one of ['vgg16', 'vgg19', 'inception_v3', 'inception_resnet_v2', 'mobilenet_v2', 'mobilenetv2', ...].
  • x: Numpy array to feed the model as input. In the case of multi-inputs, x should be of type List.
  • layer_names: (optional) Single name of a layer or list of layer names for which activations should be returned. It is useful in very big networks when it is computationally expensive to evaluate all the layers/nodes.
  • nodes_to_evaluate: (optional) List of Keras nodes to be evaluated.
  • output_format: Change the output dictionary key of the function.
    • simple: output key will match the names of the Keras layers. For example Dense(1, name='d1') will return {'d1': ...}.
    • full: output key will match the full name of the output layer name. In the example above, it will return {'d1/BiasAdd:0': ...}.
    • numbered: output key will be an index range, based on the order of definition of each layer within the model.
  • nested: If specified, will move recursively through the model definition to retrieve nested layers. Recursion ends at leaf layers of the model tree or at layers with their name specified in layer_names. For example a Sequential model in another Sequential model is considered nested.
  • auto_compile: If set to True, will auto-compile the model if needed.

Returns: Dict {layer_name (specified by output_format) -> activation of the layer output/node (Numpy array)}.

Example

import numpy as np
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, concatenate
from keract import get_activations

# model definition
i1 = Input(shape=(10,), name='i1')
i2 = Input(shape=(10,), name='i2')

a = Dense(1, name='fc1')(i1)
b = Dense(1, name='fc2')(i2)

c = concatenate([a, b], name='concat')
d = Dense(1, name='out')(c)
model = Model(inputs=[i1, i2], outputs=[d])

# inputs to the model
x = [np.random.uniform(size=(32, 10)), np.random.uniform(size=(32, 10))]

# call to fetch the activations of the model.
activations = get_activations(model, x, auto_compile=True)

# print the activations shapes.
[print(k, '->', v.shape, '- Numpy array') for (k, v) in activations.items()]

# Print output:
# i1 -> (32, 10) - Numpy array
# i2 -> (32, 10) - Numpy array
# fc1 -> (32, 1) - Numpy array
# fc2 -> (32, 1) - Numpy array
# concat -> (32, 2) - Numpy array
# out -> (32, 1) - Numpy array

Display the activations you've obtained

keract.display_activations(activations, cmap=None, save=False, directory='.', data_format='channels_last', fig_size=(24, 24), reshape_1d_layers=False)

Plot the activations for each layer using matplotlib

Inputs are:

  • activations: dict - a dictionary mapping layers to their activations (the output of get_activations)
  • cmap: (optional) string - a valid matplotlib colormap to be used
  • save: (optional) bool - if True the images of the activations are saved rather than being shown
  • directory: (optional) string - where to store the activations (if save is True)
  • data_format: (optional) string - one of "channels_last" (default) or "channels_first".
  • reshape_1d_layers: (optional) bool - tries to reshape large 1d layers to a square/rectangle.
  • fig_size: (optional) (float, float) - width, height in inches.

The ordering of the dimensions in the inputs. "channels_last" corresponds to inputs with shape (batch, steps, channels) (default format for temporal data in Keras) while "channels_first" corresponds to inputs with shape (batch, channels, steps).

Display the activations as a heatmap overlaid on an image

keract.display_heatmaps(activations, input_image, directory='.', save=False, fix=True, merge_filters=False)

Plot heatmaps of activations for all filters overlayed on the input image for each layer

Inputs are:

  • activations: a dictionary mapping layers to their activations (the output of get_activations).
  • input_image: numpy array - the image that was passed as x to get_activations.
  • directory: (optional) string - where to store the heatmaps (if save is True).
  • save: (optional) bool - if True the heatmaps are saved rather than being shown.
  • fix: (optional) bool - if True automated checks and fixes for incorrect images will be ran.
  • merge_filters: (optional) bool - if True one heatmap (with all the filters averaged together) is produced for each layer, if False a heatmap is produced for each filter in each layer

Get gradients of weights

keract.get_gradients_of_trainable_weights(model, x, y)
  • model: a keras.models.Model object.
  • x: Numpy array to feed the model as input. In the case of multi-inputs, x should be of type List.
  • y: Labels (numpy array). Keras convention.

The output is a dictionary mapping each trainable weight to the values of its gradients (regarding x and y).

Get gradients of activations

keract.get_gradients_of_activations(model, x, y, layer_name=None, output_format='simple')
  • model: a keras.models.Model object.
  • x: Numpy array to feed the model as input. In the case of multi-inputs, x should be of type List.
  • y: Labels (numpy array). Keras convention.
  • layer_name: (optional) Name of a layer for which activations should be returned.
  • output_format: Change the output dictionary key of the function.
    • simple: output key will match the names of the Keras layers. For example Dense(1, name='d1') will return {'d1': ...}.
    • full: output key will match the full name of the output layer name. In the example above, it will return {'d1/BiasAdd:0': ...}.
    • numbered: output key will be an index range, based on the order of definition of each layer within the model.

Returns: Dict {layer_name (specified by output_format) -> grad activation of the layer output/node (Numpy array)}.

The output is a dictionary mapping each layer to the values of its gradients (regarding x and y).

Persist activations to JSON

keract.persist_to_json_file(activations, filename)
  • activations: activations (dict mapping layers)
  • filename: output filename (JSON format)

Load activations from JSON

keract.load_activations_from_json_file(filename)
  • filename: filename to read the activations from (JSON format)

It returns the activations.

Examples

Examples are provided for:

  • keras.models.Sequential - mnist.py
  • keras.models.Model - multi_inputs.py
  • Recurrent networks - recurrent.py

In the case of MNIST with LeNet, we are able to fetch the activations for a batch of size 128:

conv2d_1/Relu:0
(128, 26, 26, 32)

conv2d_2/Relu:0
(128, 24, 24, 64)

max_pooling2d_1/MaxPool:0
(128, 12, 12, 64)

dropout_1/cond/Merge:0
(128, 12, 12, 64)

flatten_1/Reshape:0
(128, 9216)

dense_1/Relu:0
(128, 128)

dropout_2/cond/Merge:0
(128, 128)

dense_2/Softmax:0
(128, 10)

We can visualise the activations. Here's another example using VGG16:

cd examples
pip install -r examples-requirements.txt
python vgg16.py


A cat.


Outputs of the first convolutional layer of VGG16.

Also, we can visualise the heatmaps of the activations:

cd examples
pip install -r examples-requirements.txt
python heat_map.py

Limitations / Ways of improvement

In some specific cases, Keract does not handle well some models that contain submodels. Feel free to fork this repo and propose a PR to fix it!

Citation

@misc{Keract,
  author = {Philippe Remy},
  title = {Keract: A library for visualizing activations and gradients},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/philipperemy/keract}},
}

Contributors