Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue Reproducing Figure 4 from the Paper #9

Open
ChubbyPear opened this issue Dec 20, 2024 · 0 comments
Open

Issue Reproducing Figure 4 from the Paper #9

ChubbyPear opened this issue Dec 20, 2024 · 0 comments

Comments

@ChubbyPear
Copy link

Dear Authors,

I hope this message finds you well. I am having problem reproducing Figure 4 from your paper. Below is a snippet of the code from attnmap.py:

for stage in stages:
    ws_a0 = torch.stack([ws_rec['stage{}'.format(stage)][i][0] for i in range(len(rand_examples))]).mean(0)
    ws_a1 = torch.stack([ws_rec['stage{}'.format(stage)][i][1] for i in range(len(rand_examples))]).mean(0)
    ws_a2 = torch.stack([ws_rec['stage{}'.format(stage)][i][2] for i in range(len(rand_examples))]).mean(0)
    ws_a3 = torch.stack([ws_rec['stage{}'.format(stage)][i][3] for i in range(len(rand_examples))]).mean(0)
    if half:
        # interpolate a1, a2, a3 to a0 size
        h, w = ws_a0.shape
        ws_a1 = F.interpolate(ws_a1[None, None], size=(h, w), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
        ws_a2 = F.interpolate(ws_a2[None, None], size=(h, w), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
        ws_a3 = F.interpolate(ws_a3[None, None], size=(h, w), mode='bilinear', align_corners=False).squeeze(0).squeeze(0)
   visualize_attnmaps([
        (ws_a0, "Scan1"),
        (ws_a1, "Scan2"),
        (ws_a2, "Scan3"),
        (ws_a3, "Scan4"),
    ], rows=2, savefig=f"{dst_path}/stage{stage}_ws_scan.jpg", fontsize=60, title=title)

However, the structure of ws_rec is as follows, which means ws_rec['stage{}'.format(stage)] is still a dictionary and cannot be indexed with [i][0]. This results in an error during execution.

ws_rec = {
    'stage1': {
        'layer0': {
            'ave_ws': ...,
            'ratio': ...,
            'ws_ah': ...,
            'ws_aw': ...
        },
        'layer1': {
            'ave_ws': ...,
            'ratio': ...,
            'ws_ah': ...,
            'ws_aw': ...
        },
        ...
    },
    'stage2': {
        ...
    },
    ...
}

I tried modifying the code as follows:

ws_a0 = torch.stack([ws_rec['stage{}'.format(stage)]['layer1']['ave_ws'][i][0] for i in range(len(rand_examples))]).mean(0)
ws_a1 = torch.stack([ws_rec['stage{}'.format(stage)]['layer1']['ave_ws'][i][1] for i in range(len(rand_examples))]).mean(0)
ws_a2 = torch.stack([ws_rec['stage{}'.format(stage)]['layer1']['ave_ws'][i][2] for i in range(len(rand_examples))]).mean(0)
ws_a3 = torch.stack([ws_rec['stage{}'.format(stage)]['layer1']['ave_ws'][i][3] for i in range(len(rand_examples))]).mean(0)

With these changes, I can obtain results similar to yours. And I found that only using ['stage3']['layer1']['ave_ws'] in the ms_vssm_tiny model yields results akin to Figure 4 (decay with different scanning routes).

Could you please confirm if there is an error in the original code or if I should make additional modifications?

Looking forward to your response.

Thank you for your assistance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant