Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit 6ca9ad7

Browse files
wasertechmikeshi80casper-hansen
authored
Update build process for flexibility and torch version compatibility (#29)
* add the support for Jetson Orin devices whose compute capability is 87 * don't force CC and CXX in env and be more flexible with torch version required * update torch version requirement to be even more flexible * update setup.py to support multiple compute capabilities * impove compute_capabilities set definition * Add support for Python 3.12 in classifiers tags * (first draft) update readme with new building procedure * Add install notes --------- Co-authored-by: Shi Hui <[email protected]> Co-authored-by: Casper <[email protected]>
1 parent 2136e91 commit 6ca9ad7

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

README.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,25 @@ pip install autoawq-kernels
2323
```
2424

2525
### Build from source
26-
You can also build from source:
26+
27+
To build the kernels from source, you first need to setup an environment containing the necessary dependencies.
28+
29+
#### Build Requirements
30+
31+
- Python>=3.8.0
32+
- Numpy
33+
- Wheel
34+
- PyTorch
35+
- ROCm: You need to install the following packages `rocsparse-dev hipsparse-dev rocthrust-dev rocblas-dev hipblas-dev`.
36+
37+
#### Building process
2738

2839
```
29-
git clone https://github.com/casper-hansen/AutoAWQ_kernels
30-
cd AutoAWQ_kernels
31-
pip install -e .
40+
pip install git+https://github.com/casper-hansen/AutoAWQ_kernels.git
3241
```
3342

34-
To build for ROCm, you need to first install the following packages `rocsparse-dev hipsparse-dev rocthrust-dev rocblas-dev hipblas-dev`.
43+
Notes on environment variables:
44+
- `TORCH_VERSION`: By default, we build using the current version of torch by `torch.__version__`. You can override it with `TORCH_VERSION`.
45+
- `CUDA_VERSION` or `ROCM_VERSION` can also be used to build for a specific version of CUDA or ROCm.
46+
- `CC` and `CXX`: You can specify which build system to use for the C code, e.g. `CC=g++-13 CXX=g++-13 pip install -e .`
47+
- `COMPUTE_CAPABILITIES`: You can specify specific compute capabilities to compile for: `COMPUTE_CAPABILITIES="75,80,86,87,89,90" pip install -e .`

setup.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from distutils.sysconfig import get_python_lib
66
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
77

8-
os.environ["CC"] = "g++"
9-
os.environ["CXX"] = "g++"
8+
if "CC" not in os.environ:
9+
os.environ["CC"] = "g++"
10+
if "CXX" not in os.environ:
11+
os.environ["CXX"] = "g++"
1012
AUTOAWQ_KERNELS_VERSION = "0.0.8"
1113
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
14+
COMPUTE_CAPABILITIES = os.getenv("COMPUTE_CAPABILITIES", "75,80,86,87,89,90")
15+
TORCH_VERSION = str(os.getenv("TORCH_VERSION", None) or torch.__version__).split('+', maxsplit=1)[0]
1216
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
1317
ROCM_VERSION = os.environ.get("ROCM_VERSION", None) or torch.version.hip
1418

@@ -57,7 +61,7 @@
5761
}
5862

5963
requirements = [
60-
"torch==2.4.1",
64+
f"torch>={TORCH_VERSION}",
6165
]
6266

6367

@@ -91,7 +95,7 @@ def get_generator_flag():
9195

9296

9397
def get_compute_capabilities(
94-
compute_capabilities={75, 80, 86, 89, 90}
98+
compute_capabilities=set(map(int, COMPUTE_CAPABILITIES.split(",")))
9599
):
96100
capability_flags = []
97101

0 commit comments

Comments
 (0)