Skip to content

Commit 135631b

Browse files
committed
zluda enable triton experimentally
1 parent 878fb9f commit 135631b

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

modules/zluda_hijacks.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,36 @@ def topk(input: torch.Tensor, *args, **kwargs): # pylint: disable=redefined-buil
99
return torch.return_types.topk((values.to(device), indices.to(device),))
1010

1111

12+
class DeviceProperties:
13+
PROPERTIES_OVERRIDE = {"regs_per_multiprocessor": 65535}
14+
internal: torch._C._CudaDeviceProperties
15+
16+
def __init__(self, props: torch._C._CudaDeviceProperties):
17+
self.internal = props
18+
19+
def __getattr__(self, name):
20+
if name in DeviceProperties.PROPERTIES_OVERRIDE:
21+
return DeviceProperties.PROPERTIES_OVERRIDE[name]
22+
return getattr(self.internal, name)
23+
24+
25+
__get_device_properties = torch.cuda._get_device_properties # pylint: disable=protected-access
26+
def torch_cuda__get_device_properties(device):
27+
return DeviceProperties(__get_device_properties(device))
28+
29+
1230
def do_hijack():
1331
torch.version.hip = rocm.version
1432
torch.topk = topk
33+
34+
torch.cuda._get_device_properties = torch_cuda__get_device_properties # pylint: disable=protected-access
35+
try:
36+
import triton
37+
_get_device_properties = triton.runtime.driver.active.utils.get_device_properties
38+
def triton_runtime_driver_active_utils_get_device_properties(device):
39+
props = _get_device_properties(device)
40+
props["mem_bus_width"] = 384
41+
return props
42+
triton.runtime.driver.active.utils.get_device_properties = triton_runtime_driver_active_utils_get_device_properties
43+
except Exception:
44+
pass

0 commit comments

Comments
 (0)