Skip to content

Releases: NVIDIA/cuEquivariance

v0.3.0

07 Mar 22:33
Compare
Choose a tag to compare

0.3.0 (2025-03-05)

The main changes are:

  1. [JAX] New JIT Uniform 1d kernel with JAX bindings
    1. Computes any polynomial based on 1d uniform STPs
    2. Supports arbitrary derivatives
    3. Provides optional fused scatter/gather for the inputs and outputs
    4. 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉
  2. [Torch] Adds torch.compile support
  3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel
    1. enable the new kernel by setting the environement variable CUEQUIVARIANCE_OPS_USE_JIT=1
  4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d
    1. this is a temporary API that will change, cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed

Breaking Changes

  • [Torch/JAX] Removed cue.TensorProductExecution and added cue.Operation which is more lightweight and better aligned with the backend.
  • [JAX] In cuex.equivariant_tensor_product, the arguments dtype_math and dtype_output are renamed to math_dtype and output_dtype respectively. This change adds consistency with the rest of the library.
  • [JAX] In cuex.equivariant_tensor_product, the arguments algorithm, precision, use_custom_primitive and use_custom_kernels have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument impl: str has been added instead to select the implementation.
  • [JAX] Removed cuex.symmetric_tensor_product. The cuex.tensor_product function now handles any non-homogeneous polynomials.
  • [JAX] The batching support (jax.vmap) of cuex.equivariant_tensor_product is now limited to specific use cases.
  • [JAX] The interface of cuex.tensor_product has changed. It now takes a list of tuple[cue.Operation, cue.SegmentedTensorProduct] instead of a single cue.SegmentedTensorProduct. This change allows cuex.tensor_product to execute any type of non-homogeneous polynomials.
  • [JAX] Removed cuex.flax_linen.Linear to reduce maintenance burden. Use cue.descriptor.linear together with cuex.equivariant_tensor_product instead.
e = cue.descriptors.linear(input.irreps, output_irreps)
w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype)
output = cuex.equivariant_tensor_product(e, w, input)

Fixed

  • [Torch/JAX] Fixed cue.descriptor.full_tensor_product which was ignoring the irreps3_filter argument.
  • [Torch/JAX] Fixed a rare bug with np.bincount when using an old version of numpy. The input is now flattened to make it work with all versions.
  • [Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for cuet.TransposeSegments and cuet.TransposeIrrepsLayout.

Added

  • [Torch/JAX] Added __mul__ to cue.EquivariantTensorProduct to allow rescaling the equivariant tensor product.
  • [JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations.
  • [JAX] Added an indices argument to cuex.equivariant_tensor_product and cuex.tensor_product to handle the scatter/gather fusion.
  • [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable CUEQUIVARIANCE_OPS_USE_JIT=1)
  • [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change, cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed)

Full Changelog: v0.2.0...v0.3.0

v0.2.0

24 Jan 13:39
Compare
Choose a tag to compare

Changelog

Breaking Changes

  • Minimal python version is now 3.10 in all packages.
  • cuet.TensorProduct and cuet.EquivariantTensorProduct now require inputs to be of shape (batch_size, dim) or (1, dim). Inputs of dimension (dim,) are no more allowed.
  • cuex.IrrepsArray is an alias for cuex.RepArray.
  • cuex.RepArray.irreps and cuex.RepArray.segments are not functions anymore. They are now properties.
  • cuex.IrrepsArray.is_simple is replaced by cuex.RepArray.is_irreps_array.
  • The function cuet.spherical_harmonics is replaced by the Torch Module cuet.SphericalHarmonics. This was done to allow the use of torch.jit.script and torch.compile.

Added

  • Add an experimental support for torch.compile. Known issue: the export in c++ is not working.
  • Add cue.IrrepsAndLayout: A simple class that inherits from cue.Rep and contains a cue.Irreps and a cue.IrrepsLayout.
  • Add cuex.RepArray for representing an array of any kind of representations (not only irreps like before with cuex.IrrepsArray).

Fixed

  • Add support for empty batch dimension in cuet (cuequivariance_torch).
  • Move README.md and LICENSE into the source distribution.
  • Fix cue.SegmentedTensorProduct.flop_cost for the special case of 1 operand.

Improved

  • No more special case for degree 0 in cuet.SymmetricTensorProduct.

List of Pull Requests

New Contributors

Full Changelog: v0.1.0...v0.2.0

cuEquivariance v0.1.0

23 Nov 14:46
Compare
Choose a tag to compare

First Beta Release