Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ValueError: [scatter] Cannot calculate VJP with respect to indices. #1439

Open
sachinraja13 opened this issue Sep 26, 2024 · 6 comments

Comments

@sachinraja13
Copy link

Describe the bug
I understand that HungarianMatching algorithm requires linear_sum_assignment from scipy, which needs cost matrix to be evaluated. Hence, I cannot compile my train step function. However, if I use SimpleMatching algorithm and then compile my train step, I get the following error:

    (loss_value, loss_dict), grads = train_step_fn(samples, targets, need_tgt_for_training, return_outputs=False)
  File "site-packages/mlx/nn/utils.py", line 35, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
ValueError: [scatter] Cannot calculate VJP with respect to indices.

Code for matching is as follows:

        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.reshape(bs, num_queries, -1)

        sizes = [v["num_objects"] for v in targets]
        indices = []
        
        for i, c in enumerate(C):
            # print(c.shape)
            if i == 0:
                start_index = 0
                end_index = start_index + sizes[i]
            else:
                start_index = sizes[i-1]
                end_index = start_index + sizes[i]
            cost_matrix = c[:, start_index:end_index]
            size_ = cost_matrix.shape[1]
            idx_i = cost_matrix.argmin(0)
            idx_j = mx.arange(size_)
            indices.append((idx_i, idx_j))

Also, is there a work around to get mx.compile working for Hungarian Matching algorithm?

Will greatly appreciate your help to solve this.

Additional context
Using MLX 0.17.3

@barronalex
Copy link
Collaborator

It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.

Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?

@sachinraja13
Copy link
Author

It's probably a good idea to add an all zeros vjp w.r.t. indices for scatter like we did for gather for consistency.

Is this for your MLX implementation of DINO? Looking at the PyTorch implementation it seems like they get around this by zeroing out the gradients for the Hungarian matcher. Maybe that will work for you in the meantime?

Hi @barronalex : Yes, this is for the implementation of DINO. Seems like zeroing out the gradient should solve the problem for the simple greedy matcher. Not sure if it would solve the problem for the Hungarian Matcher though since it uses scipy's linear_sum_assignment. This is particularly a problem as it requires evaluating the cost matrix which in turn forbids me from calling mx.compile on the value_and_grad function for training. Please correct me if I'm mistaken here.

@barronalex
Copy link
Collaborator

Yes, you would need an MLX implementation of linear_sum_assignment to be able to compile the whole thing. That being said you should be able to compile the rest of model if you put @mx.compile on the model forward but not the loss.

@sachinraja13
Copy link
Author

sachinraja13 commented Sep 26, 2024

Thanks @barronalex , understood.

Refering from this jax implementation, there is another challenge in writing an MLX implementation in python that can be compiled:

#1441

@barronalex
Copy link
Collaborator

Is HungarianMatcher a big performance bottleneck when you're training? If not it might be worth leaving it uncompiled and using the scipy implementation the way the original PyTorch implementation does it.

We could definitely implement something like scipy.optimize.linear_sum_assignment in MLX but my guess is it would require some custom C++/Metal to be competitive performance wise.

@sachinraja13
Copy link
Author

Thanks for your response @barronalex !

I realise that HungarianMatcher is not a big performance bottleneck. However, I'm facing memory inflation while in the MLX port of prepare_for_cdn function.

#1432

I thought that maybe if I compiled the entire computation graph, that would solve the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants