@@ -194,9 +194,7 @@ def sub_elements(self) -> list[_AbstractFiniteElement]:
194
194
195
195
# Basix specific functions
196
196
@_abstractmethod
197
- def tabulate (
198
- self , nderivs : int , points : _npt .NDArray [np .floating ]
199
- ) -> _npt .ArrayLike :
197
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
200
198
"""Tabulate the basis functions of the element.
201
199
202
200
Args:
@@ -208,9 +206,7 @@ def tabulate(
208
206
"""
209
207
210
208
@_abstractmethod
211
- def get_component_element (
212
- self , flat_component : int
213
- ) -> tuple [_typing .Any , int , int ]:
209
+ def get_component_element (self , flat_component : int ) -> tuple [_typing .Any , int , int ]:
214
210
"""Get element that represents a component, and the offset and stride of the component.
215
211
216
212
For example, for a mixed element, this will return the
@@ -431,9 +427,7 @@ def basix_hash(self) -> _typing.Optional[int]:
431
427
"""Return the hash of the Basix element if this is a standard Basix element."""
432
428
return self ._element .hash ()
433
429
434
- def tabulate (
435
- self , nderivs : int , points : _npt .NDArray [np .floating ]
436
- ) -> _npt .ArrayLike :
430
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
437
431
"""Tabulate the basis functions of the element.
438
432
439
433
Args:
@@ -448,9 +442,7 @@ def tabulate(
448
442
# TODO: update FFCx to remove the need for transposing here
449
443
return tab .transpose ((0 , 1 , 3 , 2 )).reshape ((tab .shape [0 ], tab .shape [1 ], - 1 )) # type: ignore
450
444
451
- def get_component_element (
452
- self , flat_component : int
453
- ) -> tuple [_ElementBase , int , int ]:
445
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
454
446
"""Get element that represents a component.
455
447
456
448
Element that represents a component of the element, and the
@@ -678,9 +670,7 @@ def __hash__(self) -> int:
678
670
"""Return a hash."""
679
671
return super ().__hash__ ()
680
672
681
- def tabulate (
682
- self , nderivs : int , points : _npt .NDArray [np .floating ]
683
- ) -> _npt .ArrayLike :
673
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
684
674
"""Tabulate the basis functions of the element.
685
675
686
676
Args:
@@ -700,24 +690,17 @@ def tabulate(
700
690
elif len (self ._element ._reference_value_shape ) == 1 :
701
691
output .append (tbl [:, self ._component , :])
702
692
elif len (self ._element ._reference_value_shape ) == 2 :
703
- if (
704
- isinstance (self ._element , _BlockedElement )
705
- and self ._element ._has_symmetry
706
- ):
693
+ if isinstance (self ._element , _BlockedElement ) and self ._element ._has_symmetry :
707
694
# FIXME: check that this behaves as expected
708
695
output .append (tbl [:, self ._component , :])
709
696
else :
710
697
vs0 = self ._element ._reference_value_shape [0 ]
711
- output .append (
712
- tbl [:, self ._component // vs0 , self ._component % vs0 , :]
713
- )
698
+ output .append (tbl [:, self ._component // vs0 , self ._component % vs0 , :])
714
699
else :
715
700
raise NotImplementedError ()
716
701
return np .asarray (output , dtype = np .float64 )
717
702
718
- def get_component_element (
719
- self , flat_component : int
720
- ) -> tuple [_ElementBase , int , int ]:
703
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
721
704
"""Get element that represents a component.
722
705
723
706
Element that represents a component of the element, and the
@@ -897,9 +880,7 @@ def __init__(self, sub_elements: list[_ElementBase]):
897
880
898
881
def __eq__ (self , other ) -> bool :
899
882
"""Check if two elements are equal."""
900
- if isinstance (other , _MixedElement ) and len (self ._sub_elements ) == len (
901
- other ._sub_elements
902
- ):
883
+ if isinstance (other , _MixedElement ) and len (self ._sub_elements ) == len (other ._sub_elements ):
903
884
for i , j in zip (self ._sub_elements , other ._sub_elements ):
904
885
if i != j :
905
886
return False
@@ -925,9 +906,7 @@ def degree(self) -> int:
925
906
"""Degree of the element."""
926
907
return max ((e .degree for e in self ._sub_elements ), default = - 1 )
927
908
928
- def tabulate (
929
- self , nderivs : int , points : _npt .NDArray [np .floating ]
930
- ) -> _npt .ArrayLike :
909
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
931
910
"""Tabulate the basis functions of the element.
932
911
933
912
Args:
@@ -951,9 +930,7 @@ def tabulate(
951
930
tables .append (new_table )
952
931
return np .asarray (tables , dtype = np .float64 )
953
932
954
- def get_component_element (
955
- self , flat_component : int
956
- ) -> tuple [_ElementBase , int , int ]:
933
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
957
934
"""Get element that represents a component.
958
935
959
936
Element that represents a component of the element, and the
@@ -1044,10 +1021,7 @@ def num_entity_dofs(self) -> list[list[int]]:
1044
1021
"""Number of DOFs associated with each entity."""
1045
1022
data = [e .num_entity_dofs for e in self ._sub_elements ]
1046
1023
return [
1047
- [
1048
- sum (d [tdim ][entity_n ] for d in data )
1049
- for entity_n , _ in enumerate (entities )
1050
- ]
1024
+ [sum (d [tdim ][entity_n ] for d in data ) for entity_n , _ in enumerate (entities )]
1051
1025
for tdim , entities in enumerate (data [0 ])
1052
1026
]
1053
1027
@@ -1070,19 +1044,15 @@ def num_entity_closure_dofs(self) -> list[list[int]]:
1070
1044
"""Number of DOFs associated with the closure of each entity."""
1071
1045
data = [e .num_entity_closure_dofs for e in self ._sub_elements ]
1072
1046
return [
1073
- [
1074
- sum (d [tdim ][entity_n ] for d in data )
1075
- for entity_n , _ in enumerate (entities )
1076
- ]
1047
+ [sum (d [tdim ][entity_n ] for d in data ) for entity_n , _ in enumerate (entities )]
1077
1048
for tdim , entities in enumerate (data [0 ])
1078
1049
]
1079
1050
1080
1051
@property
1081
1052
def entity_closure_dofs (self ) -> list [list [list [int ]]]:
1082
1053
"""DOF numbers associated with the closure of each entity."""
1083
1054
dofs : list [list [list [int ]]] = [
1084
- [[] for i in entities ]
1085
- for entities in self ._sub_elements [0 ].entity_closure_dofs
1055
+ [[] for i in entities ] for entities in self ._sub_elements [0 ].entity_closure_dofs
1086
1056
]
1087
1057
start_dof = 0
1088
1058
for e in self ._sub_elements :
@@ -1160,9 +1130,7 @@ def custom_quadrature(
1160
1130
custom_q = e .custom_quadrature ()
1161
1131
else :
1162
1132
p , w = e .custom_quadrature ()
1163
- if not np .allclose (p , custom_q [0 ]) or not np .allclose (
1164
- w , custom_q [1 ]
1165
- ):
1133
+ if not np .allclose (p , custom_q [0 ]) or not np .allclose (w , custom_q [1 ]):
1166
1134
raise ValueError (
1167
1135
"Subelements of mixed element use different quadrature rules"
1168
1136
)
@@ -1207,13 +1175,9 @@ def __init__(
1207
1175
)
1208
1176
if symmetry is not None :
1209
1177
if len (shape ) != 2 :
1210
- raise ValueError (
1211
- "symmetry argument can only be passed to elements of rank 2."
1212
- )
1178
+ raise ValueError ("symmetry argument can only be passed to elements of rank 2." )
1213
1179
if shape [0 ] != shape [1 ]:
1214
- raise ValueError (
1215
- "symmetry argument can only be passed to square shaped elements."
1216
- )
1180
+ raise ValueError ("symmetry argument can only be passed to square shaped elements." )
1217
1181
1218
1182
if symmetry :
1219
1183
block_size = shape [0 ] * (shape [0 ] + 1 ) // 2
@@ -1281,9 +1245,7 @@ def is_quadrature(self) -> bool:
1281
1245
"""Is this a quadrature element?"""
1282
1246
return self ._sub_element .is_quadrature
1283
1247
1284
- def tabulate (
1285
- self , nderivs : int , points : _npt .NDArray [np .floating ]
1286
- ) -> _npt .ArrayLike :
1248
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
1287
1249
"""Tabulate the basis functions of the element.
1288
1250
1289
1251
Args:
@@ -1295,19 +1257,15 @@ def tabulate(
1295
1257
1296
1258
"""
1297
1259
assert len (self ._block_shape ) == 1 # TODO: block shape
1298
- assert (
1299
- self .reference_value_size == self ._block_size
1300
- ) # TODO: remove this assumption
1260
+ assert self .reference_value_size == self ._block_size # TODO: remove this assumption
1301
1261
output = []
1302
1262
for table in self ._sub_element .tabulate (nderivs , points ): # type: ignore
1303
1263
# Repeat sub element horizontally
1304
1264
assert len (table .shape ) == 2 # type: ignore
1305
1265
new_table = np .zeros (
1306
1266
(table .shape [0 ], * self ._block_shape , self ._block_size * table .shape [1 ]) # type: ignore
1307
1267
)
1308
- for i , j in enumerate (
1309
- _itertools .product (* [range (s ) for s in self ._block_shape ])
1310
- ):
1268
+ for i , j in enumerate (_itertools .product (* [range (s ) for s in self ._block_shape ])):
1311
1269
if len (j ) == 1 :
1312
1270
new_table [:, j [0 ], i :: self ._block_size ] = table
1313
1271
elif len (j ) == 2 :
@@ -1317,9 +1275,7 @@ def tabulate(
1317
1275
output .append (new_table )
1318
1276
return np .asarray (output , dtype = np .float64 )
1319
1277
1320
- def get_component_element (
1321
- self , flat_component : int
1322
- ) -> tuple [_ElementBase , int , int ]:
1278
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
1323
1279
"""Get element that represents a component.
1324
1280
1325
1281
Element that represents a component of the element, and the
@@ -1368,29 +1324,23 @@ def dim(self) -> int:
1368
1324
@property
1369
1325
def num_entity_dofs (self ) -> list [list [int ]]:
1370
1326
"""Number of DOFs associated with each entity."""
1371
- return [
1372
- [j * self ._block_size for j in i ] for i in self ._sub_element .num_entity_dofs
1373
- ]
1327
+ return [[j * self ._block_size for j in i ] for i in self ._sub_element .num_entity_dofs ]
1374
1328
1375
1329
@property
1376
1330
def entity_dofs (self ) -> list [list [list [int ]]]:
1377
1331
"""DOF numbers associated with each entity."""
1378
1332
# TODO: should this return this, or should it take blocks into
1379
1333
# account?
1380
1334
return [
1381
- [
1382
- [k * self ._block_size + b for k in j for b in range (self ._block_size )]
1383
- for j in i
1384
- ]
1335
+ [[k * self ._block_size + b for k in j for b in range (self ._block_size )] for j in i ]
1385
1336
for i in self ._sub_element .entity_dofs
1386
1337
]
1387
1338
1388
1339
@property
1389
1340
def num_entity_closure_dofs (self ) -> list [list [int ]]:
1390
1341
"""Number of DOFs associated with the closure of each entity."""
1391
1342
return [
1392
- [j * self ._block_size for j in i ]
1393
- for i in self ._sub_element .num_entity_closure_dofs
1343
+ [j * self ._block_size for j in i ] for i in self ._sub_element .num_entity_closure_dofs
1394
1344
]
1395
1345
1396
1346
@property
@@ -1399,10 +1349,7 @@ def entity_closure_dofs(self) -> list[list[list[int]]]:
1399
1349
# TODO: should this return this, or should it take blocks into
1400
1350
# account?
1401
1351
return [
1402
- [
1403
- [k * self ._block_size + b for k in j for b in range (self ._block_size )]
1404
- for j in i
1405
- ]
1352
+ [[k * self ._block_size + b for k in j for b in range (self ._block_size )] for j in i ]
1406
1353
for i in self ._sub_element .entity_closure_dofs
1407
1354
]
1408
1355
@@ -1620,9 +1567,7 @@ def __hash__(self) -> int:
1620
1567
"""Return a hash."""
1621
1568
return super ().__hash__ ()
1622
1569
1623
- def tabulate (
1624
- self , nderivs : int , points : _npt .NDArray [np .floating ]
1625
- ) -> _npt .ArrayLike :
1570
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
1626
1571
"""Tabulate the basis functions of the element.
1627
1572
1628
1573
Args:
@@ -1637,14 +1582,10 @@ def tabulate(
1637
1582
1638
1583
if points .shape != self ._points .shape :
1639
1584
raise ValueError ("Mismatch of tabulation points and element points." )
1640
- tables = np .asarray (
1641
- [np .eye (points .shape [0 ], points .shape [0 ])], dtype = points .dtype
1642
- )
1585
+ tables = np .asarray ([np .eye (points .shape [0 ], points .shape [0 ])], dtype = points .dtype )
1643
1586
return tables
1644
1587
1645
- def get_component_element (
1646
- self , flat_component : int
1647
- ) -> tuple [_ElementBase , int , int ]:
1588
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
1648
1589
"""Get element that represents a component.
1649
1590
1650
1591
Element that represents a component of the element, and the
@@ -1810,9 +1751,7 @@ def __init__(self, cell: _basix.CellType, value_shape: tuple[int, ...]):
1810
1751
self ._cell_type = cell
1811
1752
tdim = len (_basix .topology (cell )) - 1
1812
1753
1813
- super ().__init__ (
1814
- f"RealElement({ cell .name } , { value_shape } )" , cell .name , value_shape , 0
1815
- )
1754
+ super ().__init__ (f"RealElement({ cell .name } , { value_shape } )" , cell .name , value_shape , 0 )
1816
1755
1817
1756
self ._entity_counts = []
1818
1757
if tdim >= 1 :
@@ -1840,9 +1779,7 @@ def dtype(self) -> _npt.DTypeLike:
1840
1779
"""Element float type."""
1841
1780
raise NotImplementedError ()
1842
1781
1843
- def tabulate (
1844
- self , nderivs : int , points : _npt .NDArray [np .floating ]
1845
- ) -> _npt .ArrayLike :
1782
+ def tabulate (self , nderivs : int , points : _npt .NDArray [np .floating ]) -> _npt .ArrayLike :
1846
1783
"""Tabulate the basis functions of the element.
1847
1784
1848
1785
Args:
@@ -1858,9 +1795,7 @@ def tabulate(
1858
1795
out [0 , :, self .reference_value_size * v + v ] = 1.0
1859
1796
return out
1860
1797
1861
- def get_component_element (
1862
- self , flat_component : int
1863
- ) -> tuple [_ElementBase , int , int ]:
1798
+ def get_component_element (self , flat_component : int ) -> tuple [_ElementBase , int , int ]:
1864
1799
"""Get element that represents a component.
1865
1800
1866
1801
Element that represents a component of the element, and the
@@ -2160,9 +2095,7 @@ def enriched_element(
2160
2095
map_type = elements [0 ].map_type
2161
2096
for e in elements :
2162
2097
if e .map_type != map_type :
2163
- raise ValueError (
2164
- "Enriched elements on different map types not supported."
2165
- )
2098
+ raise ValueError ("Enriched elements on different map types not supported." )
2166
2099
2167
2100
dtype = e .dtype
2168
2101
hcd = min (e .embedded_subdegree for e in elements )
@@ -2175,13 +2108,9 @@ def enriched_element(
2175
2108
if e .cell_type != ct :
2176
2109
raise ValueError ("Enriched elements on different cell types not supported." )
2177
2110
if e .polyset_type != ptype :
2178
- raise ValueError (
2179
- "Enriched elements on different polyset types not supported."
2180
- )
2111
+ raise ValueError ("Enriched elements on different polyset types not supported." )
2181
2112
if e .reference_value_shape != vshape or e .reference_value_size != vsize :
2182
- raise ValueError (
2183
- "Enriched elements on different value shapes not supported."
2184
- )
2113
+ raise ValueError ("Enriched elements on different value shapes not supported." )
2185
2114
if e .dtype != dtype :
2186
2115
raise ValueError ("Enriched elements with different dtypes no supported." )
2187
2116
nderivs = max (e .interpolation_nderivs for e in elements )
@@ -2200,9 +2129,7 @@ def enriched_element(
2200
2129
pt = 0
2201
2130
dof = 0
2202
2131
for mat in M_parts :
2203
- new_M [
2204
- dof : dof + mat .shape [0 ], :, pt : pt + mat .shape [2 ], : mat .shape [3 ]
2205
- ] = mat
2132
+ new_M [dof : dof + mat .shape [0 ], :, pt : pt + mat .shape [2 ], : mat .shape [3 ]] = mat
2206
2133
dof += mat .shape [0 ]
2207
2134
pt += mat .shape [2 ]
2208
2135
M_row .append (new_M )
@@ -2410,9 +2337,7 @@ def blocked_element(
2410
2337
A blocked finite element.
2411
2338
"""
2412
2339
if len (sub_element .reference_value_shape ) != 0 :
2413
- raise ValueError (
2414
- "Cannot create a blocked element containing a non-scalar element."
2415
- )
2340
+ raise ValueError ("Cannot create a blocked element containing a non-scalar element." )
2416
2341
2417
2342
return _BlockedElement (sub_element , shape = shape , symmetry = symmetry )
2418
2343
0 commit comments