-
Notifications
You must be signed in to change notification settings - Fork 363
Open
Description
问题
当模型部署在GPU时,使用HWAB模块会抛出设备不匹配错误。尽管显式调用了.cuda(),IWT模块中的中间张量h默认在CPU上创建,导致RuntimeError: Expected all tensors to be on the same device。
错误位置:
iwt_init函数中创建的张量h未指定设备,导致其停留在CPU。
h = torch.zeros([out_batch, out_channel, out_height, out_width])建议改成
def iwt_init(x):
...
h = torch.zeros(
(out_batch, out_channel, out_height, out_width), device=x.device
)
...同时修改main函数
block = HWAB(n_feat=512, o_feat=64).cuda()
input = torch.randn(1, 512, 64, 64).cuda()
output = block(input)Metadata
Metadata
Assignees
Labels
No labels