From 9532589b53b5f40e3e0222b12a65366584c4271c Mon Sep 17 00:00:00 2001 From: eqy Date: Sat, 30 Nov 2024 06:55:02 +0000 Subject: [PATCH] [CUDA][64-bit indexing] Support 64-bit indexing in `distribution_elementwise_grid_stride_kernel` (#141613) For #141544 Overhead doesn't seem to be noticeable even on small sizes (e.g., 2**10 elements) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141613 Approved by: https://github.com/Skylion007, https://github.com/ngimel --- aten/src/ATen/native/cuda/DistributionTemplates.h | 10 +++++----- test/test_cuda.py | 7 +++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/cuda/DistributionTemplates.h b/aten/src/ATen/native/cuda/DistributionTemplates.h index 49b05fc3c50d8..f3807f2b7e0e8 100644 --- a/aten/src/ATen/native/cuda/DistributionTemplates.h +++ b/aten/src/ATen/native/cuda/DistributionTemplates.h @@ -63,25 +63,25 @@ std::tuple calc_execution_policy(const int64_t total_eleme // grid stride loop kernel for distributions template C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound) -__global__ void distribution_elementwise_grid_stride_kernel(int numel, +__global__ void distribution_elementwise_grid_stride_kernel(int64_t numel, PhiloxCudaState philox_args, const dist_t dist_func, const transform_t transform_func) { auto seeds = at::cuda::philox::unpack(philox_args); - int idx = blockIdx.x * blockDim.x + threadIdx.x; + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; curandStatePhilox4_32_10_t state; curand_init(std::get<0>(seeds), idx, std::get<1>(seeds), &state); - int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * + int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) * blockDim.x * gridDim.x * unroll_factor; - for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { + for(int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) { auto rand = dist_func(&state); #pragma unroll for (int ii = 0; ii < unroll_factor; ii++) { - int li = linear_index + blockDim.x * gridDim.x * ii; + int64_t li = linear_index + blockDim.x * gridDim.x * ii; if (li < numel) { transform_func(li, static_cast((&rand.x)[ii])); } diff --git a/test/test_cuda.py b/test/test_cuda.py index bcb7e4dfdeeef..54383da438d21 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -36,6 +36,7 @@ ) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, + largeTensorTest, onlyCUDA, onlyNativeDeviceTypes, ) @@ -1051,6 +1052,12 @@ def run(dev: torch.device) -> int: abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000 ) + @largeTensorTest("20GB", "cuda") + def test_randint_generation_for_large_numel(self) -> None: + numel = 2**31 + 1 + s = torch.randint(2, (numel,), device="cuda", dtype=torch.int8).sum() + self.assertTrue(s > 0, "expected randint in [0, 1] to generate nonzero values") + @parametrize("dtype", [torch.float32, torch.double]) def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None: # Test if random states do not overlap between consecutive rand/randn calls.