Skip to content

Commit 7f44b07

Browse files
YuriyzabegaevIngridKJ
authored andcommitted
Sparse arrays (#1373)
* MAINT: Add sps.sparray to isinstance checks * TST: sparse_array support in ad test * MAINT: Applied suggestions from the review --------- Co-authored-by: Ingrid Kristine Jacobsen <[email protected]>
1 parent 2f37632 commit 7f44b07

File tree

7 files changed

+59
-31
lines changed

7 files changed

+59
-31
lines changed

src/porepy/models/constitutive_laws.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def aperture(self, subdomains: list[pp.Grid]) -> pp.ad.Operator:
430430
parent_cells_to_intersection_cells
431431
)
432432

433-
assert isinstance(weight_value, sps.spmatrix) # for mypy
433+
assert isinstance(weight_value, (sps.spmatrix, sps.sparray)) # for mypy
434434
average_weights = np.ravel(weight_value.sum(axis=1))
435435
nonzero = average_weights > 0
436436
average_weights[nonzero] = 1 / average_weights[nonzero]

src/porepy/numerics/ad/_ad_parser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def evaluate(
139139
result_list[index] = pp.ad.AdArray(
140140
res, sps.csr_matrix((res.shape[0], equation_system.num_dofs()))
141141
)
142-
elif isinstance(res, (sps.spmatrix, np.ndarray)):
142+
elif isinstance(res, (sps.spmatrix, sps.sparray, np.ndarray)):
143143
# This will cover numpy arrays of higher dimensions (> 1) and sparse
144144
# matrices.
145145
#
@@ -420,7 +420,7 @@ def _get_error_message(
420420
msg += "The second argument represents the expression:\n " + msg_1 + nl
421421

422422
# Finally some information on sizes
423-
if isinstance(results[0], sps.spmatrix):
423+
if isinstance(results[0], (sps.spmatrix, sps.sparray)):
424424
msg += f"First argument is a sparse matrix of size {results[0].shape}\n"
425425
elif isinstance(results[0], pp.ad.AdArray):
426426
msg += (
@@ -430,7 +430,7 @@ def _get_error_message(
430430
elif isinstance(results[0], np.ndarray):
431431
msg += f"First argument is a numpy array of size {results[0].size}\n"
432432

433-
if isinstance(results[1], sps.spmatrix):
433+
if isinstance(results[1], (sps.spmatrix, sps.sparray)):
434434
msg += f"Second argument is a sparse matrix of size {results[1].shape}\n"
435435
elif isinstance(results[1], pp.ad.AdArray):
436436
msg += (

src/porepy/numerics/ad/forward_mode.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def __add__(self, other: AdType) -> AdArray:
175175
raise ValueError("Only 1d numpy arrays can be added to AdArrays")
176176
return AdArray(self.val + other, self.jac)
177177

178-
elif isinstance(other, sps.spmatrix):
178+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
179179
raise ValueError("Sparse matrices cannot be added to AdArrays")
180180

181181
elif isinstance(other, pp.ad.AdArray):
@@ -268,7 +268,7 @@ def __mul__(self, other: AdType) -> AdArray:
268268
new_jac = self._diagvec_mul_jac(other)
269269
return AdArray(new_val, new_jac)
270270

271-
elif isinstance(other, sps.spmatrix):
271+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
272272
raise ValueError(
273273
"""Sparse matrices cannot be multiplied with AdArrays elementwise.
274274
Did you mean to use the @ operator?
@@ -315,7 +315,7 @@ def __rmul__(self, other: AdType) -> AdArray:
315315
316316
"""
317317

318-
if isinstance(other, (float, sps.spmatrix, np.ndarray, int)):
318+
if isinstance(other, (float, sps.spmatrix, sps.sparray, np.ndarray, int)):
319319
# In these cases, there is no difference between left and right
320320
# multiplication, so we simply invoke the standard __mul__ function.
321321
return self.__mul__(other)
@@ -370,7 +370,7 @@ def __pow__(self, other: AdType) -> AdArray:
370370
new_jac = self._diagvec_mul_jac(other * (self.val ** (other - 1)))
371371
return AdArray(new_val, new_jac)
372372

373-
elif isinstance(other, sps.spmatrix):
373+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
374374
raise ValueError("Cannot raise AdArrays to power of sparse matrices.")
375375

376376
elif isinstance(other, pp.matrix_operations.ArraySlicer):
@@ -440,7 +440,7 @@ def __rpow__(self, other: AdType) -> AdArray:
440440
new_jac = self._diagvec_mul_jac((other**self.val) * np.log(other))
441441
return AdArray(new_val, new_jac)
442442

443-
elif isinstance(other, sps.spmatrix):
443+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
444444
raise ValueError("Cannot raise sparse matrices to the power of Ad arrays.")
445445

446446
elif isinstance(other, pp.ad.AdArray):
@@ -482,7 +482,7 @@ def __truediv__(self, other: AdType) -> AdArray:
482482
new_jac = self._diagvec_mul_jac(other.astype(float) ** (-1.0))
483483
return AdArray(new_val, new_jac)
484484

485-
elif isinstance(other, sps.spmatrix):
485+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
486486
raise ValueError("AdArrays cannot be divided by sparse matrices.")
487487

488488
elif isinstance(other, pp.matrix_operations.ArraySlicer):
@@ -510,7 +510,7 @@ def __rtruediv__(self, other: AdType) -> AdArray:
510510
511511
"""
512512

513-
if isinstance(other, (float, int, np.ndarray, sps.spmatrix)):
513+
if isinstance(other, (float, int, np.ndarray, sps.spmatrix, sps.sparray)):
514514
# Divide a float or a numpy array by self is the same as raising self to
515515
# the power of -1 and multiplying by the float. The multiplication will
516516
# end upcalling self.__mul__, which will do the right checks for numpy
@@ -545,7 +545,7 @@ class documentation for restrictions on admissible types for this
545545
f""" {type(other)}."""
546546
)
547547

548-
elif isinstance(other, sps.spmatrix):
548+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
549549
# This goes against the way equations should be formulated in the AD
550550
# framework, variables should not be right-multiplied by anything. Raise
551551
# a value error to make sure this is not done.
@@ -574,7 +574,7 @@ class documentation for restrictions on admissible types for this
574574
f""" {type(other)}."""
575575
)
576576

577-
elif isinstance(other, sps.spmatrix):
577+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
578578
# This is the standard matrix-vector multiplication
579579
if self.jac.shape[0] != other.shape[1]:
580580
raise ValueError(

src/porepy/numerics/ad/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def maximum(var_0: FloatType, var_1: FloatType) -> FloatType:
416416
# Start from var_0, then change entries corresponding to inds.
417417
max_jac = jacs[0].copy()
418418

419-
if isinstance(max_jac, sps.spmatrix):
419+
if isinstance(max_jac, (sps.spmatrix, sps.sparray)):
420420
# Enforce csr format, unless the matrix is csc, in which case we keep it.
421421
if not max_jac.getformat() == "csc":
422422
max_jac = max_jac.tocsr()

src/porepy/numerics/ad/operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ def value_and_jacobian(
573573
return AdArray(
574574
ad, sps.csr_matrix((ad.shape[0], equation_system.num_dofs()))
575575
)
576-
elif isinstance(ad, (sps.spmatrix, np.ndarray)):
576+
elif isinstance(ad, (sps.spmatrix, sps.sparray, np.ndarray)):
577577
# this case coverse both, dense and sparse matrices returned from
578578
# discretizations f.e.
579579
raise NotImplementedError(
@@ -901,7 +901,7 @@ def _parse_other(self, other):
901901
return [self, Scalar(other)]
902902
elif isinstance(other, np.ndarray):
903903
return [self, DenseArray(other)]
904-
elif isinstance(other, sps.spmatrix):
904+
elif isinstance(other, (sps.spmatrix, sps.sparray)):
905905
return [self, SparseArray(other)]
906906
elif isinstance(other, AdArray):
907907
# This may happen when using nested pp.ad.Function.

src/porepy/numerics/linalg/matrix_operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def __matmul__(
626626
# Slice matrix, vector, or AdArray by calling relevant helper methods.
627627
if isinstance(x, np.ndarray):
628628
sliced = self._slice_vector(x)
629-
elif isinstance(x, sps.spmatrix):
629+
elif isinstance(x, (sps.spmatrix, sps.sparray)):
630630
sliced = self._slice_matrix(x)
631631
elif isinstance(x, pp.ad.AdArray):
632632
val = self._slice_vector(x.val)

tests/numerics/ad/test_operators.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,11 +1068,15 @@ def _get_dense_array(wrapped: bool) -> np.ndarray | pp.ad.DenseArray:
10681068
return array
10691069

10701070

1071-
def _get_sparse_array(wrapped: bool) -> sps.spmatrix | pp.ad.SparseArray:
1071+
def _get_sparse_array(
1072+
wrapped: bool, use_csr_matrix: bool
1073+
) -> sps.spmatrix | sps.sparray | pp.ad.SparseArray:
10721074
"""Helper to set a sparse array (scipy sparse array). Expected values in the test
10731075
are hardcoded with respect to this value. The array is either returned as-is, or
10741076
wrapped as an Ad SparseArray."""
1075-
mat = sps.csr_matrix(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])).astype(float)
1077+
inner = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
1078+
mat = sps.csr_matrix(inner) if use_csr_matrix else sps.csr_array(inner)
1079+
mat = mat.astype(float)
10761080
if wrapped:
10771081
return pp.ad.SparseArray(mat)
10781082
else:
@@ -1172,7 +1176,7 @@ def _expected_value(
11721176
except ValueError:
11731177
assert op in ["@"]
11741178
return False
1175-
elif isinstance(var_1, float) and isinstance(var_2, sps.spmatrix):
1179+
elif isinstance(var_1, float) and isinstance(var_2, (sps.spmatrix, sps.sparray)):
11761180
try:
11771181
# This should fail for all operations expect from multiplication.
11781182
val = eval(f"var_1 {op} var_2")
@@ -1188,13 +1192,15 @@ def _expected_value(
11881192
return False
11891193
elif isinstance(var_1, np.ndarray) and isinstance(var_2, np.ndarray):
11901194
return eval(f"var_1 {op} var_2")
1191-
elif isinstance(var_1, np.ndarray) and isinstance(var_2, sps.spmatrix):
1195+
elif isinstance(var_1, np.ndarray) and isinstance(
1196+
var_2, (sps.spmatrix, sps.sparray)
1197+
):
11921198
try:
11931199
return eval(f"var_1 {op} var_2")
11941200
except TypeError:
11951201
assert op in ["/", "**"]
11961202
return False
1197-
elif isinstance(var_1, sps.spmatrix) and isinstance(var_2, float):
1203+
elif isinstance(var_1, (sps.spmatrix, sps.sparray)) and isinstance(var_2, float):
11981204
if op == "**":
11991205
# SciPy has implemented a limited version matrix powers to scalars, but not
12001206
# with a satisfactory flexibility. If we try to evaluate the expression, it
@@ -1209,7 +1215,9 @@ def _expected_value(
12091215
return val
12101216
except (ValueError, NotImplementedError):
12111217
return False
1212-
elif isinstance(var_1, sps.spmatrix) and isinstance(var_2, np.ndarray):
1218+
elif isinstance(var_1, (sps.spmatrix, sps.sparray)) and isinstance(
1219+
var_2, np.ndarray
1220+
):
12131221
if op == "**":
12141222
# SciPy has implemented a limited version matrix powers to numpy arrays, but
12151223
# not with a satisfactory flexibility. If we try to evaluate the expression,
@@ -1222,10 +1230,12 @@ def _expected_value(
12221230
assert op in ["**"]
12231231
return False
12241232

1225-
elif isinstance(var_1, sps.spmatrix) and isinstance(var_2, sps.spmatrix):
1233+
elif isinstance(var_1, (sps.spmatrix, sps.sparray)) and isinstance(
1234+
var_2, (sps.spmatrix, sps.sparray)
1235+
):
12261236
try:
12271237
return eval(f"var_1 {op} var_2")
1228-
except (ValueError, TypeError):
1238+
except (ValueError, TypeError, NotImplementedError):
12291239
assert op in ["**"]
12301240
return False
12311241

@@ -1434,7 +1444,9 @@ def _expected_value(
14341444
)
14351445
return pp.ad.AdArray(val, jac)
14361446

1437-
elif isinstance(var_1, pp.ad.AdArray) and isinstance(var_2, sps.spmatrix):
1447+
elif isinstance(var_1, pp.ad.AdArray) and isinstance(
1448+
var_2, (sps.spmatrix, sps.sparray)
1449+
):
14381450
return False
14391451
elif isinstance(var_1, sps.spmatrix) and isinstance(var_2, pp.ad.AdArray):
14401452
# This combination is only allowed for matrix-vector products (op = "@")
@@ -1444,6 +1456,14 @@ def _expected_value(
14441456
return pp.ad.AdArray(val, jac)
14451457
else:
14461458
return False
1459+
elif isinstance(var_1, sps.sparray) and isinstance(var_2, pp.ad.AdArray):
1460+
# This combination is only allowed for matrix-vector products (op = "@")
1461+
if op == "@":
1462+
val = var_1 @ var_2.val
1463+
jac = var_1 @ var_2.jac
1464+
return pp.ad.AdArray(val, jac)
1465+
else:
1466+
return False
14471467

14481468
elif isinstance(var_1, pp.ad.AdArray) and isinstance(var_2, pp.ad.AdArray):
14491469
# For this case, var_2 was modified manually to be twice var_1, see comments in
@@ -1525,10 +1545,16 @@ def _expected_value(
15251545
return pp.ad.AdArray(val, jac)
15261546
elif op == "@":
15271547
return False
1548+
else:
1549+
raise ValueError(f"Unknown classes: {type(var_1)}, {type(var_2)}.")
15281550

15291551

1530-
@pytest.mark.parametrize("var_1", ["scalar", "dense", "sparse", "ad"])
1531-
@pytest.mark.parametrize("var_2", ["scalar", "dense", "sparse", "ad"])
1552+
@pytest.mark.parametrize(
1553+
"var_1", ["scalar", "dense", "sparse_matrix", "sparse_array", "ad"]
1554+
)
1555+
@pytest.mark.parametrize(
1556+
"var_2", ["scalar", "dense", "sparse_matrix", "sparse_array", "ad"]
1557+
)
15321558
@pytest.mark.parametrize("op", ["+", "-", "*", "/", "**", "@"])
15331559
@pytest.mark.parametrize("wrapped", [True, False])
15341560
def test_arithmetic_operations_on_ad_objects(
@@ -1572,8 +1598,10 @@ def _var_from_string(v, do_wrap: bool):
15721598
return _get_scalar(do_wrap)
15731599
elif v == "dense":
15741600
return _get_dense_array(do_wrap)
1575-
elif v == "sparse":
1576-
return _get_sparse_array(do_wrap)
1601+
elif v == "sparse_matrix":
1602+
return _get_sparse_array(do_wrap, use_csr_matrix=True)
1603+
elif v == "sparse_array":
1604+
return _get_sparse_array(do_wrap, use_csr_matrix=False)
15771605
elif v == "ad":
15781606
return _get_ad_array(do_wrap)
15791607
else:
@@ -1624,7 +1652,7 @@ def _compare(v1, v2):
16241652
assert np.isclose(v1, v2)
16251653
elif isinstance(v1, np.ndarray):
16261654
assert np.allclose(v1, v2)
1627-
elif isinstance(v1, sps.spmatrix):
1655+
elif isinstance(v1, (sps.spmatrix, sps.sparray)):
16281656
assert np.allclose(v1.toarray(), v2.toarray())
16291657
elif isinstance(v1, pp.ad.AdArray):
16301658
assert np.allclose(v1.val, v2.val)

0 commit comments

Comments
 (0)