diff --git a/utils/utils.py b/utils/utils.py index d2ee0f8..b17e9aa 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -14,7 +14,9 @@ from utils.cam import grad_cam, grad_cam_plusplus -def load_network(network_name: str) -> Tuple[ +def load_network( + network_name: str, +) -> Tuple[ tf.keras.Model, Tuple[int, int], tf.keras.layers.Layer, @@ -190,6 +192,14 @@ def draw_heatmap( def guided_bp_map_postprocessing(guided_bp_map: np.array) -> np.array: + """Function that applies postprocessing to a given guided backpropagation map + + Arguments: + guided_bp_map (np.array): The guided backpropagation map + + Returns: + guided_bp_map (np.array): The postprocessed guided backpropagation map + """ # Center on 0 with std 0.25 guided_bp_map -= guided_bp_map.mean()