Skip to content

Commit 686fbb2

Browse files
committed
Ruff formatting
1 parent bb13e8a commit 686fbb2

File tree

1 file changed

+36
-111
lines changed

1 file changed

+36
-111
lines changed

python/basix/ufl.py

Lines changed: 36 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,7 @@ def sub_elements(self) -> list[_AbstractFiniteElement]:
194194

195195
# Basix specific functions
196196
@_abstractmethod
197-
def tabulate(
198-
self, nderivs: int, points: _npt.NDArray[np.floating]
199-
) -> _npt.ArrayLike:
197+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
200198
"""Tabulate the basis functions of the element.
201199
202200
Args:
@@ -208,9 +206,7 @@ def tabulate(
208206
"""
209207

210208
@_abstractmethod
211-
def get_component_element(
212-
self, flat_component: int
213-
) -> tuple[_typing.Any, int, int]:
209+
def get_component_element(self, flat_component: int) -> tuple[_typing.Any, int, int]:
214210
"""Get element that represents a component, and the offset and stride of the component.
215211
216212
For example, for a mixed element, this will return the
@@ -431,9 +427,7 @@ def basix_hash(self) -> _typing.Optional[int]:
431427
"""Return the hash of the Basix element if this is a standard Basix element."""
432428
return self._element.hash()
433429

434-
def tabulate(
435-
self, nderivs: int, points: _npt.NDArray[np.floating]
436-
) -> _npt.ArrayLike:
430+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
437431
"""Tabulate the basis functions of the element.
438432
439433
Args:
@@ -448,9 +442,7 @@ def tabulate(
448442
# TODO: update FFCx to remove the need for transposing here
449443
return tab.transpose((0, 1, 3, 2)).reshape((tab.shape[0], tab.shape[1], -1)) # type: ignore
450444

451-
def get_component_element(
452-
self, flat_component: int
453-
) -> tuple[_ElementBase, int, int]:
445+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
454446
"""Get element that represents a component.
455447
456448
Element that represents a component of the element, and the
@@ -678,9 +670,7 @@ def __hash__(self) -> int:
678670
"""Return a hash."""
679671
return super().__hash__()
680672

681-
def tabulate(
682-
self, nderivs: int, points: _npt.NDArray[np.floating]
683-
) -> _npt.ArrayLike:
673+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
684674
"""Tabulate the basis functions of the element.
685675
686676
Args:
@@ -700,24 +690,17 @@ def tabulate(
700690
elif len(self._element._reference_value_shape) == 1:
701691
output.append(tbl[:, self._component, :])
702692
elif len(self._element._reference_value_shape) == 2:
703-
if (
704-
isinstance(self._element, _BlockedElement)
705-
and self._element._has_symmetry
706-
):
693+
if isinstance(self._element, _BlockedElement) and self._element._has_symmetry:
707694
# FIXME: check that this behaves as expected
708695
output.append(tbl[:, self._component, :])
709696
else:
710697
vs0 = self._element._reference_value_shape[0]
711-
output.append(
712-
tbl[:, self._component // vs0, self._component % vs0, :]
713-
)
698+
output.append(tbl[:, self._component // vs0, self._component % vs0, :])
714699
else:
715700
raise NotImplementedError()
716701
return np.asarray(output, dtype=np.float64)
717702

718-
def get_component_element(
719-
self, flat_component: int
720-
) -> tuple[_ElementBase, int, int]:
703+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
721704
"""Get element that represents a component.
722705
723706
Element that represents a component of the element, and the
@@ -897,9 +880,7 @@ def __init__(self, sub_elements: list[_ElementBase]):
897880

898881
def __eq__(self, other) -> bool:
899882
"""Check if two elements are equal."""
900-
if isinstance(other, _MixedElement) and len(self._sub_elements) == len(
901-
other._sub_elements
902-
):
883+
if isinstance(other, _MixedElement) and len(self._sub_elements) == len(other._sub_elements):
903884
for i, j in zip(self._sub_elements, other._sub_elements):
904885
if i != j:
905886
return False
@@ -925,9 +906,7 @@ def degree(self) -> int:
925906
"""Degree of the element."""
926907
return max((e.degree for e in self._sub_elements), default=-1)
927908

928-
def tabulate(
929-
self, nderivs: int, points: _npt.NDArray[np.floating]
930-
) -> _npt.ArrayLike:
909+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
931910
"""Tabulate the basis functions of the element.
932911
933912
Args:
@@ -951,9 +930,7 @@ def tabulate(
951930
tables.append(new_table)
952931
return np.asarray(tables, dtype=np.float64)
953932

954-
def get_component_element(
955-
self, flat_component: int
956-
) -> tuple[_ElementBase, int, int]:
933+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
957934
"""Get element that represents a component.
958935
959936
Element that represents a component of the element, and the
@@ -1044,10 +1021,7 @@ def num_entity_dofs(self) -> list[list[int]]:
10441021
"""Number of DOFs associated with each entity."""
10451022
data = [e.num_entity_dofs for e in self._sub_elements]
10461023
return [
1047-
[
1048-
sum(d[tdim][entity_n] for d in data)
1049-
for entity_n, _ in enumerate(entities)
1050-
]
1024+
[sum(d[tdim][entity_n] for d in data) for entity_n, _ in enumerate(entities)]
10511025
for tdim, entities in enumerate(data[0])
10521026
]
10531027

@@ -1070,19 +1044,15 @@ def num_entity_closure_dofs(self) -> list[list[int]]:
10701044
"""Number of DOFs associated with the closure of each entity."""
10711045
data = [e.num_entity_closure_dofs for e in self._sub_elements]
10721046
return [
1073-
[
1074-
sum(d[tdim][entity_n] for d in data)
1075-
for entity_n, _ in enumerate(entities)
1076-
]
1047+
[sum(d[tdim][entity_n] for d in data) for entity_n, _ in enumerate(entities)]
10771048
for tdim, entities in enumerate(data[0])
10781049
]
10791050

10801051
@property
10811052
def entity_closure_dofs(self) -> list[list[list[int]]]:
10821053
"""DOF numbers associated with the closure of each entity."""
10831054
dofs: list[list[list[int]]] = [
1084-
[[] for i in entities]
1085-
for entities in self._sub_elements[0].entity_closure_dofs
1055+
[[] for i in entities] for entities in self._sub_elements[0].entity_closure_dofs
10861056
]
10871057
start_dof = 0
10881058
for e in self._sub_elements:
@@ -1160,9 +1130,7 @@ def custom_quadrature(
11601130
custom_q = e.custom_quadrature()
11611131
else:
11621132
p, w = e.custom_quadrature()
1163-
if not np.allclose(p, custom_q[0]) or not np.allclose(
1164-
w, custom_q[1]
1165-
):
1133+
if not np.allclose(p, custom_q[0]) or not np.allclose(w, custom_q[1]):
11661134
raise ValueError(
11671135
"Subelements of mixed element use different quadrature rules"
11681136
)
@@ -1207,13 +1175,9 @@ def __init__(
12071175
)
12081176
if symmetry is not None:
12091177
if len(shape) != 2:
1210-
raise ValueError(
1211-
"symmetry argument can only be passed to elements of rank 2."
1212-
)
1178+
raise ValueError("symmetry argument can only be passed to elements of rank 2.")
12131179
if shape[0] != shape[1]:
1214-
raise ValueError(
1215-
"symmetry argument can only be passed to square shaped elements."
1216-
)
1180+
raise ValueError("symmetry argument can only be passed to square shaped elements.")
12171181

12181182
if symmetry:
12191183
block_size = shape[0] * (shape[0] + 1) // 2
@@ -1281,9 +1245,7 @@ def is_quadrature(self) -> bool:
12811245
"""Is this a quadrature element?"""
12821246
return self._sub_element.is_quadrature
12831247

1284-
def tabulate(
1285-
self, nderivs: int, points: _npt.NDArray[np.floating]
1286-
) -> _npt.ArrayLike:
1248+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
12871249
"""Tabulate the basis functions of the element.
12881250
12891251
Args:
@@ -1295,19 +1257,15 @@ def tabulate(
12951257
12961258
"""
12971259
assert len(self._block_shape) == 1 # TODO: block shape
1298-
assert (
1299-
self.reference_value_size == self._block_size
1300-
) # TODO: remove this assumption
1260+
assert self.reference_value_size == self._block_size # TODO: remove this assumption
13011261
output = []
13021262
for table in self._sub_element.tabulate(nderivs, points): # type: ignore
13031263
# Repeat sub element horizontally
13041264
assert len(table.shape) == 2 # type: ignore
13051265
new_table = np.zeros(
13061266
(table.shape[0], *self._block_shape, self._block_size * table.shape[1]) # type: ignore
13071267
)
1308-
for i, j in enumerate(
1309-
_itertools.product(*[range(s) for s in self._block_shape])
1310-
):
1268+
for i, j in enumerate(_itertools.product(*[range(s) for s in self._block_shape])):
13111269
if len(j) == 1:
13121270
new_table[:, j[0], i :: self._block_size] = table
13131271
elif len(j) == 2:
@@ -1317,9 +1275,7 @@ def tabulate(
13171275
output.append(new_table)
13181276
return np.asarray(output, dtype=np.float64)
13191277

1320-
def get_component_element(
1321-
self, flat_component: int
1322-
) -> tuple[_ElementBase, int, int]:
1278+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
13231279
"""Get element that represents a component.
13241280
13251281
Element that represents a component of the element, and the
@@ -1368,29 +1324,23 @@ def dim(self) -> int:
13681324
@property
13691325
def num_entity_dofs(self) -> list[list[int]]:
13701326
"""Number of DOFs associated with each entity."""
1371-
return [
1372-
[j * self._block_size for j in i] for i in self._sub_element.num_entity_dofs
1373-
]
1327+
return [[j * self._block_size for j in i] for i in self._sub_element.num_entity_dofs]
13741328

13751329
@property
13761330
def entity_dofs(self) -> list[list[list[int]]]:
13771331
"""DOF numbers associated with each entity."""
13781332
# TODO: should this return this, or should it take blocks into
13791333
# account?
13801334
return [
1381-
[
1382-
[k * self._block_size + b for k in j for b in range(self._block_size)]
1383-
for j in i
1384-
]
1335+
[[k * self._block_size + b for k in j for b in range(self._block_size)] for j in i]
13851336
for i in self._sub_element.entity_dofs
13861337
]
13871338

13881339
@property
13891340
def num_entity_closure_dofs(self) -> list[list[int]]:
13901341
"""Number of DOFs associated with the closure of each entity."""
13911342
return [
1392-
[j * self._block_size for j in i]
1393-
for i in self._sub_element.num_entity_closure_dofs
1343+
[j * self._block_size for j in i] for i in self._sub_element.num_entity_closure_dofs
13941344
]
13951345

13961346
@property
@@ -1399,10 +1349,7 @@ def entity_closure_dofs(self) -> list[list[list[int]]]:
13991349
# TODO: should this return this, or should it take blocks into
14001350
# account?
14011351
return [
1402-
[
1403-
[k * self._block_size + b for k in j for b in range(self._block_size)]
1404-
for j in i
1405-
]
1352+
[[k * self._block_size + b for k in j for b in range(self._block_size)] for j in i]
14061353
for i in self._sub_element.entity_closure_dofs
14071354
]
14081355

@@ -1620,9 +1567,7 @@ def __hash__(self) -> int:
16201567
"""Return a hash."""
16211568
return super().__hash__()
16221569

1623-
def tabulate(
1624-
self, nderivs: int, points: _npt.NDArray[np.floating]
1625-
) -> _npt.ArrayLike:
1570+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
16261571
"""Tabulate the basis functions of the element.
16271572
16281573
Args:
@@ -1637,14 +1582,10 @@ def tabulate(
16371582

16381583
if points.shape != self._points.shape:
16391584
raise ValueError("Mismatch of tabulation points and element points.")
1640-
tables = np.asarray(
1641-
[np.eye(points.shape[0], points.shape[0])], dtype=points.dtype
1642-
)
1585+
tables = np.asarray([np.eye(points.shape[0], points.shape[0])], dtype=points.dtype)
16431586
return tables
16441587

1645-
def get_component_element(
1646-
self, flat_component: int
1647-
) -> tuple[_ElementBase, int, int]:
1588+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
16481589
"""Get element that represents a component.
16491590
16501591
Element that represents a component of the element, and the
@@ -1810,9 +1751,7 @@ def __init__(self, cell: _basix.CellType, value_shape: tuple[int, ...]):
18101751
self._cell_type = cell
18111752
tdim = len(_basix.topology(cell)) - 1
18121753

1813-
super().__init__(
1814-
f"RealElement({cell.name}, {value_shape})", cell.name, value_shape, 0
1815-
)
1754+
super().__init__(f"RealElement({cell.name}, {value_shape})", cell.name, value_shape, 0)
18161755

18171756
self._entity_counts = []
18181757
if tdim >= 1:
@@ -1840,9 +1779,7 @@ def dtype(self) -> _npt.DTypeLike:
18401779
"""Element float type."""
18411780
raise NotImplementedError()
18421781

1843-
def tabulate(
1844-
self, nderivs: int, points: _npt.NDArray[np.floating]
1845-
) -> _npt.ArrayLike:
1782+
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
18461783
"""Tabulate the basis functions of the element.
18471784
18481785
Args:
@@ -1858,9 +1795,7 @@ def tabulate(
18581795
out[0, :, self.reference_value_size * v + v] = 1.0
18591796
return out
18601797

1861-
def get_component_element(
1862-
self, flat_component: int
1863-
) -> tuple[_ElementBase, int, int]:
1798+
def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
18641799
"""Get element that represents a component.
18651800
18661801
Element that represents a component of the element, and the
@@ -2160,9 +2095,7 @@ def enriched_element(
21602095
map_type = elements[0].map_type
21612096
for e in elements:
21622097
if e.map_type != map_type:
2163-
raise ValueError(
2164-
"Enriched elements on different map types not supported."
2165-
)
2098+
raise ValueError("Enriched elements on different map types not supported.")
21662099

21672100
dtype = e.dtype
21682101
hcd = min(e.embedded_subdegree for e in elements)
@@ -2175,13 +2108,9 @@ def enriched_element(
21752108
if e.cell_type != ct:
21762109
raise ValueError("Enriched elements on different cell types not supported.")
21772110
if e.polyset_type != ptype:
2178-
raise ValueError(
2179-
"Enriched elements on different polyset types not supported."
2180-
)
2111+
raise ValueError("Enriched elements on different polyset types not supported.")
21812112
if e.reference_value_shape != vshape or e.reference_value_size != vsize:
2182-
raise ValueError(
2183-
"Enriched elements on different value shapes not supported."
2184-
)
2113+
raise ValueError("Enriched elements on different value shapes not supported.")
21852114
if e.dtype != dtype:
21862115
raise ValueError("Enriched elements with different dtypes no supported.")
21872116
nderivs = max(e.interpolation_nderivs for e in elements)
@@ -2200,9 +2129,7 @@ def enriched_element(
22002129
pt = 0
22012130
dof = 0
22022131
for mat in M_parts:
2203-
new_M[
2204-
dof : dof + mat.shape[0], :, pt : pt + mat.shape[2], : mat.shape[3]
2205-
] = mat
2132+
new_M[dof : dof + mat.shape[0], :, pt : pt + mat.shape[2], : mat.shape[3]] = mat
22062133
dof += mat.shape[0]
22072134
pt += mat.shape[2]
22082135
M_row.append(new_M)
@@ -2410,9 +2337,7 @@ def blocked_element(
24102337
A blocked finite element.
24112338
"""
24122339
if len(sub_element.reference_value_shape) != 0:
2413-
raise ValueError(
2414-
"Cannot create a blocked element containing a non-scalar element."
2415-
)
2340+
raise ValueError("Cannot create a blocked element containing a non-scalar element.")
24162341

24172342
return _BlockedElement(sub_element, shape=shape, symmetry=symmetry)
24182343

0 commit comments

Comments
 (0)