Skip to content

Commit

Permalink
[CUDA][64-bit indexing] Support 64-bit indexing in `distribution_elem…
Browse files Browse the repository at this point in the history
…entwise_grid_stride_kernel` (pytorch#141613)

For pytorch#141544
Overhead doesn't seem to be noticeable even on small sizes (e.g., 2**10 elements)

Pull Request resolved: pytorch#141613
Approved by: https://github.com/Skylion007, https://github.com/ngimel
  • Loading branch information
eqy authored and pytorchmergebot committed Nov 30, 2024
1 parent 7fafaa9 commit 9532589
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/native/cuda/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ std::tuple<uint64_t, dim3, dim3> calc_execution_policy(const int64_t total_eleme
// grid stride loop kernel for distributions
template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
__global__ void distribution_elementwise_grid_stride_kernel(int numel,
__global__ void distribution_elementwise_grid_stride_kernel(int64_t numel,
PhiloxCudaState philox_args,
const dist_t dist_func,
const transform_t transform_func) {
auto seeds = at::cuda::philox::unpack(philox_args);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(std::get<0>(seeds),
idx,
std::get<1>(seeds),
&state);

int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
int64_t rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
blockDim.x * gridDim.x * unroll_factor;
for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
for(int64_t linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
auto rand = dist_func(&state);
#pragma unroll
for (int ii = 0; ii < unroll_factor; ii++) {
int li = linear_index + blockDim.x * gridDim.x * ii;
int64_t li = linear_index + blockDim.x * gridDim.x * ii;
if (li < numel) {
transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
}
Expand Down
7 changes: 7 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
largeTensorTest,
onlyCUDA,
onlyNativeDeviceTypes,
)
Expand Down Expand Up @@ -1051,6 +1052,12 @@ def run(dev: torch.device) -> int:
abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000
)

@largeTensorTest("20GB", "cuda")
def test_randint_generation_for_large_numel(self) -> None:
numel = 2**31 + 1
s = torch.randint(2, (numel,), device="cuda", dtype=torch.int8).sum()
self.assertTrue(s > 0, "expected randint in [0, 1] to generate nonzero values")

@parametrize("dtype", [torch.float32, torch.double])
def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None:
# Test if random states do not overlap between consecutive rand/randn calls.
Expand Down

0 comments on commit 9532589

Please sign in to comment.