-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Panic w/ backwards pass when combining gather and max_dim #1687
Labels
bug
Something isn't working
Comments
I was only able to reproduce the bug on the ndarray backend, it seems to work on the tch backend. You can see the test on the branch: |
Hmm, odd that it didn't reproduce with tch! I had whittled down the example to be minimal, and indeed, that one doesn't cause a panic on tch for me either. Perhaps this is two bugs in a trenchcoat pretending to be one! Here's a specific snippet which does crash for me: let a: Vec<f32> = vec![-0.35060948, -0.6759874, -1.2398422, -0.55234957];
let b = [2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2];
let b: Tensor<Autodiff<LibTorch>, 2, Int> =
Tensor::from_data(Data::from(b.as_slice()), &LibTorchDevice::default()).reshape([5, 4]);
let a = Tensor::from_data(Data::from(a.as_slice()), &LibTorchDevice::default())
.reshape([1, 4])
.require_grad();
let grammar: Tensor<_, 2> = a.clone().repeat(0, 5);
let loss = grammar.gather(1, b);
let loss = loss.clone().max_dim(0) + loss;
let loss = loss.sum();
let g = loss.backward(); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, there seems to be a problem with keeping track of the number of dimensions when doing some kind of combination of
max_dim
andgather
. The following code will lead to a panic complaining about the number of dimensions, while it won't have any issue if we get rid of themax_dim
line. This also doesn't seem related to any specific backend: I found it initially when using the tch backendTo Reproduce
When run, produces the following output:
The text was updated successfully, but these errors were encountered: