Skip to content

pyproject.toml #1852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake==3.21.0
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -42,8 +42,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -62,6 +62,8 @@ jobs:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,13 @@ Alternatively, install directly from the GitHub repository:

.. code-block:: bash

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

When installing from GitHub, you can explicitly specify frameworks using the environment variable:

.. code-block:: bash

NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

conda Installation
^^^^^^^^^^^^^^^^^^
Expand Down
17 changes: 3 additions & 14 deletions build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/JAX extensions."""
return ["jax[cuda12]", "flax>=0.7.1"]
return ["jax", "flax>=0.7.1"]


def test_requirements() -> List[str]:
Expand Down Expand Up @@ -75,20 +75,9 @@ def setup_jax_extension(
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension

class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow custom CXX flags."""

def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags

return Pybind11CPPExtension(
return Pybind11Extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags},
extra_compile_args=cxx_flags,
)
7 changes: 0 additions & 7 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,3 @@ def copy_common_headers(
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)


def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
15 changes: 9 additions & 6 deletions build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive

# Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja

if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
Expand All @@ -31,15 +34,15 @@ if $BUILD_COMMON ; then
WHL_BASE="transformer_engine-${VERSION}"

# Create the wheel.
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt

# Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp38-cp38/bin/wheel unpack dist/*
/opt/python/cp310-cp310/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}
/opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}

# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
Expand All @@ -51,14 +54,14 @@ fi

if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp38-cp38/bin/pip install torch
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
/opt/python/cp310-cp310/bin/pip install torch
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/
fi

if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean PyTorch is a dependency for JAX-only builds, and vice versa? That doesn't seem right.

One problem with the pyproject.toml approach is that build-time dependencies are static, but we would really like to have dynamic dependencies. Should we migrate away from supporting a monolithic TE package and commit to using framework-specific subpackages? The TE core package only has runtime dependencies on the framework subpackages, and runtime dependencies can be dynamic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of important points to note:

  1. These static build dependencies are only relevant when using build isolation, which is to be viewed purely as an extra feature on top of the current recommendation of using --no-build-isolation.
  2. This is also only the case for source installations. As installations from PyPI circumvents this by having a wheel for the core build and specific dependencies from the framework packages.

Feature 1 will help users who want to be able to do a simple build without installing any dependencies beforehand and in that case they will get both frameworks. The main issue that I see with having the framework packages completely separate even for a source install is that we won't be able to support the workflow of doing a simple pip install from the top level directory.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think requiring --no-build-isolation is fine for now (I wonder if there's a way to error out if it's not set), but doesn't this mean you also need --no-deps for the source build? Or else you might install PyTorch/JAX unnecessarily. Maybe better to assume that the user has PyTorch/JAX in their environment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't this mean you also need --no-deps for the source build?

No, because of the distinction between build vs runtime dependencies. --no-build-isolation is strictly for build time dependencies, whereas --no-deps is for runtime dependencies. With --no-build-isolation (the current recommended install method), the user must preinstall all needed build deps for the necessary framework extension they want to compile (include the fw itself). Few cases:

  1. If they only have torch installed then the Jax extensions won't be built and the install deps for TE/Jax will not be installed, and vice-versa.
  2. If the user has both frameworks but doesn't set NVTE_FRAMEWORK then everything will be built and all runtime deps will be installed.
  3. If the user has both frameworks and sets NVTE_FRAMEWORK then only that framework's extension will be built and thus only it's runtime deps will be installed.

I wonder if there's a way to error out if it's not set

I think this is not needed because:

  1. We updated all installation instructions and docs in v2.3 to ensure that --no-build-isolation is the default everywhere.
  2. This PR does add support for build isolation, purely as an new feature or fallback. Worst case the user env has a few extra libs/dependencies installed that they might not need, but this is a good trade-off for the basic user that doesn't want to manage deps and worry about the env. This will be powerful especially once we remove the toolkit restriction because then we will not longer have any system package dependency and user can simply pip install ..


# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"

39 changes: 4 additions & 35 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@
from build_tools.te_version import te_version
from build_tools.utils import (
cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
get_frameworks,
install_and_import,
remove_dups,
cuda_toolkit_include_path,
)

frameworks = get_frameworks()
Expand All @@ -36,7 +31,6 @@
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import("pybind11[global]")
from pybind11.setup_helpers import build_ext as BuildExtension


Expand Down Expand Up @@ -82,57 +76,34 @@ def setup_common_extension() -> CMakeExtension:
)


def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
def setup_requirements() -> Tuple[List[str], List[str]]:
"""Setup Python dependencies

Returns dependencies for build, runtime, and testing.
Returns dependencies for runtime and testing.
"""

# Common requirements
setup_reqs: List[str] = []
if cuda_toolkit_include_path() is None:
setup_reqs.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)
install_reqs: List[str] = [
"pydantic",
"importlib-metadata>=1.0",
"packaging",
]
test_reqs: List[str] = ["pytest>=8.2.1"]

# Requirements that may be installed outside of Python
if not found_cmake():
setup_reqs.append("cmake>=3.21")
if not found_ninja():
setup_reqs.append("ninja")
if not found_pybind11():
setup_reqs.append("pybind11")

# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import install_requirements, test_requirements

setup_reqs.extend(["torch>=2.1"])
install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())
if "jax" in frameworks:
from build_tools.jax import install_requirements, test_requirements

setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
install_reqs.extend(install_requirements())
test_reqs.extend(test_requirements())

return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]]


if __name__ == "__main__":
Expand All @@ -149,14 +120,13 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
ext_modules = []
package_data = {}
include_package_data = False
setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
package_data = {"": ["VERSION.txt"]}
include_package_data = True
Expand Down Expand Up @@ -203,7 +173,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8",
classifiers=["Programming Language :: Python :: 3"],
setup_requires=setup_requires,
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=include_package_data,
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/jax/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"

18 changes: 1 addition & 17 deletions transformer_engine/jax/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@


from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, install_and_import, cuda_toolkit_include_path
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension, install_requirements, test_requirements

install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension

os.environ["NVTE_PROJECT_BUILDING"] = "1"
Expand Down Expand Up @@ -94,28 +93,13 @@
)
]

setup_requires = ["jax[cuda12]", "flax>=0.7.1"]
if cuda_toolkit_include_path() is None:
setup_requires.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)

# Configure package
setuptools.setup(
name="transformer_engine_jax",
version=te_version(),
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires,
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

[build-system]
requires = ["setuptools>=61.0", "pip", "torch>=2.1"]

# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"

17 changes: 1 addition & 16 deletions transformer_engine/pytorch/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@


from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, cuda_toolkit_include_path
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements

Expand All @@ -48,28 +48,13 @@
)
]

setup_requires = ["torch>=2.1"]
if cuda_toolkit_include_path() is None:
setup_requires.extend(
[
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
]
)

# Configure package
setuptools.setup(
name="transformer_engine_torch",
version=te_version(),
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires,
install_requires=install_requirements(),
tests_require=test_requirements(),
)
Expand Down