Skip to content

[JAX:Sparse] Implement CSR sparse kernel #28261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions jax/experimental/sparse/bcsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _bcsr_to_bcoo(indices: jax.Array, indptr: jax.Array, *,


def _bcoo_to_bcsr(indices: Array, *, shape: Sequence[int],
index_dtype: DTypeLike = jnp.int32) -> tuple[Array, Array]:
index_dtype: DTypeLike) -> tuple[Array, Array]:
"""Given BCOO (indices), return BCSR (indices, indptr).
Note: this assumes that ``indices`` are lexicographically sorted within each batch.
Expand Down Expand Up @@ -238,7 +238,9 @@ def _bcsr_fromdense_impl(mat, *, nse, n_batch, n_dense, index_dtype):
raise ValueError("bcsr_fromdense: must have 2 sparse dimensions.")
bcoo_mat = bcoo.bcoo_fromdense(mat, nse=nse, index_dtype=index_dtype,
n_dense=n_dense, n_batch=n_batch)
indices, indptr = _bcoo_to_bcsr(bcoo_mat.indices, shape=mat.shape)
indices, indptr = _bcoo_to_bcsr(
bcoo_mat.indices, shape=mat.shape, index_dtype=index_dtype
)
return bcoo_mat.data, indices, indptr


Expand Down Expand Up @@ -867,7 +869,9 @@ def from_bcoo(cls, arr: bcoo.BCOO) -> BCSR:
raise NotImplementedError(f"BSCR.from_bcoo requires n_sparse=2; got {arr.n_sparse=}")
if not arr.indices_sorted:
arr = arr.sort_indices()
indices, indptr = _bcoo_to_bcsr(arr.indices, shape=arr.shape)
indices, indptr = _bcoo_to_bcsr(
arr.indices, shape=arr.shape, index_dtype=arr.indices.dtype
)
return cls((arr.data, indices, indptr), shape=arr.shape)

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ package_group(
py_library_providing_imports_info(
name = "jaxlib",
srcs = [
"cpu_sparse.py",
"gpu_common_utils.py",
"gpu_linalg.py",
"gpu_prng.py",
Expand All @@ -76,6 +77,7 @@ py_library_providing_imports_info(
"//jaxlib:_jax",
"//jaxlib:xla_client",
"//jaxlib/cpu:_lapack",
"//jaxlib/cpu:_sparse",
"//jaxlib/mlir",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
Expand Down
33 changes: 33 additions & 0 deletions jaxlib/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,42 @@ cc_library(
deps = [
":lapack_kernels",
":lapack_kernels_using_lapack",
":sparse_kernels",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/service:custom_call_target_registry",
],
alwayslink = 1,
)

cc_library(
name = "sparse_kernels",
srcs = ["sparse_kernels.cc"],
hdrs = ["sparse_kernels.h"],
deps = [
"@eigen_archive//:eigen3",
"@xla//xla/ffi/api:ffi",
],
)

nanobind_extension(
name = "_sparse",
srcs = ["sparse.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
enable_stub_generation = False,
features = ["-use_header_modules"],
module_name = "_sparse",
pytype_srcs = [
"_sparse/__init__.pyi",
],
deps = [
":sparse_kernels",
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/base",
"@nanobind",
"@xla//xla/ffi/api:ffi",
],
)
15 changes: 15 additions & 0 deletions jaxlib/cpu/_sparse/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

def registrations() -> dict: ...
3 changes: 3 additions & 0 deletions jaxlib/cpu/cpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <complex>

#include "jaxlib/cpu/lapack_kernels.h"
#include "jaxlib/cpu/sparse_kernels.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/service/custom_call_target_registry.h"
Expand Down Expand Up @@ -110,6 +111,8 @@ JAX_CPU_REGISTER_HANDLER(lapack_dgtsv_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_cgtsv_ffi);
JAX_CPU_REGISTER_HANDLER(lapack_zgtsv_ffi);

JAX_CPU_REGISTER_HANDLER(cpu_csr_sparse_dense_ffi);

#undef JAX_CPU_REGISTER_HANDLER

} // namespace
Expand Down
37 changes: 37 additions & 0 deletions jaxlib/cpu/sparse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2021 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "nanobind/nanobind.h"
#include "jaxlib/cpu/sparse_kernels.h"
#include "jaxlib/kernel_nanobind_helpers.h"

namespace jax {
namespace {

namespace nb = nanobind;

nb::dict Registrations() {
nb::dict dict;

dict["cpu_csr_sparse_dense_ffi"] =
EncapsulateFunction(cpu_csr_sparse_dense_ffi);

return dict;
}

NB_MODULE(_sparse, m) { m.def("registrations", &Registrations); }

} // namespace
} // namespace jax
216 changes: 216 additions & 0 deletions jaxlib/cpu/sparse_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/cpu/sparse_kernels.h"

#include <algorithm>
#include <complex>
#include <cstdint>
#include <vector>

#include "Eigen/Core"
#include "Eigen/SparseCore"
#include "xla/ffi/api/ffi.h"

namespace jax {

template <typename ElementType, typename StorageType>
using SparseMatrixType =
Eigen::SparseMatrix<ElementType, Eigen::RowMajor, StorageType>;
template <typename ElementType>
using DenseMatrixType =
Eigen::Matrix<ElementType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;

template <typename MatrixT>
using InputMap = Eigen::Map<const MatrixT, Eigen::Aligned32>;
template <typename MatrixT>
using OutputMap = Eigen::Map<MatrixT, Eigen::Aligned32>;

template <typename ElementType, typename StorageType>
static ::xla::ffi::Future CsrSparseDenseKernelImpl(
const InputMap<SparseMatrixType<ElementType, StorageType>>& lhs_matrix,
const InputMap<DenseMatrixType<ElementType>>& rhs_matrix,
OutputMap<DenseMatrixType<ElementType>>& out_matrix,
::xla::ffi::ThreadPool& thread_pool) {
// Rule of thumb to give each task at least 100k cycles to hide the cost of
// task scheduling.
// TODO(willfroom) Do we want to make this configurable?
constexpr int64_t kTargetCyclesPerTask = 100'000;
// Based on AVX (CPI 0.5 -> 2 IPC)
constexpr int64_t kScalarProductsPerCycle = 2 * 32 / sizeof(ElementType);
constexpr int64_t kTaskSize = kTargetCyclesPerTask * kScalarProductsPerCycle;

if (lhs_matrix.nonZeros() * rhs_matrix.cols() <= kTaskSize ||
thread_pool.num_threads() == 0) {
out_matrix.noalias() = lhs_matrix * rhs_matrix;

::xla::ffi::Promise promise;
promise.SetAvailable();
return ::xla::ffi::Future(promise);
} else {
std::vector<int64_t> batch_sizes;
{
int64_t running_batch_nnz = 0;
int64_t running_number_rows = 0;
for (int row = 0; row < lhs_matrix.rows(); ++row) {
int64_t row_nnz = lhs_matrix.outerIndexPtr()[row + 1] -
lhs_matrix.outerIndexPtr()[row];
// If there is no non-zero elements in a row the task still needs to
// write out a zero row we give each row a non-zero contribution to
// avoid the pathological case of a task having to write many rows where
// there is a large block of zero inputs.
running_batch_nnz += std::max(row_nnz, static_cast<int64_t>(1));
running_number_rows++;
if (running_batch_nnz * rhs_matrix.cols() > kTaskSize) {
batch_sizes.push_back(running_number_rows);
running_batch_nnz = 0;
running_number_rows = 0;
} else if (row == lhs_matrix.rows() - 1 && running_number_rows > 0) {
batch_sizes.push_back(running_number_rows);
}
}
}

::xla::ffi::CountDownPromise promise(batch_sizes.size());
::xla::ffi::Future future(promise);
int64_t batch_start = 0;
for (int64_t size : batch_sizes) {
thread_pool.Schedule([out_matrix, lhs_matrix, rhs_matrix, batch_start,
size, promise]() mutable {
out_matrix.middleRows(batch_start, size).noalias() =
lhs_matrix.middleRows(batch_start, size) * rhs_matrix;
promise.CountDown();
});
batch_start += size;
}
return future;
}
}

template <typename ElementType, typename StorageType>
static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch(
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
::xla::ffi::ThreadPool thread_pool) {
::xla::ffi::Span<const int64_t> rhs_shape = rhs.dimensions();
::xla::ffi::Span<const int64_t> out_shape = out->dimensions();

InputMap<SparseMatrixType<ElementType, StorageType>> lhs_matrix(
out_shape[0], rhs_shape[0], lhs_data.element_count(),
lhs_outer_indicies.reinterpret_data<StorageType>(),
lhs_inner_indicies.reinterpret_data<StorageType>(),
lhs_data.reinterpret_data<ElementType>());

InputMap<DenseMatrixType<ElementType>> rhs_matrix(
rhs.reinterpret_data<ElementType>(), rhs_shape[0],
rhs_shape.size() > 1 ? rhs_shape[1] : 1);
OutputMap<DenseMatrixType<ElementType>> out_matrix(
out->reinterpret_data<ElementType>(), lhs_matrix.rows(),
rhs_matrix.cols());

return CsrSparseDenseKernelImpl<ElementType, StorageType>(
lhs_matrix, rhs_matrix, out_matrix, thread_pool);
}

template <typename ElementType>
static ::xla::ffi::Future CsrSparseDenseKernelTypedDispatch(
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
::xla::ffi::ThreadPool thread_pool) {
if (lhs_outer_indicies.element_type() != lhs_inner_indicies.element_type()) {
::xla::ffi::Promise promise;
promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument,
"Sparse index type mismatch"));
return ::xla::ffi::Future(promise);
}

switch (lhs_outer_indicies.element_type()) {
case ::xla::ffi::DataType::S32:
return CsrSparseDenseKernelTypedDispatch<ElementType, int32_t>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::S64:
return CsrSparseDenseKernelTypedDispatch<ElementType, int64_t>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
default:
::xla::ffi::Promise promise;
promise.SetError(::xla::ffi::Error(
::xla::ffi::ErrorCode::kInvalidArgument, "Invalid index data type"));
return ::xla::ffi::Future(promise);
}
}

static ::xla::ffi::Future CsrSparseDenseKernelDispatch(
::xla::ffi::AnyBuffer lhs_data, ::xla::ffi::AnyBuffer lhs_outer_indicies,
::xla::ffi::AnyBuffer lhs_inner_indicies, ::xla::ffi::AnyBuffer rhs,
::xla::ffi::Result<::xla::ffi::AnyBuffer> out,
::xla::ffi::ThreadPool thread_pool) {
if (lhs_data.element_type() != rhs.element_type() ||
lhs_data.element_type() != out->element_type()) {
::xla::ffi::Promise promise;
promise.SetError(::xla::ffi::Error(::xla::ffi::ErrorCode::kInvalidArgument,
"Element type mismatch"));
return ::xla::ffi::Future(promise);
}

switch (lhs_data.element_type()) {
case ::xla::ffi::DataType::S32:
return CsrSparseDenseKernelTypedDispatch<int32_t>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::S64:
return CsrSparseDenseKernelTypedDispatch<int64_t>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::F32:
return CsrSparseDenseKernelTypedDispatch<float>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::F64:
return CsrSparseDenseKernelTypedDispatch<double>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::C64:
return CsrSparseDenseKernelTypedDispatch<std::complex<float>>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
case ::xla::ffi::DataType::C128:
return CsrSparseDenseKernelTypedDispatch<std::complex<double>>(
lhs_data, lhs_outer_indicies, lhs_inner_indicies, rhs, out,
thread_pool);
default:
::xla::ffi::Promise promise;
promise.SetError(::xla::ffi::Error(
::xla::ffi::ErrorCode::kInvalidArgument, "Invalid data type"));
return ::xla::ffi::Future(promise);
}
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
cpu_csr_sparse_dense_ffi, CsrSparseDenseKernelDispatch,
(::xla::ffi::Ffi::Bind()
.Arg<::xla::ffi::AnyBuffer>(/*lhs_data*/)
.Arg<::xla::ffi::AnyBuffer>(
/*lhs_outer_indicies*/)
.Arg<::xla::ffi::AnyBuffer>(
/*lhs_inner_indicies*/)
.Arg<::xla::ffi::AnyBuffer>(/*rhs*/)
.Ret<::xla::ffi::AnyBuffer>(/*out*/)
.Ctx<::xla::ffi::ThreadPool>(/*thread_pool*/)));

} // namespace jax
Loading
Loading