Skip to content

Commit 760a255

Browse files
author
Gopalakrishnan Nallasamy
committed
WebGPU Tile: clamp output element budget to uint32 max
1 parent fe7cdf3 commit 760a255

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

  • onnxruntime/core/providers/webgpu/tensor

onnxruntime/core/providers/webgpu/tensor/tile.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/providers/webgpu/webgpu_supported_types.h"
99

1010
#include <limits>
11+
#include <algorithm>
1112

1213
namespace onnxruntime {
1314
namespace webgpu {
@@ -76,7 +77,14 @@ Status Tile::ComputeInternal(ComputeContext& context) const {
7677
? static_cast<int64_t>(std::numeric_limits<size_t>::max())
7778
: kMaxTileOutputBytes;
7879
const int64_t element_size = narrow<int64_t>(input_tensor->DataType()->Size());
79-
const int64_t max_elements = kMaxSupportedTileOutputBytes / element_size;
80+
// The WebGPU shader uses a uint32_t uniform for the total output element
81+
// count and dispatches based on it. Clamp the per-element budget to
82+
// uint32_t::max() so that any combination of repeats producing more than
83+
// 2^32 - 1 elements is rejected by the byte-cap check below instead of
84+
// silently truncating to a smaller dispatch / OOB-guard value.
85+
const int64_t max_elements =
86+
std::min<int64_t>(kMaxSupportedTileOutputBytes / element_size,
87+
static_cast<int64_t>(std::numeric_limits<uint32_t>::max()));
8088
int64_t total_elements = 1;
8189
for (size_t axis = 0; axis < input_rank; axis++) {
8290
if (repeats_data[axis] < 0) {

0 commit comments

Comments
 (0)