Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] What does device / embedding_specs.compute_device parameter in ctor of TBE mean? #2395

Open
JacoCheung opened this issue Mar 6, 2024 · 4 comments

Comments

@JacoCheung
Copy link

Hi team, I am confused by the following parameters related to device context ( in TBE ctor. )

What combinations of those parameters are legal? Could anyone provide some hints?

Thanks!

@q10
Copy link
Contributor

q10 commented Mar 6, 2024

Hi @JacoCheung

ComputeDevice specifies the TBE kernel variant (i.e. will the kernel execute on the CPU or CUDA) that will be applied to each embedding table.

EmbeddingLocation specifies the target memory location of the embedding tables that are constructed by the operator (i.e. on-CUDA-device-only, managed(UVM), managed + caching, on-host-only

device specifies the target location of memory buffers used internally by the operator.

There is a list of constraints being checked at runtime TBE construction, including:

  1. ComputeDevice values should be the same across all embedding tables
  2. EmbeddingLocation values can be different
  3. Buf if use_cpu is set or optimizer is set to None, EmbeddingLocation values can only be set to HOST

If the constraints are not met, an error with detailed messages will be thrown, which will help guide you to create TBE with the correct combination of parameters.

We will update our docs to explain this more in details. Let us know if you have other questions. cc @sryap

@JacoCheung
Copy link
Author

@q10 Thanks for your reply! What if the ComputeDevice==CPU while EmbeddingLocation==GPU? Will the kernel spawn some memcpy D2H?

@q10
Copy link
Contributor

q10 commented Mar 7, 2024

@JacoCheung
Copy link
Author

Thanks @q10 . I have another question (but may be beyond this issue's scope).

How are the pyhsical tables allocated (Assume EmbeddingLocation are all device)? Do multiple tables share single memory chunk?

For example, I have 2 embedding tables with different embedding dimensions. Will there be 2 separate memory buffers or single? And how many lookup kernels will be launched while doing the forward?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants