Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new explain_image notebook #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**/__pycache__/**/*
outputs/**/*
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,8 @@ Example cross-attention visualizations.
as seen in the illustration above.
This notebook can be used to provide an explanation for the generations produced by Attend-and-Excite.

### Explainability on arbitrary images
`notebooks/explain_image.ipynb` shows cross-attention maps on existing images. It illustrates how some concepts (e.g. subjects tokens and easily identifiable body parts) can be located in the image reliably from the attention maps, while other concepts aren't identifiable. This notebook was contributed by Christian Laforte from Stability AI.

## Acknowledgements
This code is builds on the code from the [diffusers](https://github.com/huggingface/diffusers) library as well as the [Prompt-to-Prompt](https://github.com/google/prompt-to-prompt/) codebase.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion environment/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ opencv-python
ipywidgets
matplotlib
pyrallis
torch==1.12.0
torch==1.13.1
diffusers==0.3.0
transformers==4.23.1
jupyter
774 changes: 774 additions & 0 deletions notebooks/explain_image.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pprint
from typing import List
from typing import List, Optional

import pyrallis
import torch
Expand All @@ -19,7 +19,6 @@ def load_model(config: RunConfig):
stable = AttendAndExcitePipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
return stable


def get_indices_to_alter(stable, prompt: str) -> List[int]:
token_idx_to_word = {idx: stable.tokenizer.decode(t)
for idx, t in enumerate(stable.tokenizer(prompt)['input_ids'])
Expand All @@ -37,7 +36,8 @@ def run_on_prompt(prompt: List[str],
controller: AttentionStore,
token_indices: List[int],
seed: torch.Generator,
config: RunConfig) -> Image.Image:
config: RunConfig,
latents: Optional[torch.FloatTensor] = None) -> Image.Image:
if controller is not None:
ptp_utils.register_attention_control(model, controller)
outputs = model(prompt=prompt,
Expand All @@ -46,6 +46,7 @@ def run_on_prompt(prompt: List[str],
attention_res=config.attention_res,
guidance_scale=config.guidance_scale,
generator=seed,
latents=latents,
num_inference_steps=config.n_inference_steps,
max_iter_to_alter=config.max_iter_to_alter,
run_standard_sd=config.run_standard_sd,
Expand Down
36 changes: 27 additions & 9 deletions utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ def show_cross_attention(prompt: str,
res: int,
from_where: List[str],
select: int = 0,
orig_image=None):
orig_image=None,
vis_image_size=512,
auto_scale_heatmap=True,
heatmap_scale=100.0):
"""
Args:
vis_image_size: width/height of each image, in the final image grid
auto_scale_heatmap: if True, each image's attention heatmap gets normalized independently,
to make it easier to see the relative attention within that image.
If False, all images' heatmaps get scaled by `heatmap_scale`
heatmap_scale: amount by which to scale the heatmaps.
"""
tokens = tokenizer.encode(prompt)
decoder = tokenizer.decode
attention_maps = aggregate_attention(attention_store, res, from_where, True, select).detach().cpu()
Expand All @@ -26,31 +37,38 @@ def show_cross_attention(prompt: str,
for i in range(len(tokens)):
image = attention_maps[:, :, i]
if i in indices_to_alter:
image = show_image_relevance(image, orig_image)
image = show_image_relevance(image, orig_image, vis_image_size, auto_scale_heatmap, heatmap_scale)
image = image.astype(np.uint8)
image = np.array(Image.fromarray(image).resize((res ** 2, res ** 2)))
image = np.array(Image.fromarray(image).resize((vis_image_size, vis_image_size)))
image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
images.append(image)

ptp_utils.view_images(np.stack(images, axis=0))


def show_image_relevance(image_relevance, image: Image.Image, relevnace_res=16):
def show_image_relevance(image_relevance, image: Image.Image, vis_image_size=512, auto_scale_heatmap=True,
heatmap_scale=100.0):
# create heatmap from mask on image
def show_cam_on_image(img, mask):
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
cam = heatmap + np.float32(img)
cam = cam / np.max(cam)
if auto_scale_heatmap:
cam = cam / np.max(cam)
else:
cam = cam / 2.0 # normalize by constant value
return cam

image = image.resize((relevnace_res ** 2, relevnace_res ** 2))
image = image.resize((vis_image_size, vis_image_size))
image = np.array(image)

image_relevance = image_relevance.reshape(1, 1, image_relevance.shape[-1], image_relevance.shape[-1])
image_relevance = torch.nn.functional.interpolate(image_relevance, size=relevnace_res ** 2, mode='bilinear')
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
image_relevance = image_relevance.reshape(relevnace_res ** 2, relevnace_res ** 2)
image_relevance = torch.nn.functional.interpolate(image_relevance, size=vis_image_size, mode='bilinear') # nearest would be better
if auto_scale_heatmap:
image_relevance = (image_relevance - image_relevance.min()) / (image_relevance.max() - image_relevance.min())
else:
image_relevance *= heatmap_scale
image_relevance = image_relevance.reshape(vis_image_size, vis_image_size)
image = (image - image.min()) / (image.max() - image.min())
vis = show_cam_on_image(image, image_relevance)
vis = np.uint8(255 * vis)
Expand Down