Skip to content

Commit 0466eb8

Browse files
authored
cuda backend for eigh (#69)
1 parent b6b974d commit 0466eb8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2434
-19
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
name: Publish pyscfad-cuda-plugin
2+
3+
on:
4+
release:
5+
types:
6+
- released
7+
workflow_dispatch:
8+
9+
jobs:
10+
build_wheel_linux_x86:
11+
name: Build wheels on ${{ matrix.os }}
12+
runs-on: ${{ matrix.os }}
13+
strategy:
14+
matrix:
15+
os: [ubuntu-22.04]
16+
python-version: ["3.10", "3.11", "3.12", "3.13"]
17+
18+
steps:
19+
- uses: actions/checkout@v4
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
26+
- name: Install Bazel
27+
run: |
28+
sudo apt install apt-transport-https curl gnupg -y
29+
curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor >bazel-archive-keyring.gpg
30+
sudo mv bazel-archive-keyring.gpg /usr/share/keyrings
31+
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/bazel-archive-keyring.gpg] https://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list
32+
sudo apt update && sudo apt install bazel-7.4.1
33+
34+
- name: Build wheels
35+
run: |
36+
cd pyscfadlib
37+
python build/build.py build
38+
39+
- name: Upload wheels
40+
uses: actions/upload-artifact@v4
41+
with:
42+
name: cuda_plugin_wheels
43+
path: pyscfadlib/dist/*.whl
44+
overwrite: true

.github/workflows/publish_pyscfadlib.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
publish_pypi_linux_x86_aarch64:
11-
name: publish linux_x86 wheels to pypi
11+
name: publish linux_x86 linux_aarch64 wheels to pypi
1212
runs-on: ubuntu-latest
1313

1414
environment: release
@@ -124,4 +124,3 @@ jobs:
124124

125125
- name: Publish to PyPI
126126
uses: pypa/gh-action-pypi-publish@release/v1
127-

examples/fci/10-nac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
mol = gto.Mole()
66
mol.atom = 'H 0 0 0; H 0 0 1.1'
77
mol.basis = 'ccpvdz'
8-
mol.build()
8+
mol.build(trace_exp=False, trace_ctr_coeff=False)
99

1010
# HF and FCI calculation
1111
nroots = 8

pyscfad/backend/_jax/lax/linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,14 @@ def _eigh_gen_cpu_gpu_lowering(
135135
"jobz": np.uint8(ord("V")),
136136
"uplo": np.uint8(ord("L" if lower else "U")),
137137
}
138+
elif target_name_prefix == "cuda":
139+
target_name = "cusolver_sygvd_ffi"
140+
kwargs = {
141+
"itype": np.int32(itype),
142+
"lower": lower,
143+
}
138144
else:
139-
raise NotImplementedError
145+
raise NotImplementedError(f"Platform {target_name_prefix} is not supported.")
140146

141147
info_aval = ShapedArray(batch_dims, np.int32)
142148

@@ -200,7 +206,7 @@ def _eigh_gen_batching_rule(batched_args, batch_dims, *,
200206
)
201207
mlir.register_lowering(
202208
eigh_gen_p,
203-
partial(_eigh_gen_cpu_gpu_lowering, target_name_prefix="cu"),
209+
partial(_eigh_gen_cpu_gpu_lowering, target_name_prefix="cuda"),
204210
platform="cuda"
205211
)
206212

pyscfadlib/.bazelrc

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copied from JAX
2+
3+
# #############################################################################
4+
# All default build options below. These apply to all build commands.
5+
# #############################################################################
6+
# TODO: Enable Bzlmod
7+
common --noenable_bzlmod
8+
9+
# TODO: Migrate for https://github.com/bazelbuild/bazel/issues/7260
10+
#common --noincompatible_enable_cc_toolchain_resolution
11+
12+
# Make Bazel print out all options from rc files.
13+
common --announce_rc
14+
15+
# By default, execute all actions locally.
16+
build --spawn_strategy=local
17+
18+
# Enable host OS specific configs. For instance, "build:linux" will be used
19+
# automatically when building on Linux.
20+
build --enable_platform_specific_config
21+
22+
common --experimental_cc_shared_library
23+
24+
# Do not use C-Ares when building gRPC.
25+
build --define=grpc_no_ares=true
26+
27+
build --define=tsl_link_protobuf=true
28+
29+
# Enable optimization.
30+
build -c opt
31+
32+
# Suppress all warning messages.
33+
build --output_filter=DONT_MATCH_ANYTHING
34+
35+
# #############################################################################
36+
# Platform Specific configs below. These are automatically picked up by Bazel
37+
# depending on the platform that is running the build.
38+
# #############################################################################
39+
build:linux --config=posix
40+
build:linux --copt=-Wno-unknown-warning-option
41+
42+
# Workaround for gcc 10+ warnings related to upb.
43+
# See https://github.com/tensorflow/tensorflow/issues/39467
44+
build:linux --copt=-Wno-stringop-truncation
45+
build:linux --copt=-Wno-array-parameter
46+
47+
# #############################################################################
48+
# Feature-specific configurations. These are used by the CI configs below
49+
# depending on the type of build. E.g. `ci_linux_x86_64` inherits the Linux x86
50+
# configs such as `avx_linux` and `mkl_open_source_only`, `ci_linux_x86_64_cuda`
51+
# inherits `cuda` and `build_cuda_with_nvcc`, etc.
52+
# #############################################################################
53+
build:nonccl --define=no_nccl_support=true
54+
55+
build:posix --copt=-fvisibility=hidden
56+
#build:posix --copt=-Wno-sign-compare
57+
build:posix --cxxopt=-std=c++17
58+
build:posix --host_cxxopt=-std=c++17
59+
60+
build:avx_posix --copt=-mavx
61+
build:avx_posix --host_copt=-mavx
62+
63+
build:native_arch_posix --copt=-march=native
64+
build:native_arch_posix --host_copt=-march=native
65+
66+
build:avx_linux --copt=-mavx
67+
build:avx_linux --host_copt=-mavx
68+
69+
# Configs for CUDA
70+
build:cuda --repo_env TF_NEED_CUDA=1
71+
build:cuda --repo_env TF_NCCL_USE_STUB=1
72+
# "sm" means we emit only cubin, which is forward compatible within a GPU generation.
73+
# "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.
74+
build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"
75+
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
76+
build:cuda --@local_config_cuda//:enable_cuda
77+
78+
# Default hermetic CUDA and CUDNN versions.
79+
build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2"
80+
build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1"
81+
build:cuda --@local_config_cuda//cuda:include_cuda_libs=true
82+
83+
# This config is used for building targets with CUDA libraries from stubs.
84+
build:cuda_libraries_from_stubs --@local_config_cuda//cuda:include_cuda_libs=false
85+
86+
# Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries,
87+
# ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to
88+
# point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA
89+
# packages.
90+
# This has pros and cons:
91+
# * pro: we'll ignore other CUDA installations, which has frequently confused
92+
# users in the past. By setting RPATH, we'll always use the NVIDIA pip
93+
# packages if they are installed.
94+
# * con: the user cannot override the CUDA installation location
95+
# via LD_LIBRARY_PATH, if the nvidia-... pip packages are installed. This is
96+
# acceptable, because the workaround is "remove the nvidia-..." pip packages.
97+
# The list of CUDA pip packages that JAX depends on are present in setup.py.
98+
build:cuda --linkopt=-Wl,--disable-new-dtags
99+
100+
# Build CUDA and other C++ targets with Clang
101+
build:build_cuda_with_clang --@local_config_cuda//:cuda_compiler=clang
102+
103+
# Build CUDA with NVCC and other C++ targets with Clang
104+
build:build_cuda_with_nvcc --action_env=TF_NVCC_CLANG="1"
105+
build:build_cuda_with_nvcc --@local_config_cuda//:cuda_compiler=nvcc
106+
107+
# Flag to enable remote config
108+
common --experimental_repo_remote_exec
109+
110+
# Load `.jax_configure.bazelrc` file written by build.py
111+
try-import %workspace%/.pyscfad_configure.bazelrc

pyscfadlib/.bazelversion

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
7.4.1

pyscfadlib/BUILD

Whitespace-only changes.

pyscfadlib/LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021-2024 Xing Zhang
3+
Copyright (c) 2021-2025 Xing Zhang
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

pyscfadlib/MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include MANIFEST.in
2-
include README.md setup.py LICENSE
2+
include README.md LICENSE pyproject.toml
33

44
recursive-include pyscfadlib/thirdparty *.so
55
include pyscfadlib/*.so pyscfadlib/config.h.in

pyscfadlib/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
# pyscfadlib: support library for pyscfad
2+
13
pyscfadlib is the support library for PySCFAD.
24
It contains most of the C optimized code for derivative calculations.

0 commit comments

Comments
 (0)