Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can I launch cupy kernels in C++? #8232

Open
chaoming0625 opened this issue Mar 10, 2024 · 15 comments
Open

Can I launch cupy kernels in C++? #8232

chaoming0625 opened this issue Mar 10, 2024 · 15 comments

Comments

@chaoming0625
Copy link

Description

I try to define a kernel by using cupy's cupy.RawKernel (https://docs.cupy.dev/en/stable/reference/generated/cupy.RawKernel.html). And then, can I launch the kernel from my customized C++ code? and passing a cuda stream into the kernel to launch? Thanks!

Additional Information

No response

@chaoming0625 chaoming0625 added the cat:feature New features/APIs label Mar 10, 2024
@takagi
Copy link
Member

takagi commented Mar 11, 2024

You cannot launch a kernel defined by cupy.RawKernel, however, an option may be using cupy.RawModule that can be used to load a .cubin or .ptx file. Does it fit?
https://docs.cupy.dev/en/stable/reference/generated/cupy.RawModule.html

@chaoming0625
Copy link
Author

So a great answer! Therefore, the key is to use RawModule to generate a .cubin or .ptx file, then I load the generated under the c++ backend to run. Am I right?

@chaoming0625
Copy link
Author

Moreover, can cupyx.jit.rawkernel compiled kernels to be saved into a .ptx file?

@takagi
Copy link
Member

takagi commented Mar 11, 2024

What I meant was the opposite. You write in C++ (.cu file) and compile it into .cubin or .ptx files, then you can use them from the RawModule. cupyx.jit.rawkernel doesn't have a feature to save .ptx file in a way that is easily usable from external programs.

@chaoming0625
Copy link
Author

Thanks for the explanation. I am wondering how to get the compiled binary when using cupy.RawKernel?

@takagi
Copy link
Member

takagi commented Mar 11, 2024

Please supply a path to the compiled binary to path argument of RawModule (not RawKernel). https://docs.cupy.dev/en/stable/reference/generated/cupy.RawModule.html

@chaoming0625
Copy link
Author

I try to use cupy to compile the cuda code, and get its compiled kernel, rather than providing the path of a compiled CUDA binary (*.cubin) or a PTX file. So, I am wondering how to provide the cuda source code and then get the cupy compiled binary file path?

@chaoming0625
Copy link
Author

Or, how can I get the kernel under $HOME/.cupy/kernel_cache/ directory? The name of a .cubin file seems have no pattern?

@kmaehashi
Copy link
Member

Are there any specific reasons to use CuPy for that purpose? If your C++ application needs to compile CUDA code on the fly, you can just call NVRTC to get cubin/ptx.

@kmaehashi kmaehashi added issue-checked and removed cat:feature New features/APIs labels Mar 11, 2024
@leofang
Copy link
Member

leofang commented Mar 13, 2024

In theory, for a given RawKernel (either you use it directly, or get it via RawModule.get_function) you can retrieve the CUFunction pointer via RawKernel.kernel.ptr, but

  1. This is not public API
  2. This is untested

It's unclear to me either why you'd need this, @chaoming0625 could you elaborate?

@chaoming0625
Copy link
Author

chaoming0625 commented Mar 13, 2024

Thank you @kmaehashi @leofang . Currently, I am using the pointer in RawKernel.kernel.ptr just as @leofang pointed out. However, I also agree the suggestion of @kmaehashi is right.

The motivation for my question is to use cupy as a compiler to compile custom CUDA extensions on JAX. JAX's jit system needs to register an XLA custom call when using customized Cuda kernels. Usually, we need to write Cuda code, pre-compile it, bind it to Python, and register kernels in XLA. To remove such a complex process, we can directly compile the source code (as a Python code) at the Python level, then get the compiled kernel, throw it into the custom call, and all things are compatible with jax's jit system, with the minimal efforts (only writing the Python string).

Currently, we are working on this functionality.

@chaoming0625
Copy link
Author

Moreover, can I get the pointer of the function after compiling through cupyx.jit.rawkernel?

@leofang
Copy link
Member

leofang commented Mar 13, 2024

Thanks for sharing your use case @chaoming0625, this is very interesting!

Would you be able to point us how you use this capability to make CuPy and Jax interoperable at the kernel level? I would love to see how it allows you to avoid writing complex boilerplate code. Eventually, I would like to learn how to craft a small interop demo like the one we showed for PyTorch-CuPy:
https://docs.cupy.dev/en/stable/user_guide/interoperability.html#using-custom-kernels-in-pytorch
If you already have a small demo that we can copy/paste to the document that's even better! 😄

Moreover, can I get the pointer of the function after compiling through cupyx.jit.rawkernel?

Right now it's not public API either, but according to the internal implementation (subject to change)

kern, enable_cg = self._cache.get((in_types, device_id), (None, None))

it is possible to get the Function object from jit.rawkernel._cache once instantiated (it's the key value of the cache). Then, you can get the CUFunction pointer via Function.ptr as before.

If you show us your workflow as I ask above, it'll help us stabilize the interface and expose these features properly. Thanks!

@leofang
Copy link
Member

leofang commented Mar 27, 2024

Friendly nudge @chaoming0625 :)

@chaoming0625
Copy link
Author

@leofang So many thanks for the reminder.

Currently, we provide an interface for wrapping a cupy's RawKernel and cupyx.jit.rawkernel as a jittable operator in JAX. The examples can be found in our documentation. Please feel free to try the API we have provided here.

The key for such integration is getting the pointer of cupy compiled kernels. For RawKernel, it is convenient to get its pointer by using [kernel.kernel.ptr].

https://github.com/brainpy/BrainPy/blob/87858c54b4f4e45192ad4d9c6ff359f4c1e7ecf8/brainpy/_src/math/op_register/cupy_based.py#L116

However, for cupyx.jit.rawkernel, we found that it does not have an accessible pointer. So, we hacked into the cupy and wrote a compilation process for cupyx.jit.rawkernel to get its pointer.

https://github.com/brainpy/BrainPy/blob/87858c54b4f4e45192ad4d9c6ff359f4c1e7ecf8/brainpy/_src/math/op_register/cupy_based.py#L154-L189

In the future, we are expecting a publicly available interface to get the compiled kernel of a cupyx.jit.rawkernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants