Skip to content

Commit

Permalink
Fix typing + ensure that we run standard checks on a schedule. (#897)
Browse files Browse the repository at this point in the history
* Add schedule

* Generalize typing of all elements in UFL layer

* Ruff formatting
  • Loading branch information
jorgensd authored Jan 15, 2025
1 parent 92c54d2 commit 1f28b82
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
10 changes: 6 additions & 4 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ on:
branches:
- main
workflow_dispatch:

# Weekly build on Mondays at 8 am
schedule:
- cron: "0 8 * * 1"
jobs:
lint:
name: Lint code
Expand All @@ -26,7 +28,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
python-version: "3.12"
- uses: actions/checkout@v4
- name: Ruff check
run: |
Expand Down Expand Up @@ -144,7 +146,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
python-version: "3.12"
- name: Install dependencies
run: sudo apt-get install -y libopenblas-dev liblapack-dev ninja-build
- name: Install Python dependencies
Expand All @@ -160,7 +162,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
python-version: "3.12"
- name: Install dependencies
run: sudo apt-get install -y libopenblas-dev liblapack-dev
- name: Install Basix C++ library
Expand Down
39 changes: 21 additions & 18 deletions python/basix/ufl.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def sub_elements(self) -> list[_AbstractFiniteElement]:

# Basix specific functions
@_abstractmethod
def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -309,7 +309,7 @@ def degree(self) -> int:

def custom_quadrature(
self,
) -> tuple[_npt.NDArray[np.float64], _npt.NDArray[np.float64]]:
) -> tuple[_npt.NDArray[np.floating], _npt.NDArray[np.floating]]:
"""Return custom quadrature rule or raise a ValueError."""
raise ValueError("Element does not have a custom quadrature rule.")

Expand Down Expand Up @@ -427,7 +427,7 @@ def basix_hash(self) -> _typing.Optional[int]:
"""Return the hash of the Basix element if this is a standard Basix element."""
return self._element.hash()

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -670,7 +670,7 @@ def __hash__(self) -> int:
"""Return a hash."""
return super().__hash__()

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -906,7 +906,7 @@ def degree(self) -> int:
"""Degree of the element."""
return max((e.degree for e in self._sub_elements), default=-1)

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def polyset_type(self) -> _basix.PolysetType:

def custom_quadrature(
self,
) -> tuple[_npt.NDArray[np.float64], _npt.NDArray[np.float64]]:
) -> tuple[_npt.NDArray[np.floating], _npt.NDArray[np.floating]]:
"""Return custom quadrature rule or raise a ValueError."""
custom_q = None
for e in self._sub_elements:
Expand Down Expand Up @@ -1245,7 +1245,7 @@ def is_quadrature(self) -> bool:
"""Is this a quadrature element?"""
return self._sub_element.is_quadrature

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -1501,7 +1501,7 @@ def has_tensor_product_factorisation(self) -> bool:

def custom_quadrature(
self,
) -> tuple[_npt.NDArray[np.float64], _npt.NDArray[np.float64]]:
) -> tuple[_npt.NDArray[np.floating], _npt.NDArray[np.floating]]:
"""Return custom quadrature rule or raise a ValueError."""
return self._sub_element.custom_quadrature()

Expand All @@ -1522,14 +1522,15 @@ class _QuadratureElement(_ElementBase):
def __init__(
self,
cell: _basix.CellType,
points: _npt.NDArray[np.float64],
weights: _npt.NDArray[np.float64],
points: _npt.NDArray[np.floating],
weights: _npt.NDArray[np.floating],
pullback: _AbstractPullback,
degree: _typing.Optional[int] = None,
dtype: _typing.Optional[_npt.DTypeLike] = np.float64,
):
"""Initialise the element."""
self._points = points
self._weights = weights
self._points = points.astype(dtype)
self._weights = weights.astype(dtype)
repr = f"QuadratureElement({cell.name}, {points!r}, {weights!r}, {pullback})".replace(
"\n", ""
)
Expand All @@ -1544,7 +1545,7 @@ def __init__(
@property
def dtype(self) -> _npt.DTypeLike:
"""Element float type."""
raise NotImplementedError()
raise self.points.dtype

@property
def basix_sobolev_space(self) -> _basix.SobolevSpace:
Expand All @@ -1566,7 +1567,7 @@ def __hash__(self) -> int:
"""Return a hash."""
return super().__hash__()

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand All @@ -1581,7 +1582,7 @@ def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.Array

if points.shape != self._points.shape:
raise ValueError("Mismatch of tabulation points and element points.")
tables = np.asarray([np.eye(points.shape[0], points.shape[0])])
tables = np.asarray([np.eye(points.shape[0], points.shape[0])], dtype=points.dtype)
return tables

def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int, int]:
Expand All @@ -1600,7 +1601,7 @@ def get_component_element(self, flat_component: int) -> tuple[_ElementBase, int,

def custom_quadrature(
self,
) -> tuple[_npt.NDArray[np.float64], _npt.NDArray[np.float64]]:
) -> tuple[_npt.NDArray[np.floating], _npt.NDArray[np.floating]]:
"""Return custom quadrature rule or raise a ValueError."""
return self._points, self._weights

Expand Down Expand Up @@ -1778,7 +1779,7 @@ def dtype(self) -> _npt.DTypeLike:
"""Element float type."""
raise NotImplementedError()

def tabulate(self, nderivs: int, points: _npt.NDArray[np.float64]) -> _npt.ArrayLike:
def tabulate(self, nderivs: int, points: _npt.NDArray[np.floating]) -> _npt.ArrayLike:
"""Tabulate the basis functions of the element.
Args:
Expand Down Expand Up @@ -2252,6 +2253,7 @@ def quadrature_element(
weights: _typing.Optional[_npt.NDArray[np.floating]] = None,
pullback: _AbstractPullback = _ufl.identity_pullback,
symmetry: _typing.Optional[bool] = None,
dtype: _typing.Optional[_npt.DTypeLike] = None,
) -> _ElementBase:
"""Create a quadrature element.
Expand All @@ -2268,6 +2270,7 @@ def quadrature_element(
pullback: Map name.
symmetry: Set to ``True`` if the tensor is symmetric. Valid for
rank 2 elements only.
dtype: Data type of quadrature points and weights
Returns:
A 'quadrature' finite element.
Expand All @@ -2288,7 +2291,7 @@ def quadrature_element(
assert points is not None
assert weights is not None

e = _QuadratureElement(cell, points, weights, pullback, degree)
e = _QuadratureElement(cell, points, weights, pullback, degree, dtype=dtype)
if value_shape == ():
if symmetry is not None:
raise ValueError("Cannot pass a symmetry argument to this element.")
Expand Down

0 comments on commit 1f28b82

Please sign in to comment.