Skip to content

Commit

Permalink
[inductor] Remove usage of device_interface from _inductor.runtime
Browse files Browse the repository at this point in the history
Summary: Redo of #124592, but with necessary internal changes

Test Plan: CI

Reviewed By: jansel

Differential Revision: D56642231
  • Loading branch information
masnesral authored and facebook-github-bot committed Apr 29, 2024
1 parent da44d2f commit f05f68d
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 78 deletions.
3 changes: 2 additions & 1 deletion test/inductor/test_cuda_repro.py
Expand Up @@ -14,6 +14,7 @@
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.runtime.hints import DeviceProperties
from torch._inductor.utils import run_and_get_code
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
Expand Down Expand Up @@ -405,7 +406,7 @@ def decorator(fn):
],
meta={
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
"device": 0,
"device": DeviceProperties.create(torch.device("cuda")),
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
"constants": {},
},
Expand Down
30 changes: 8 additions & 22 deletions torch/_inductor/codecache.py
Expand Up @@ -45,16 +45,12 @@
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

import torch
from torch._dynamo.device_interface import (
get_interface_for_device,
get_registered_device_interfaces,
)
from torch._dynamo.device_interface import get_registered_device_interfaces
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor import config, exc, metrics
from torch._inductor.codegen.cuda import cuda_env
Expand All @@ -70,7 +66,6 @@
from torch.fx.experimental.symbolic_shapes import has_hint, hint_int, ShapeEnv

if TYPE_CHECKING:
from torch._dynamo.device_interface import DeviceInterface
from torch._inductor.graph import GraphLowering
from torch._inductor.ir import ChoiceCaller

Expand Down Expand Up @@ -2823,14 +2818,9 @@ def _set_triton_ptxas_path() -> None:

def _worker_compile_triton(
load_kernel: Callable[[], Any],
cc: int,
device: torch.device,
device_interface: Type[DeviceInterface],
):
_set_triton_ptxas_path()
device_interface.Worker.set_device(device.index)
kernel = load_kernel()
kernel.precompile(warm_cache_only_with_cc=cc)
load_kernel().precompile(warm_cache_only=True)


class CodeCacheFuture:
Expand Down Expand Up @@ -2993,17 +2983,13 @@ def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):

kernel = TritonCodeCache.load(kernel_name, source_code)
if config.compile_threads > 1:
device_interface = get_interface_for_device(device_str)
device = torch.device(device_str, device_interface.current_device())
cc = device_interface.get_compute_capability(device)
future = self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
cc,
device,
device_interface,
return TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
),
)
return TritonFuture(kernel, future)
else:
kernel.precompile()
return kernel
Expand Down
7 changes: 3 additions & 4 deletions torch/_inductor/codegen/triton.py
Expand Up @@ -34,7 +34,7 @@
from torch._dynamo.utils import preserve_rng_state

from torch._inductor.metrics import is_metric_table_enabled, log_kernel_metadata
from torch._inductor.runtime.hints import AutotuneHint
from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
Expand Down Expand Up @@ -125,7 +125,7 @@ def gen_common_triton_imports():
"""
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties
"""
)
return imports.getvalue()
Expand Down Expand Up @@ -2833,8 +2833,7 @@ def codegen_kernel(self, name=None):
)
triton_meta = {
"signature": triton_meta_signature,
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"constants": {},
}

Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/triton_foreach.py
Expand Up @@ -6,6 +6,7 @@
from sympy import Integer

from .. import metrics
from ..runtime.hints import DeviceProperties
from ..scheduler import SchedulerNode
from ..utils import ceildiv, Placeholder
from ..virtualized import V
Expand Down Expand Up @@ -157,8 +158,7 @@ def jit_lines(self):
_, _, signature = self.args.python_argdefs()
triton_meta = {
"signature": signature_to_meta(signature, size_dtype=size_dtype),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"device": DeviceProperties.create(V.graph.scheduler.current_device),
"constants": {},
}
triton_meta["configs"] = [config_of(signature)]
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/wrapper.py
Expand Up @@ -40,6 +40,7 @@
from .. import codecache, config, ir
from ..ir import ReinterpretView
from ..runtime import triton_heuristics
from ..runtime.hints import DeviceProperties
from ..utils import (
cache_on_self,
get_benchmark_name,
Expand Down Expand Up @@ -1130,8 +1131,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs):
size_dtype=index_dtype,
indices=non_constant_indices,
),
"device": V.graph.scheduler.current_device.index,
"device_type": V.graph.scheduler.current_device.type,
"device": DeviceProperties.create(V.graph.scheduler.current_device),
# Triton compiler includes equal_to_1 args into constants even
# when they are not constexpr. otherwise there may be a segfault
# during launching the Inductor-compiled Triton kernel.
Expand Down
44 changes: 44 additions & 0 deletions torch/_inductor/runtime/hints.py
@@ -1,6 +1,8 @@
import collections
import typing
from dataclasses import fields
from enum import auto, Enum
from typing import Optional


# NOTE: if these fail asserts submit a PR to increase them
Expand Down Expand Up @@ -89,3 +91,45 @@ class AutotuneHint(Enum):
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__


class DeviceProperties(typing.NamedTuple):
"""Copy device properties into a data structure not requiring torch to be imported"""

type: str # type: ignore[assignment]
index: int # type: ignore[assignment]
cc: int
major: Optional[int] = None
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
multi_processor_count: Optional[int] = None

@classmethod
def create(cls, device):
import torch
from torch._dynamo.device_interface import get_interface_for_device

device_type = device.type if torch.version.hip is None else "hip"
device_interface = get_interface_for_device(device)
if device_type == "cuda":
props = device_interface.get_device_properties(device)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
major=props.major,
regs_per_multiprocessor=props.regs_per_multiprocessor,
max_threads_per_multi_processor=props.max_threads_per_multi_processor,
multi_processor_count=props.multi_processor_count,
)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
)

@classmethod
def create_from_args(cls, *args, **kwargs):
import torch

return cls.create(torch.device(*args, **kwargs))

0 comments on commit f05f68d

Please sign in to comment.