Skip to content

修复HWAB模块中的设备不一致问题(张量未正确分配到GPU) #12

@Fan-hr

Description

@Fan-hr

问题

当模型部署在GPU时,使用HWAB模块会抛出设备不匹配错误。尽管显式调用了.cuda()IWT模块中的中间张量h默认在CPU上创建,导致RuntimeError: Expected all tensors to be on the same device

错误位置:
iwt_init函数中创建的张量h未指定设备,导致其停留在CPU。

h = torch.zeros([out_batch, out_channel, out_height, out_width])

建议改成

def iwt_init(x):
    ...
    h = torch.zeros(
        (out_batch, out_channel, out_height, out_width), device=x.device
    )
    ...

同时修改main函数

block = HWAB(n_feat=512, o_feat=64).cuda()
input = torch.randn(1, 512, 64, 64).cuda()
output = block(input)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions