Skip to content

[Question] Why I am getting this internal error #1714

@Da1sypetals

Description

@Da1sypetals

Required prerequisites

Questions

Reproduce

import math
from pathlib import Path

import einops as ein
import numpy as np
import tilelang
import tilelang.language as T
import torch
from icecream import ic
import itertools


@tilelang.jit
def make_kernel(
    out,
    dout,
    res,
    nstr: int,
    tilesize: int,
    threads: int = 128,
):
    sql = T.dynamic("sql")
    dtype = T.float32

    assert nstr == 4
    assert tilesize == 8

    @T.prim_func
    def main(
        out: T.Tensor([sql, nstr, nstr], dtype),
        dout: T.Tensor([sql, nstr, nstr], dtype),
        res: T.Tensor([sql, nstr, nstr], dtype),
    ):
        with T.Kernel(T.ceildiv(sql, tilesize), threads=threads) as i_seq:
            R = T.alloc_shared([tilesize, nstr, nstr], dtype=dtype)
            b = T.alloc_fragment([tilesize, nstr], dtype=dtype)
            T.reduce_sum(R, b, dim=-1)
            T.reduce_sum(R, b, dim=-2)

    return main


def main():
    sql = 2048
    nstr = 4
    tilesize = 8

    dist = torch.distributions.uniform.Uniform(0.0, 4.0)
    x = dist.sample((sql, nstr, nstr))
    x = x.to("cuda")
    x.requires_grad_()

    stx = torch.randn_like(x)
    dstx = torch.randn_like(stx)

    # Method 2: NumPy Implicit differentiation
    stx = stx.detach()
    gmi = torch.zeros_like(stx)
    make_kernel(
        stx,
        dstx,
        gmi,
        nstr,
        tilesize,
    )


if __name__ == "__main__":
    main()

error

  File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 435, in __call__
    kernel = self.compile(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 374, in compile
    kernel_result = compile(
                    ^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 98, in compile
    return cached(
           ^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/cache/__init__.py", line 74, in cached
    return _dispatch_map[execution_backend].cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/cache/kernel_cache.py", line 204, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 137, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 242, in _compile_and_create_adapter
    artifact = tilelang.lower(
               ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/engine/lower.py", line 248, in lower
    mod = LowerAndLegalize(mod, target)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/engine/phase.py", line 178, in LowerAndLegalize
    mod = tilelang.transform.LayoutInference()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/tilelang/3rdparty/tvm/python/tvm/ir/transform.py", line 167, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python/tvm_ffi/cython/function.pxi", line 923, in tvm_ffi.core.Function.__call__
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule) const
  File "<unknown>", line 0, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  File "<unknown>", line 0, in tvm::tl::LayoutInferencer::Substitute(tvm::tir::PrimFunc, bool)
  File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::Run()
  File "<unknown>", line 0, in tvm::tl::BufferUseDefCollector::InferInFreeMode(tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void>&, tvm::ffi::Map<tvm::tir::Buffer, tvm::tl::Layout, void> const&)
  File "<unknown>", line 0, in tvm::runtime::detail::LogFatal::Entry::Finalize()
tvm.error.InternalError: Check failed: (min_reg_num < INT64_MAX) is false: no available layout found

versions

Name: tilelang
Version: 0.1.7.post3
Summary: A tile level programming language to generate high performance code.
Home-page: 
Author: TileLang Contributors, Tile-AI
Author-email: 
License-Expression: MIT
Location: /opt/conda/lib/python3.11/site-packages
Requires: apache-tvm-ffi, cloudpickle, ml-dtypes, numpy, psutil, torch, torch-c-dlpack-ext, tqdm, typing-extensions, z3-solver
Required-by: 

Metadata

Metadata

Assignees

Labels

questionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions