Skip to content
This repository was archived by the owner on Mar 2, 2025. It is now read-only.

Index #80

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Index #80

wants to merge 6 commits into from

Conversation

StijnWoestenborghs
Copy link
Collaborator

@StijnWoestenborghs StijnWoestenborghs commented May 14, 2024

INDEX: An extension on SLICE that adds the ability to select values form a tensor based on specified indeces (or slices).
Note that it also adds the ability to duplicate values from a dimension. (This feature is used in upsampling)

Short demo of the expected behaviour:

t = torch.arange(2*2).reshape(2, 2).float().requires_grad_(True)

r = t[:, [0, 0, 0]]
print(r) # tensor([[0., 0., 0.], [2., 2., 2.]], grad_fn=<IndexBackward0>)

ug = torch.ones_like(r)
r.backward(ug)
print(t.grad) # tensor([[3., 0.], [3., 0.]])
  • product itertools (also required in upsampling)
  • forward (done but not optimized)
  • backward (done but not optimized)
  • unittests

Notes: Using the attributes like it is right now feels hacky and not the right way to do it. One way of doing it could be to convert all slices to indeces in the API, but it would make the required size of the attribute a lot bigger then it is now. Even now you are basically restricted to specify up to MAX_RANK indeces in the list. I feel like this just needs an attribute of type List[Int] which can potentially be very big. But that has some implication on the size of the other attributes as well.

@StijnWoestenborghs StijnWoestenborghs marked this pull request as draft May 14, 2024 13:04
@StijnWoestenborghs StijnWoestenborghs mentioned this pull request May 15, 2024
4 tasks
@StijnWoestenborghs
Copy link
Collaborator Author

StijnWoestenborghs commented May 31, 2024

I was looking at this again,

  1. comparing it with mlx https://github.com/ml-explore/mlx/blob/main/mlx/backend/common/indexing.cpp#L256-L258
    They have something like a scatter that accepts a ReduceType Sum which allows this bw pass to be implemented correctly.
    As you said the stdlib overwrites (https://github.com/modularml/mojo/blob/main/stdlib/src/sys/intrinsics.mojo#L903) and it also doesn't look like there is an LLVM intrinsic that supports this https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics.
    I think the correct way to go is to either leave it as it is now, or implement the scatter::sum ourselves

  2. About the way attributes are being saved. I'm not convinced as well. could you explain how this would work?

# indeces = list(0, 0, 1, 0)
# dimensions_indeces = list(0, 0, 0, 1) # corresponding to the dimensions for indeces.).

The only different way of doing it i can really think of is having something like, which to me looks like it is becoming too much.

- starts, ends, stops, slice_dimensions
- indeces, indeces_dimensions

Or create a the whole mask/indeces to be selected beforehand and pass that as attribute. But my worry is that the current attributes are not big enough to support this.
Any thoughts on these 2 points?

@andresnowak
Copy link
Collaborator

andresnowak commented May 31, 2024

Maybe the way I showed for the attributes is too much, but in the end in the first place when we are able to have a frontend the user in reality wouldn't call the code that way, they would do it like in pytorch, mlx or other (not like onnx the 99% of the time), and I feel it is better to have the attributes declared this way.

starts: list
ends: list
steps: list
slice_dimensions: list
indices: list # ex (0, 0, 1, 2), because of indeces dimension  value 0 is repeated on dimenions 0 and 1 and 2 work on dimension 1
indices_dimensions: list # ex (0, 0, 1, 1)

It let's us declare more info easily I would say (extra: a slice and indices can't work on the same dimension, but I think you already no). But that's how I see it I don't know.

And for the scatter part, if you say mlx is doing the slice and index combined instead of how pytorch does calling the two functions separately (slice and then index), then yeah maybe for now we can leave the code as it is and later implement the scatter function (and i also don't how to do it, because in the first place not all processors have a scatter simd functionality baked in from what I understand, so I think thats why mojo uses the scatter llvm instrinsic).

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants