Releases: NVIDIA/cuEquivariance
Releases · NVIDIA/cuEquivariance
v0.3.0
0.3.0 (2025-03-05)
The main changes are:
- [JAX] New JIT Uniform 1d kernel with JAX bindings
- Computes any polynomial based on 1d uniform STPs
- Supports arbitrary derivatives
- Provides optional fused scatter/gather for the inputs and outputs
- 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
- [Torch] Adds torch.compile support
- [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
- enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
- enable the new kernel by setting the environement variable
- [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
- this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
- this is a temporary API that will change,
Breaking Changes
- [Torch/JAX] Removed
cue.TensorProductExecution
and addedcue.Operation
which is more lightweight and better aligned with the backend. - [JAX] In
cuex.equivariant_tensor_product
, the argumentsdtype_math
anddtype_output
are renamed tomath_dtype
andoutput_dtype
respectively. This change adds consistency with the rest of the library. - [JAX] In
cuex.equivariant_tensor_product
, the argumentsalgorithm
,precision
,use_custom_primitive
anduse_custom_kernels
have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argumentimpl: str
has been added instead to select the implementation. - [JAX] Removed
cuex.symmetric_tensor_product
. Thecuex.tensor_product
function now handles any non-homogeneous polynomials. - [JAX] The batching support (
jax.vmap
) ofcuex.equivariant_tensor_product
is now limited to specific use cases. - [JAX] The interface of
cuex.tensor_product
has changed. It now takes a list oftuple[cue.Operation, cue.SegmentedTensorProduct]
instead of a singlecue.SegmentedTensorProduct
. This change allowscuex.tensor_product
to execute any type of non-homogeneous polynomials. - [JAX] Removed
cuex.flax_linen.Linear
to reduce maintenance burden. Usecue.descriptor.linear
together withcuex.equivariant_tensor_product
instead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)
Fixed
- [Torch/JAX] Fixed
cue.descriptor.full_tensor_product
which was ignoring theirreps3_filter
argument. - [Torch/JAX] Fixed a rare bug with
np.bincount
when using an old version of numpy. The input is now flattened to make it work with all versions. - [Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for
cuet.TransposeSegments
andcuet.TransposeIrrepsLayout
.
Added
- [Torch/JAX] Added
__mul__
tocue.EquivariantTensorProduct
to allow rescaling the equivariant tensor product. - [JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
- [JAX] Added an
indices
argument tocuex.equivariant_tensor_product
andcuex.tensor_product
to handle the scatter/gather fusion. - [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable
CUEQUIVARIANCE_OPS_USE_JIT=1
) - [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change,
cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed
)
Full Changelog: v0.2.0...v0.3.0
v0.2.0
Changelog
Breaking Changes
- Minimal python version is now 3.10 in all packages.
cuet.TensorProduct
andcuet.EquivariantTensorProduct
now require inputs to be of shape(batch_size, dim)
or(1, dim)
. Inputs of dimension(dim,)
are no more allowed.cuex.IrrepsArray
is an alias forcuex.RepArray
.cuex.RepArray.irreps
andcuex.RepArray.segments
are not functions anymore. They are now properties.cuex.IrrepsArray.is_simple
is replaced bycuex.RepArray.is_irreps_array
.- The function
cuet.spherical_harmonics
is replaced by the Torch Modulecuet.SphericalHarmonics
. This was done to allow the use oftorch.jit.script
andtorch.compile
.
Added
- Add an experimental support for
torch.compile
. Known issue: the export in c++ is not working. - Add
cue.IrrepsAndLayout
: A simple class that inherits fromcue.Rep
and contains acue.Irreps
and acue.IrrepsLayout
. - Add
cuex.RepArray
for representing an array of any kind of representations (not only irreps like before withcuex.IrrepsArray
).
Fixed
- Add support for empty batch dimension in
cuet
(cuequivariance_torch
). - Move
README.md
andLICENSE
into the source distribution. - Fix
cue.SegmentedTensorProduct.flop_cost
for the special case of 1 operand.
Improved
- No more special case for degree 0 in
cuet.SymmetricTensorProduct
.
List of Pull Requests
- full_tp descriptor by @mitkotak in #26
- Add support for zero batch by @mariogeiger in #27
- Documentation: add pytorch version requirement by @mariogeiger in #36
- Avoid calling irreps.dim and logger in forward (fix for #32) by @mariogeiger in #35
- List as inputs by @mariogeiger in #28
- pre-commit by @mariogeiger in #42
- test on python==3.9 by @mariogeiger in #43
- Compatibility with jit.script and torch.compile: COMPLETE by @borisfom in #40
- Uniformize the python version requirements by @mariogeiger in #47
- fix sh l=0 by @mariogeiger in #49
- cue.IrrepsAndLayout, cue.EquivariantTensorProduct, cuex.RepArray by @mariogeiger in #46
- Working torch.jit.script() and torch.compile() support by @borisfom in #44
- fix documentation errors due to latest merged PR by @mariogeiger in #54
- CI: test doc in PR by @mariogeiger in #55
- Update CHANGELOG.md and fix few documentation details by @mariogeiger in #57
- Attempt for fix issue 53 by @mariogeiger in #56
- Fix SegmentedTensorProduct.flop_cost for 0 inputs (1 operand) by @mariogeiger in #59
- Test batch size zero (needs !77 in backend to be merged) by @mariogeiger in #58
- Removing tensor lists from all APIs by @borisfom in #61
- Fix tests by @mariogeiger in #62
- Added export test for variable batch by @borisfom in #65
- Prepare for next release by @mariogeiger in #64
- Fix shared weights for IWeightedSymmetricTPDispatcher by @mariogeiger in #63
- Update version number by @mariogeiger in #67
- Update headers with new year by @mariogeiger in #66
- Update CHANGELOG.md by @mariogeiger in #70
- Mini fixes in the doc by @mariogeiger in #71
New Contributors
Full Changelog: v0.1.0...v0.2.0
cuEquivariance v0.1.0
First Beta Release