-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Incorrect gradient shapes with conv_general and flipped kernels #1826
Comments
The third example (using Will look into the first one using the |
Thank you! This is great. I noticed the issue with Is there a scheduled date for the next tagged release? |
We should get out at least a patch release by the end of the week. |
Thank you! This might be separate, but I was wondering if it's addressed in the same PR. example: import mlx.core as mx
a = mx.random.normal((2, 5, 5, 3))
print(a[:, 0:2, 0:2, :].shape)
i = mx.array(0, dtype=mx.int32)
j = mx.array(2, dtype=mx.int32)
print(a[:, i, j, :].shape)
print(a[:, i:j, i:j, :].shape) # ValueError: Slice indices must be integers or None. |
That is expected and it isn't addressed in the PR I sent. In general operations with shapes which depend on input data are not supported. You could open an issue for this one specifically.. but it's not a simple fix in this case so I can't say whether or when we will be able to support that. |
Describe the bug
I'm using conv_general to implement a transposed convolution layer, and getting incorrect gradient shapes.
Below are three examples of flipping the spatial dimensions of the kernels.
To Reproduce
Expected behavior
I expect the gradients with respect to the kernel to be the same shape as the kernel, regardless of how the spatial axes are flipped.
Desktop (please complete the following information):
Additional context
The forward pass through each layer produces the same output. Just using
conv_general
withflip=True
produces gradient shapes different from the kernel. Manually flipping each spatial axes (withreverse_sequence()
usingmx.take
in the example), produces gradients for the kernel with the same shape as the kernel. And flipping with stride and negative slicing,::-1
, gives a broadcasting error when calling theloss_and_grad
function (is it doing a copy somewhere?).And thank you!
The text was updated successfully, but these errors were encountered: