-
Notifications
You must be signed in to change notification settings - Fork 418
Open
Labels
questionFurther information is requestedFurther information is requested
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
Questions
Reproduce
import math
from pathlib import Path
import einops as ein
import numpy as np
import tilelang
import tilelang.language as T
import torch
from icecream import ic
import itertools
@tilelang.jit
def make_kernel(
out,
dout,
res,
nstr: int,
tilesize: int,
threads: int = 128,
):
sql = T.dynamic("sql")
dtype = T.float32
assert nstr == 4
assert tilesize == 8
@T.prim_func
def main(
out: T.Tensor([sql, nstr, nstr], dtype),
dout: T.Tensor([sql, nstr, nstr], dtype),
res: T.Tensor([sql, nstr, nstr], dtype),
):
with T.Kernel(T.ceildiv(sql, tilesize), threads=threads) as i_seq:
R = T.alloc_shared([tilesize, nstr, nstr], dtype=dtype)
b = T.alloc_fragment([tilesize, nstr], dtype=dtype)
T.reduce_sum(R, b, dim=-1)
T.reduce_sum(R, b, dim=-2)
return main
def main():
sql = 2048
nstr = 4
tilesize = 8
dist = torch.distributions.uniform.Uniform(0.0, 4.0)
x = dist.sample((sql, nstr, nstr))
x = x.to("cuda")
x.requires_grad_()
stx = torch.randn_like(x)
dstx = torch.randn_like(stx)
# Method 2: NumPy Implicit differentiation
stx = stx.detach()
gmi = torch.zeros_like(stx)
make_kernel(
stx,
dstx,
gmi,
nstr,
tilesize,
)
if __name__ == "__main__":
main()error
File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 435, in __call__
kernel = self.compile(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 374, in compile
kernel_result = compile(
^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 98, in compile
return cached(
^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/cache/__init__.py", line 74, in cached
return _dispatch_map[execution_backend].cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/cache/kernel_cache.py", line 204, in cached
kernel = JITKernel(
^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 137, in __init__
adapter = self._compile_and_create_adapter(func, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 242, in _compile_and_create_adapter
artifact = tilelang.lower(
^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/engine/lower.py", line 248, in lower
mod = LowerAndLegalize(mod, target)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/engine/phase.py", line 178, in LowerAndLegalize
mod = tilelang.transform.LayoutInference()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
File "<unknown>", line 0, in tvm::tl::LayoutInferencer::Substitute(tvm::tir::PrimFunc, bool)
File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::Run()
File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::InferInFreeMode(tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void>&, tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void> const&)
File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: Check failed: (min_reg_num < INT64_MAX) is false: no available layout found
versions
Name: tilelang
Version: 0.1.7.post3
Summary: A tile level programming language to generate high performance code.
Home-page:
Author: TileLang Contributors, Tile-AI
Author-email:
License-Expression: MIT
Location: /opt/conda/lib/python3.11/site-packages
Requires: apache-tvm-ffi, cloudpickle, ml-dtypes, numpy, psutil, torch, torch-c-dlpack-ext, tqdm, typing-extensions, z3-solver
Required-by:
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested