Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect gradient shapes with conv_general and flipped kernels #1826

Closed
acsweet opened this issue Feb 3, 2025 · 5 comments · Fixed by #1827
Closed

[BUG] Incorrect gradient shapes with conv_general and flipped kernels #1826

acsweet opened this issue Feb 3, 2025 · 5 comments · Fixed by #1827
Assignees
Labels
bug Something isn't working

Comments

@acsweet
Copy link

acsweet commented Feb 3, 2025

Describe the bug
I'm using conv_general to implement a transposed convolution layer, and getting incorrect gradient shapes.
Below are three examples of flipping the spatial dimensions of the kernels.

To Reproduce

import mlx.core as mx
import mlx.nn as nn

inputs = mx.random.normal((1, 14, 14, 2))
kernel = mx.random.normal((2, 7, 7, 2))
bias = mx.random.normal((2,))
target = mx.random.normal((1, 224, 224, 2))

def forward_fn_with_flip(params, inputs):
    kernel, bias = params
    result = mx.conv_general(
        inputs,
        kernel,
        stride=1,
        padding=([6, 6], [15, 15]),
        kernel_dilation=(1, 1),
        input_dilation=(16, 16),
        groups=1,
        flip=True,
    )
    return result + bias.reshape(1, 1, 1, -1)

def reverse_sequence(xs, axis=0):
    indices = mx.arange(xs.shape[axis] - 1, -1, -1)
    return mx.take(xs, indices, axis=axis)

def forward_fn_manual_flip(params, inputs):
    kernel, bias = params

    for ax in range(1, kernel.ndim - 1):
        kernel = reverse_sequence(kernel, axis=ax)
    
    result = mx.conv_general(
        inputs,
        kernel,
        stride=1,
        padding=([6, 6], [15, 15]),
        kernel_dilation=(1, 1),
        input_dilation=(16, 16),
        groups=1,
        flip=False,
    )
    return result + bias.reshape(1, 1, 1, -1)

def forward_fn_slice_flip(params, inputs):
    kernel, bias = params
    kernel = kernel[:, ::-1, ::-1, :]

    result = mx.conv_general(
        inputs,
        kernel,
        stride=1,
        padding=([6, 6], [15, 15]),
        kernel_dilation=(1, 1),
        input_dilation=(16, 16),
        groups=1,
        flip=False,
    )
    return result + bias.reshape(1, 1, 1, -1)


params = (kernel, bias)
a = forward_fn_with_flip(params, inputs)
b = forward_fn_manual_flip(params, inputs)
c = forward_fn_slice_flip(params, inputs)
# assert forward passes give same output
assert mx.all(mx.isclose(a, b))
assert mx.all(mx.isclose(b, c))

print(f'kernel shape: {kernel.shape}, bias shape: {bias.shape}') # (2, 7, 7, 2) and (2,)


##########################################
# conv_general(flip=True)
##########################################
def loss_fn(params, inputs, target):
    pred = forward_fn_with_flip(params, inputs)
    return mx.mean((pred - target) ** 2)

loss_and_grad = mx.value_and_grad(loss_fn)
params = (kernel, bias)
loss, grads = loss_and_grad(params, inputs, target)

print(f'(a) kernel grads shape {grads[0].shape}') # (2, 16, 16, 2)
print(f'(a) bias grads shape {grads[1].shape}') # (2,)


##########################################
# manually flip each kernel spatial axes
##########################################
def loss_fn(params, inputs, target):
    pred = forward_fn_manual_flip(params, inputs)
    return mx.mean((pred - target) ** 2)

loss_and_grad = mx.value_and_grad(loss_fn)
params = (kernel, bias)
loss, grads = loss_and_grad(params, inputs, target)

print(f'(b) kernel grads shape {grads[0].shape}') # (2, 7, 7, 2) expected shape
print(f'(b) bias grads shape {grads[1].shape}') # (2,)


##########################################
# flip kernel with stride and neg slice
##########################################
def loss_fn(params, inputs, target):
    pred = forward_fn_slice_flip(params, inputs)
    return mx.mean((pred - target) ** 2)

loss_and_grad = mx.value_and_grad(loss_fn)
params = (kernel, bias)
 # ValueError: [broadcast_shapes] Shapes (2,7,7,2) and (2,0,0,2) cannot be broadcast.
loss, grads = loss_and_grad(params, inputs, target)

print(f'(c) kernel grads shape {grads[0].shape}')
print(f'(c) bias grads shape {grads[1].shape}')

Expected behavior
I expect the gradients with respect to the kernel to be the same shape as the kernel, regardless of how the spatial axes are flipped.

Desktop (please complete the following information):

  • OS Version: MacOS 15.1
  • Version: 0.22.0
  • Chip: M4

Additional context
The forward pass through each layer produces the same output. Just using conv_general with flip=True produces gradient shapes different from the kernel. Manually flipping each spatial axes (with reverse_sequence() using mx.take in the example), produces gradients for the kernel with the same shape as the kernel. And flipping with stride and negative slicing, ::-1, gives a broadcasting error when calling the loss_and_grad function (is it doing a copy somewhere?).

And thank you!

@awni
Copy link
Member

awni commented Feb 3, 2025

The third example (using ::-1) should be fixed in #1827

Will look into the first one using the flipped=True

@awni awni added the bug Something isn't working label Feb 3, 2025
@awni awni self-assigned this Feb 3, 2025
@acsweet
Copy link
Author

acsweet commented Feb 3, 2025

Thank you!

This is great. I noticed the issue with ::-1 when I was working on the bidirectional rnn piece of the Keras backend, but had trouble reproducing it until now. (I've been using the take with reverse indices as a workaround).

Is there a scheduled date for the next tagged release?

@awni
Copy link
Member

awni commented Feb 4, 2025

We should get out at least a patch release by the end of the week.

@acsweet
Copy link
Author

acsweet commented Feb 4, 2025

Thank you!

This might be separate, but I was wondering if it's addressed in the same PR.
It seems arrays can't be sliced with mlx.array scalar values, is that expected?

example:

import mlx.core as mx

a = mx.random.normal((2, 5, 5, 3))
print(a[:, 0:2, 0:2, :].shape)

i = mx.array(0, dtype=mx.int32)
j = mx.array(2, dtype=mx.int32)
print(a[:, i, j, :].shape)
print(a[:, i:j, i:j, :].shape) # ValueError: Slice indices must be integers or None.

@awni
Copy link
Member

awni commented Feb 4, 2025

That is expected and it isn't addressed in the PR I sent.

In general operations with shapes which depend on input data are not supported. You could open an issue for this one specifically.. but it's not a simple fix in this case so I can't say whether or when we will be able to support that.

@awni awni closed this as completed in #1827 Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants