Skip to content

Commit aedd7e1

Browse files
authored
pyproject.toml (#1852)
* Initial basic setup Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm setup reqs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * buil-isolation support Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * rm not needed funcs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix workflows Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix wheel Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix invalid wheel Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix JAX build in baremetal env Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update install inst in readme Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Update build.yml Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * docstring fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent faee0e8 commit aedd7e1

File tree

11 files changed

+56
-101
lines changed

11 files changed

+56
-101
lines changed

.github/workflows/build.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
- name: 'Dependencies'
1919
run: |
2020
apt-get update
21-
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
22-
pip install cmake==3.21.0
21+
apt-get install -y git python3.9 pip cudnn9-cuda-12
22+
pip install cmake==3.21.0 pybind11[global] ninja
2323
- name: 'Checkout'
2424
uses: actions/checkout@v3
2525
with:
@@ -42,8 +42,8 @@ jobs:
4242
- name: 'Dependencies'
4343
run: |
4444
apt-get update
45-
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
46-
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
45+
apt-get install -y git python3.9 pip cudnn9-cuda-12
46+
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
4747
- name: 'Checkout'
4848
uses: actions/checkout@v3
4949
with:
@@ -62,6 +62,8 @@ jobs:
6262
image: ghcr.io/nvidia/jax:jax
6363
options: --user root
6464
steps:
65+
- name: 'Dependencies'
66+
run: pip install pybind11[global]
6567
- name: 'Checkout'
6668
uses: actions/checkout@v3
6769
with:

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository:
216216

217217
.. code-block:: bash
218218
219-
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
219+
pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
220220
221221
When installing from GitHub, you can explicitly specify frameworks using the environment variable:
222222

223223
.. code-block:: bash
224224
225-
NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
225+
NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
226226
227227
conda Installation
228228
^^^^^^^^^^^^^^^^^^

build_tools/jax.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def install_requirements() -> List[str]:
1616
"""Install dependencies for TE/JAX extensions."""
17-
return ["jax[cuda12]", "flax>=0.7.1"]
17+
return ["jax", "flax>=0.7.1"]
1818

1919

2020
def test_requirements() -> List[str]:
@@ -75,20 +75,9 @@ def setup_jax_extension(
7575
# Define TE/JAX as a Pybind11Extension
7676
from pybind11.setup_helpers import Pybind11Extension
7777

78-
class Pybind11CPPExtension(Pybind11Extension):
79-
"""Modified Pybind11Extension to allow custom CXX flags."""
80-
81-
def _add_cflags(self, flags: List[str]) -> None:
82-
if isinstance(self.extra_compile_args, dict):
83-
cxx_flags = self.extra_compile_args.pop("cxx", [])
84-
cxx_flags += flags
85-
self.extra_compile_args["cxx"] = cxx_flags
86-
else:
87-
self.extra_compile_args[:0] = flags
88-
89-
return Pybind11CPPExtension(
78+
return Pybind11Extension(
9079
"transformer_engine_jax",
9180
sources=[str(path) for path in sources],
9281
include_dirs=[str(path) for path in include_dirs],
93-
extra_compile_args={"cxx": cxx_flags},
82+
extra_compile_args=cxx_flags,
9483
)

build_tools/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,3 @@ def copy_common_headers(
354354
new_path = dst_dir / path.relative_to(src_dir)
355355
new_path.parent.mkdir(exist_ok=True, parents=True)
356356
shutil.copy(path, new_path)
357-
358-
359-
def install_and_import(package):
360-
"""Install a package via pip (if not already installed) and import into globals."""
361-
main_package = package.split("[")[0]
362-
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
363-
globals()[main_package] = importlib.import_module(main_package)

build_tools/wheel_utils/build_wheels.sh

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ cd /TransformerEngine
2020
git checkout $TARGET_BRANCH
2121
git submodule update --init --recursive
2222

23+
# Install deps
24+
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja
25+
2326
if $BUILD_METAPACKAGE ; then
2427
cd /TransformerEngine
2528
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
@@ -31,15 +34,15 @@ if $BUILD_COMMON ; then
3134
WHL_BASE="transformer_engine-${VERSION}"
3235

3336
# Create the wheel.
34-
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
37+
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
3538

3639
# Repack the wheel for cuda specific package, i.e. cu12.
37-
/opt/python/cp38-cp38/bin/wheel unpack dist/*
40+
/opt/python/cp310-cp310/bin/wheel unpack dist/*
3841
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
3942
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
4043
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
4144
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
42-
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}
45+
/opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}
4346

4447
# Rename the wheel to make it python version agnostic.
4548
whl_name=$(basename dist/*)
@@ -51,14 +54,14 @@ fi
5154

5255
if $BUILD_PYTORCH ; then
5356
cd /TransformerEngine/transformer_engine/pytorch
54-
/opt/python/cp38-cp38/bin/pip install torch
55-
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
57+
/opt/python/cp310-cp310/bin/pip install torch
58+
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
5659
cp dist/* /wheelhouse/
5760
fi
5861

5962
if $BUILD_JAX ; then
6063
cd /TransformerEngine/transformer_engine/jax
61-
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
64+
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
6265
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
6366
cp dist/* /wheelhouse/
6467
fi

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
[build-system]
6+
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"]
7+
8+
# Use legacy backend to import local packages in setup.py
9+
build-backend = "setuptools.build_meta:__legacy__"
10+

setup.py

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,8 @@
1616
from build_tools.te_version import te_version
1717
from build_tools.utils import (
1818
cuda_archs,
19-
found_cmake,
20-
found_ninja,
21-
found_pybind11,
2219
get_frameworks,
23-
install_and_import,
2420
remove_dups,
25-
cuda_toolkit_include_path,
2621
)
2722

2823
frameworks = get_frameworks()
@@ -36,7 +31,6 @@
3631
if "pytorch" in frameworks:
3732
from torch.utils.cpp_extension import BuildExtension
3833
elif "jax" in frameworks:
39-
install_and_import("pybind11[global]")
4034
from pybind11.setup_helpers import build_ext as BuildExtension
4135

4236

@@ -82,57 +76,34 @@ def setup_common_extension() -> CMakeExtension:
8276
)
8377

8478

85-
def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
79+
def setup_requirements() -> Tuple[List[str], List[str]]:
8680
"""Setup Python dependencies
8781
88-
Returns dependencies for build, runtime, and testing.
82+
Returns dependencies for runtime and testing.
8983
"""
9084

9185
# Common requirements
92-
setup_reqs: List[str] = []
93-
if cuda_toolkit_include_path() is None:
94-
setup_reqs.extend(
95-
[
96-
"nvidia-cuda-runtime-cu12",
97-
"nvidia-cublas-cu12",
98-
"nvidia-cudnn-cu12",
99-
"nvidia-cuda-cccl-cu12",
100-
"nvidia-cuda-nvcc-cu12",
101-
"nvidia-nvtx-cu12",
102-
"nvidia-cuda-nvrtc-cu12",
103-
]
104-
)
10586
install_reqs: List[str] = [
10687
"pydantic",
10788
"importlib-metadata>=1.0",
10889
"packaging",
10990
]
11091
test_reqs: List[str] = ["pytest>=8.2.1"]
11192

112-
# Requirements that may be installed outside of Python
113-
if not found_cmake():
114-
setup_reqs.append("cmake>=3.21")
115-
if not found_ninja():
116-
setup_reqs.append("ninja")
117-
if not found_pybind11():
118-
setup_reqs.append("pybind11")
119-
12093
# Framework-specific requirements
12194
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
12295
if "pytorch" in frameworks:
12396
from build_tools.pytorch import install_requirements, test_requirements
12497

125-
setup_reqs.extend(["torch>=2.1"])
12698
install_reqs.extend(install_requirements())
12799
test_reqs.extend(test_requirements())
128100
if "jax" in frameworks:
129101
from build_tools.jax import install_requirements, test_requirements
130102

131-
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
132103
install_reqs.extend(install_requirements())
133104
test_reqs.extend(test_requirements())
134105

135-
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
106+
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]
136107

137108

138109
if __name__ == "__main__":
@@ -149,14 +120,13 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
149120
ext_modules = []
150121
package_data = {}
151122
include_package_data = False
152-
setup_requires = []
153123
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
154124
extras_require = {
155125
"pytorch": [f"transformer_engine_torch=={__version__}"],
156126
"jax": [f"transformer_engine_jax=={__version__}"],
157127
}
158128
else:
159-
setup_requires, install_requires, test_requires = setup_requirements()
129+
install_requires, test_requires = setup_requirements()
160130
ext_modules = [setup_common_extension()]
161131
package_data = {"": ["VERSION.txt"]}
162132
include_package_data = True
@@ -203,7 +173,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
203173
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
204174
python_requires=">=3.8",
205175
classifiers=["Programming Language :: Python :: 3"],
206-
setup_requires=setup_requires,
207176
install_requires=install_requires,
208177
license_files=("LICENSE",),
209178
include_package_data=include_package_data,

transformer_engine/jax/pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
[build-system]
6+
requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"]
7+
8+
# Use legacy backend to import local packages in setup.py
9+
build-backend = "setuptools.build_meta:__legacy__"
10+

transformer_engine/jax/setup.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@
4444

4545

4646
from build_tools.build_ext import get_build_ext
47-
from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path
47+
from build_tools.utils import copy_common_headers
4848
from build_tools.te_version import te_version
4949
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements
5050

51-
install_and_import("pybind11")
5251
from pybind11.setup_helpers import build_ext as BuildExtension
5352

5453
os.environ["NVTE_PROJECT_BUILDING"] = "1"
@@ -94,28 +93,13 @@
9493
)
9594
]
9695

97-
setup_requires = ["jax[cuda12]", "flax>=0.7.1"]
98-
if cuda_toolkit_include_path() is None:
99-
setup_requires.extend(
100-
[
101-
"nvidia-cuda-runtime-cu12",
102-
"nvidia-cublas-cu12",
103-
"nvidia-cudnn-cu12",
104-
"nvidia-cuda-cccl-cu12",
105-
"nvidia-cuda-nvcc-cu12",
106-
"nvidia-nvtx-cu12",
107-
"nvidia-cuda-nvrtc-cu12",
108-
]
109-
)
110-
11196
# Configure package
11297
setuptools.setup(
11398
name="transformer_engine_jax",
11499
version=te_version(),
115100
description="Transformer acceleration library - Jax Lib",
116101
ext_modules=ext_modules,
117102
cmdclass={"build_ext": CMakeBuildExtension},
118-
setup_requires=setup_requires,
119103
install_requires=install_requirements(),
120104
tests_require=test_requirements(),
121105
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# See LICENSE for license information.
4+
5+
[build-system]
6+
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]
7+
8+
# Use legacy backend to import local packages in setup.py
9+
build-backend = "setuptools.build_meta:__legacy__"
10+

transformer_engine/pytorch/setup.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
from build_tools.build_ext import get_build_ext
32-
from build_tools.utils import copy_common_headers, cuda_toolkit_include_path
32+
from build_tools.utils import copy_common_headers
3333
from build_tools.te_version import te_version
3434
from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
3535

@@ -48,28 +48,13 @@
4848
)
4949
]
5050

51-
setup_requires = ["torch>=2.1"]
52-
if cuda_toolkit_include_path() is None:
53-
setup_requires.extend(
54-
[
55-
"nvidia-cuda-runtime-cu12",
56-
"nvidia-cublas-cu12",
57-
"nvidia-cudnn-cu12",
58-
"nvidia-cuda-cccl-cu12",
59-
"nvidia-cuda-nvcc-cu12",
60-
"nvidia-nvtx-cu12",
61-
"nvidia-cuda-nvrtc-cu12",
62-
]
63-
)
64-
6551
# Configure package
6652
setuptools.setup(
6753
name="transformer_engine_torch",
6854
version=te_version(),
6955
description="Transformer acceleration library - Torch Lib",
7056
ext_modules=ext_modules,
7157
cmdclass={"build_ext": CMakeBuildExtension},
72-
setup_requires=setup_requires,
7358
install_requires=install_requirements(),
7459
tests_require=test_requirements(),
7560
)

0 commit comments

Comments
 (0)