-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
argmax_internal::PairInputIterator and argmax_internal::PairOutputIterator don't implement the random access iterator concept #1277
Comments
Thank you for reporting. I am trying to reproduce it locally and will try to fix it if I could reproduce it. |
Sorry that I could not reproduce it with the latest commit of cub. Here is my change to diff --git a/cmake/cub.cmake b/cmake/cub.cmake
index dd66606a..51e85b67 100644
--- a/cmake/cub.cmake
+++ b/cmake/cub.cmake
@@ -20,18 +20,22 @@ function(download_cub)
include(FetchContent)
- set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
- set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/1.15.0.tar.gz")
- set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
+ # set(cub_URL "https://github.com/NVlabs/cub/archive/1.15.0.tar.gz")
+ # set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/1.15.0.tar.gz")
+ # set(cub_HASH "SHA256=1781ee5eb7f00acfee5bff88e3acfc67378f6b3c24281335e18ae19e1f2ff685")
+
+ set(cub_URL "https://github.com/NVlabs/cub/archive/0fc3c3701632a4be906765b73be20a9ad0da603d.zip")
+ set(cub_URL2 "https://hub.nuaa.cf/NVlabs/cub/archive/0fc3c3701632a4be906765b73be20a9ad0da603d.zip")
+ set(cub_HASH "SHA256=88dc9f86564f4a76f4407cdc98eec2dd1cfdca9d92fcf6e1d2a51f6456e118b5")
# If you don't have access to the Internet,
# please pre-download cub
set(possible_file_locations
- $ENV{HOME}/Downloads/cub-1.15.0.tar.gz
- ${CMAKE_SOURCE_DIR}/cub-1.15.0.tar.gz
- ${CMAKE_BINARY_DIR}/cub-1.15.0.tar.gz
- /tmp/cub-1.15.0.tar.gz
- /star-fj/fangjun/download/github/cub-1.15.0.tar.gz
+ $ENV{HOME}/Downloads/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+ ${CMAKE_SOURCE_DIR}/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+ ${CMAKE_BINARY_DIR}/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+ /tmp/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
+ /star-fj/fangjun/download/github/cub-0fc3c3701632a4be906765b73be20a9ad0da603d.zip
)
foreach(f IN LISTS possible_file_locations) I am using torch 1.13.0, cuda 11.6 There are no PyTorch versions for cuda 12.3 and we don't have such an environment locally. k2 can be built for all currently available versions of PyTorch. |
Just to note, there is no problem with mixing cuda toolkit versions within a single application. The main requirement is that you use a cuda driver (aka, libcuda.so and the associated .ko files). There are a few instances in NeMo where we depend upon, e.g., cuda 12.3 or cuda 12.4 features, but use pytorch versions that link to older versions of the cuda toolkit. |
I can take a look at some point, but it's not high priority for me right now. If someone else runs into this issue, they should make a comment here. |
Hi all, a colleague recently reached out to me asking for help becuase he couldn't build k2 from source, citing the following cub error, using a newer version of cub than the one you build with:
Unfortunately, someone made a change to cub in the recent past that changes from operator[] to operator+ and operator* in this segmented reduce implementation.
The root cause is that these two types
k2/k2/csrc/ragged_ops_inl.h
Line 550 in e8158de
k2/k2/csrc/ragged_ops_inl.h
Line 580 in e8158de
This page handily shows the methods that must be implemented to implement a random access iterator: https://en.cppreference.com/w/cpp/iterator/random_access_iterator
Anyway, the point is that a few more methods need to be implemented on those types (and maybe others) to fix this error. I could probably do this myself, but I haven't really kept up with k2 work recently, so if someone would be willing to do so themselves, that would be much appreciated. Unfortunately I'm not sure of a way to "statically assert" that a specific type implements the Random Access Iterator concept right now...
Anyway, you can reproduce the error by building with the HEAD version of cub, or just the cub version including in the latest CUDA 12.3 toolkit.
The text was updated successfully, but these errors were encountered: