From 157f2bacb57318bd1907b94d99808d58be2510da Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Sun, 1 May 2022 11:16:40 +0200 Subject: [PATCH] Fix typing errors and set up `mypy` workflow action (#176) * Address mypy errors * Change action name * Add mypy github action * Fix typing errors * Correct job name * Fix typing error * Fix merge error * Address test failure * Fix typing errors * Fix typing errors and clean up itstat default mechanism * Fix a bug and some typing errors * Exclude modules with dynamically generated functions * Make docstring phrasing imperative * Suppress typing errors * Fix typing error and clean up * Fix typing errors * Fix or suppress typing errors * Typo fix * Revert erroneous attempt to resolve typing error * Typing annotation fix and suppress some spurious typing errors * Address typing error and rephrase error messages * Fix some typing errors * Supress some typing errors * Address typing error * Fix typing errors and docstring style issues * Address test failure * Suppress/address some typing errors * Fix merge error * Fix merge error * Improve code style * Resolve mypy errors * Resolve mypy errors * Resolve mypy errors * Modify mypy configuration in workflow * Trivial edit * Bug fix * Consistency improvement * Address CodeFactor complex function * Fix type error * Switch back to ignore, can't solve this problem without code bloat * Fixed extra 's' in `LinearOperator`s string. * Trivial edit * Use type guard rather than type ignore * Fix EllipsisType import * Minor edits for docs style Co-authored-by: Michael-T-McCann Co-authored-by: Fernando Davis --- .github/workflows/mypy.yml | 36 ++++++++++++++++++++ scico/_flax.py | 2 +- scico/_generic_operators.py | 18 ++++++---- scico/_version.py | 4 +-- scico/diagnostics.py | 20 +++++------ scico/examples.py | 26 +++++++-------- scico/functional/_denoiser.py | 6 ++-- scico/functional/_functional.py | 12 +++---- scico/linop/_circconv.py | 14 ++++---- scico/linop/_diff.py | 9 +++-- scico/linop/_linop.py | 49 +++++++++++++-------------- scico/linop/_stack.py | 59 +++++++++++++++++++-------------- scico/linop/abel.py | 2 +- scico/linop/optics.py | 28 +++++++++++----- scico/linop/radon_astra.py | 6 ++-- scico/linop/radon_svmbir.py | 21 +++++++----- scico/loss.py | 28 ++++++++-------- scico/numpy/__init__.py | 2 +- scico/numpy/blockarray.py | 4 +-- scico/numpy/util.py | 14 ++++---- scico/operator/biconvolve.py | 9 ++--- scico/optimize/_ladmm.py | 14 ++++---- scico/optimize/_primaldual.py | 8 ++--- scico/optimize/admm.py | 27 ++++++++------- scico/optimize/pgm.py | 6 ++-- scico/random.py | 14 ++++---- scico/solver.py | 2 +- scico/typing.py | 10 ++++-- scico/util.py | 18 +++++----- 29 files changed, 272 insertions(+), 196 deletions(-) create mode 100644 .github/workflows/mypy.yml diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 000000000..bb209f2d1 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,36 @@ +# Install and run mypy + +name: mypy + +# Controls when the workflow will run +on: + # Triggers the workflow on push or pull request events but only for the main branch + push: + branches: [ main ] + pull_request: + branches: [ main ] + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +jobs: + mypy: + # The type of runner that the job will run on + runs-on: ubuntu-latest + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + with: + submodules: recursive + - name: Install Python 3 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install dependencies + run: | + pip install mypy + - name: Run mypy + run: | + mypy --follow-imports=skip --ignore-missing-imports --exclude "(numpy|test)" scico/ diff --git a/scico/_flax.py b/scico/_flax.py index bda3facc7..cf0d1aa2a 100644 --- a/scico/_flax.py +++ b/scico/_flax.py @@ -203,5 +203,5 @@ def __call__(self, x: JaxArray) -> JaxArray: x = x.reshape((1,) + x.shape + (1,)) elif x.ndim == 3: x = x.reshape((1,) + x.shape) - y = self.model.apply(self.variables, x, train=False, mutable=False) + y = self.model.apply(self.variables, x, train=False, mutable=False) # type: ignore return y.reshape(x_shape) diff --git a/scico/_generic_operators.py b/scico/_generic_operators.py index 0657ea108..e435c0e3f 100644 --- a/scico/_generic_operators.py +++ b/scico/_generic_operators.py @@ -126,6 +126,9 @@ def __init__( #: Dtype of input self.input_dtype: DType + #: Dtype of operator + self.dtype: DType + if isinstance(input_shape, int): self.input_shape = (input_shape,) else: @@ -140,7 +143,7 @@ def __init__( if output_shape is None or output_dtype is None: tmp = self(snp.zeros(self.input_shape, dtype=input_dtype)) if output_shape is None: - self.output_shape = tmp.shape + self.output_shape = tmp.shape # type: ignore else: self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape @@ -312,10 +315,11 @@ def freeze(self, argnum: int, val: Union[JaxArray, BlockArray]) -> Operator: f"{self.input_shape[argnum]}, got {val.shape}" ) - input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) + input_shape: Union[Shape, BlockShape] + input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore if len(input_shape) == 1: - input_shape = input_shape[0] + input_shape = input_shape[0] # type: ignore def concat_args(args): # Creates a blockarray with args and the frozen value in the correct place @@ -456,9 +460,9 @@ def __init__( ) if not hasattr(self, "_adj"): - self._adj = None + self._adj: Optional[Callable] = None if not hasattr(self, "_gram"): - self._gram = None + self._gram: Optional[Callable] = None if callable(adj_fn): self._adj = adj_fn self._gram = lambda x: self.adj(self(x)) @@ -584,7 +588,7 @@ def adj( input `y`. Args: - y: Point at which to compute adjoint. If `y` is + y: Point at which to compute adjoint. If `y` is :class:`DeviceArray` or :class:`.BlockArray`, must have `shape == self.output_shape`. If `y` is a :class:`.LinearOperator`, must have @@ -605,6 +609,7 @@ def adj( f"""Shapes do not conform: input array with shape {y.shape} does not match LinearOperator output_shape {self.output_shape}""" ) + assert self._adj is not None return self._adj(y) @property @@ -715,6 +720,7 @@ def gram( """ if self._gram is None: self._set_adjoint() + assert self._gram is not None return self._gram(x) diff --git a/scico/_version.py b/scico/_version.py index e6281e19a..08de1705b 100644 --- a/scico/_version.py +++ b/scico/_version.py @@ -41,7 +41,7 @@ def variable_assign_value(path: str, var: str) -> Any: with open(path) as f: try: # See http://stackoverflow.com/questions/2058802 - value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s + value = parse(next(filter(lambda line: line.startswith(var), f))).body[0].value.s # type: ignore except StopIteration: raise RuntimeError(f"Could not find initialization of variable {var}") return value @@ -70,7 +70,7 @@ def current_git_hash() -> Optional[str]: # nosec pragma: no cover Short git hash of current commit, or ``None`` if no git repo found. """ process = Popen(["git", "rev-parse", "--short", "HEAD"], shell=False, stdout=PIPE, stderr=PIPE) - git_hash = process.communicate()[0].strip().decode("utf-8") + git_hash: Optional[str] = process.communicate()[0].strip().decode("utf-8") if git_hash == "": git_hash = None return git_hash diff --git a/scico/diagnostics.py b/scico/diagnostics.py index 119b7074d..6536e1d67 100644 --- a/scico/diagnostics.py +++ b/scico/diagnostics.py @@ -68,23 +68,23 @@ def __init__( if not isinstance(fields, dict): raise TypeError("Parameter fields must be an instance of dict") # Subsampling rate of results that are to be displayed - self.period = period + self.period: int = period # Flag indicating whether to display and overwrite, or not display at all - self.overwrite = overwrite + self.overwrite: bool = overwrite # Number of spaces seperating fields in displayed tables - self.colsep = colsep + self.colsep: int = colsep # Main list of inserted values - self.iterations = [] + self.iterations: List = [] # Total length of header string in displayed tables - self.headlength = 0 + self.headlength: int = 0 # List of field names - self.fieldname = [] + self.fieldname: List[str] = [] # List of field format strings - self.fieldformat = [] + self.fieldformat: List[str] = [] # List of lengths of each field in displayed tables - self.fieldlength = [] + self.fieldlength: List[int] = [] # Names of fields in namedtuple used to record iteration values - self.tuplefields = [] + self.tuplefields: List[str] = [] # Compile regex for decomposing format strings fmre = re.compile(r"%(\+?-?)((?:\d+)?)(\.?)((?:\d+)?)([a-z])") # Iterate over field names @@ -131,7 +131,7 @@ def __init__( self.headlength -= colsep # Construct namedtuple used to record values - self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) + self.IterTuple = namedtuple("IterationStatsTuple", self.tuplefields) # type: ignore # Set up table header string display if requested self.display = display diff --git a/scico/examples.py b/scico/examples.py index ec68ef9d6..ac38c33e4 100644 --- a/scico/examples.py +++ b/scico/examples.py @@ -12,7 +12,7 @@ import os import tempfile import zipfile -from typing import List, Optional +from typing import List, Optional, Tuple import numpy as np @@ -201,7 +201,7 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array: """ if x.ndim == 3: - fshape = (x.shape[0], sep_width) + fshape: Tuple[int, ...] = (x.shape[0], sep_width) else: fshape = (x.shape[0], sep_width, 3) out = snp.concatenate( @@ -214,9 +214,9 @@ def tile_volume_slices(x: Array, sep_width: int = 10) -> Array: ) if x.ndim == 3: - fshape0 = (sep_width, out.shape[1]) - fshape1 = (x.shape[2], x.shape[2] + sep_width) - trans = (1, 0) + fshape0: Tuple[int, ...] = (sep_width, out.shape[1]) + fshape1: Tuple[int, ...] = (x.shape[2], x.shape[2] + sep_width) + trans: Tuple[int, ...] = (1, 0) else: fshape0 = (sep_width, out.shape[1], 3) @@ -300,7 +300,7 @@ def create_3D_foam_phantom( r_std: float = 0.001, pad: float = 0.01, is_random: bool = False, -): +) -> JaxArray: """Construct a 3D phantom with random radii and centers. Args: @@ -316,7 +316,7 @@ def create_3D_foam_phantom( process deterministic. Default ``False``. Returns: - 3D phantom of shape im_shape + 3D phantom of shape `im_shape`. """ c_lo = 0.0 c_hi = 1.0 @@ -331,10 +331,10 @@ def create_3D_foam_phantom( radii = r_std * np.random.randn(N_sphere) + r_mean im = snp.zeros(im_shape) + c_lo - for c, r in zip(centers, radii): + for c, r in zip(centers, radii): # type: ignore dist = snp.sum((x - c) ** 2, axis=-1) if snp.mean(im[dist < r**2] - c_lo) < 0.01 * c_hi: - # In numpy: im[dist < r**2] = c_hi + # equivalent to im[dist < r**2] = c_hi in numpy im = im.at[dist < r**2].set(c_hi) return im @@ -354,13 +354,13 @@ def spnoise(img: Array, nfrac: float, nmin: float = 0.0, nmax: float = 1.0) -> A """ if isinstance(img, np.ndarray): - spm = np.random.uniform(-1.0, 1.0, img.shape) + spm = np.random.uniform(-1.0, 1.0, img.shape) # type: ignore imgn = img.copy() imgn[spm < nfrac - 1.0] = nmin imgn[spm > 1.0 - nfrac] = nmax else: - spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) + spm, key = random.uniform(shape=img.shape, minval=-1.0, maxval=1.0, seed=0) # type: ignore imgn = img - imgn = imgn.at[spm < nfrac - 1.0].set(nmin) - imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) + imgn = imgn.at[spm < nfrac - 1.0].set(nmin) # type: ignore + imgn = imgn.at[spm > 1.0 - nfrac].set(nmax) # type: ignore return imgn diff --git a/scico/functional/_denoiser.py b/scico/functional/_denoiser.py index 60ec29f44..dd5982a07 100644 --- a/scico/functional/_denoiser.py +++ b/scico/functional/_denoiser.py @@ -35,7 +35,7 @@ def __init__(self, is_rgb: bool = False): self.is_rgb = is_rgb super().__init__() - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply BM3D denoiser. Args: @@ -65,7 +65,7 @@ def __init__(self): r"""Initialize a :class:`BM4D` object.""" super().__init__() - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply BM4D denoiser. Args: @@ -99,7 +99,7 @@ def __init__(self, variant: str = "6M"): """ self.dncnn = denoiser.DnCNN(variant) - def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: + def prox(self, x: JaxArray, lam: float = 1.0, **kwargs) -> JaxArray: # type: ignore r"""Apply DnCNN denoiser. *Warning*: The `lam` parameter is ignored, and has no effect on diff --git a/scico/functional/_functional.py b/scico/functional/_functional.py index df8c10046..7f94fa025 100644 --- a/scico/functional/_functional.py +++ b/scico/functional/_functional.py @@ -62,10 +62,8 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: x: Point at which to evaluate this functional. """ - if not self.has_eval: - raise NotImplementedError( - f"Functional {type(self)} cannot be evaluated; has_eval={self.has_eval}" - ) + # Functionals that can be evaluated should override this method. + raise NotImplementedError(f"Functional {type(self)} cannot be evaluated.") def prox( self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs @@ -91,10 +89,8 @@ def prox( classes. These include `x0`, an initial guess for the minimizer in the definition of :math:`\mathrm{prox}`. """ - if not self.has_prox: - raise NotImplementedError( - f"Functional {type(self)} does not have a prox; has_prox={self.has_prox}" - ) + # Functionals that have a prox should override this method. + raise NotImplementedError(f"Functional {type(self)} does not have a prox.") def conj_prox( self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs diff --git a/scico/linop/_circconv.py b/scico/linop/_circconv.py index 05bb024aa..f7e7aac6b 100644 --- a/scico/linop/_circconv.py +++ b/scico/linop/_circconv.py @@ -10,7 +10,7 @@ import math import operator from functools import partial -from typing import Optional +from typing import Optional, Tuple import numpy as np @@ -19,7 +19,7 @@ import scico.numpy as snp from scico._generic_operators import Operator from scico.numpy.util import is_nested -from scico.typing import DType, JaxArray, Shape +from scico.typing import Array, DType, JaxArray, Shape from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -123,9 +123,9 @@ def __init__( self.h_dft = snp.fft.fftn(h, s=fft_shape, axes=fft_axes) output_dtype = result_type(h.dtype, input_dtype) - if h_center is not None: + if self.h_center is not None: offset = -self.h_center - shifts = np.ix_( + shifts: Tuple[Array, ...] = np.ix_( *tuple( np.exp(-1j * k * 2 * np.pi * np.fft.fftfreq(s)) for k, s in zip(offset, input_shape[-self.ndims :]) @@ -173,7 +173,7 @@ def _eval(self, x: JaxArray) -> JaxArray: hx = hx.real return hx - def _adj(self, x: JaxArray) -> JaxArray: + def _adj(self, x: JaxArray) -> JaxArray: # type: ignore x_dft = snp.fft.fftn(x, axes=self.ifft_axes) H_adj_x = snp.fft.ifftn( snp.conj(self.h_dft) * x_dft, @@ -266,7 +266,7 @@ def from_operator( ndims = ndims if center is None: - center = tuple(d // 2 for d in H.input_shape[-ndims:]) + center = tuple(d // 2 for d in H.input_shape[-ndims:]) # type: ignore # compute impulse response d = snp.zeros(H.input_shape, H.input_dtype) @@ -276,7 +276,7 @@ def from_operator( # build CircularConvolve return CircularConvolve( Hd, - H.input_shape, + H.input_shape, # type: ignore ndims=ndims, input_dtype=H.input_dtype, h_center=snp.array(center), diff --git a/scico/linop/_diff.py b/scico/linop/_diff.py index 9f0e77d52..45f4d5383 100644 --- a/scico/linop/_diff.py +++ b/scico/linop/_diff.py @@ -73,19 +73,18 @@ def __init__( functions of the LinearOperator. """ - self.axes = parse_axes(axes, input_shape) - if axes is None: - axes_list = range(len(input_shape)) + axes_list = tuple(range(len(input_shape))) elif isinstance(axes, (list, tuple)): axes_list = axes else: axes_list = (axes,) - single_kwargs = dict(append=append, circular=circular, jit=False, input_dtype=input_dtype) + self.axes = parse_axes(axes_list, input_shape) + single_kwargs = dict(input_dtype=input_dtype, append=append, circular=circular, jit=False) ops = [FiniteDifferenceSingleAxis(axis, input_shape, **single_kwargs) for axis in axes_list] super().__init__( - ops, + ops, # type: ignore jit=jit, **kwargs, ) diff --git a/scico/linop/_linop.py b/scico/linop/_linop.py index 3866b55bf..57fbac99c 100644 --- a/scico/linop/_linop.py +++ b/scico/linop/_linop.py @@ -62,7 +62,6 @@ def operator_norm(A: LinearOperator, maxiter: int = 100, key: Optional[PRNGKey] :math:`A`, .. math:: - \| A \|_2 &= \max \{ \| A \mb{x} \|_2 \, : \, \| \mb{x} \|_2 \leq 1 \} \\ &= \sqrt{ \lambda_{ \mathrm{max} }( A^H A ) } = \sigma_{\mathrm{max}}(A) \;, @@ -152,8 +151,8 @@ def valid_adjoint( u = A(x) v = AT(y) - yTu = snp.dot(y.ravel().conj(), u.ravel()) - vTx = snp.dot(v.ravel().conj(), x.ravel()) + yTu = snp.dot(y.ravel().conj(), u.ravel()) # type: ignore + vTx = snp.dot(v.ravel().conj(), x.ravel()) # type: ignore err = snp.abs(yTu - vTx) / max(snp.abs(yTu), snp.abs(vTx)) if eps is None: return err @@ -178,7 +177,6 @@ def __init__( broadcast-compatiable with `diagonal.shape`. input_dtype: `dtype` of input argument. The default, ``None``, means `diagonal.dtype`. - """ self.diagonal = ensure_on_device(diagonal) @@ -194,9 +192,9 @@ def __init__( elif not isinstance(diagonal, BlockArray) and not is_nested(input_shape): output_shape = snp.broadcast_shapes(input_shape, self.diagonal.shape) elif isinstance(diagonal, BlockArray): - raise ValueError(f"`diagonal` was a BlockArray but `input_shape` was not nested.") + raise ValueError("`diagonal` was a BlockArray but `input_shape` was not nested.") else: - raise ValueError(f"`diagonal` was a not BlockArray but `input_shape` was nested.") + raise ValueError("`diagonal` was a not BlockArray but `input_shape` was nested.") super().__init__( input_shape=input_shape, @@ -259,7 +257,7 @@ class Slice(LinearOperator): def __init__( self, idx: ArrayIndex, - input_shape: Shape, + input_shape: Union[Shape, BlockShape], input_dtype: DType = snp.float32, jit: bool = True, **kwargs, @@ -282,8 +280,9 @@ def __init__( functions of the LinearOperator. """ + output_shape: Union[Shape, BlockShape] if is_nested(input_shape): - output_shape = input_shape[idx] + output_shape = input_shape[idx] # type: ignore else: output_shape = indexed_shape(input_shape, idx) @@ -322,22 +321,7 @@ def linop_from_function(f: Callable, classname: str, f_name: Optional[str] = Non if f_name is None: f_name = f"{f.__module__}.{f.__name__}" - def __init__( - self, - input_shape: Union[Shape, BlockShape], - *args: Any, - input_dtype: DType = snp.float32, - jit: bool = True, - **kwargs: Any, - ): - self._eval = lambda x: f(x, *args, **kwargs) - super().__init__(input_shape, input_dtype=input_dtype, jit=jit) - - OpClass = type(classname, (LinearOperator,), {"__init__": __init__}) - __class__ = OpClass # needed for super() to work - - OpClass.__doc__ = f"Linear operator version of :func:`{f_name}`." - OpClass.__init__.__doc__ = rf""" + f_doc = rf""" Args: input_shape: Shape of input array. @@ -353,6 +337,23 @@ def __init__( kwargs: Keyword arguments passed to :func:`{f_name}`. """ + def __init__( + self, + input_shape: Union[Shape, BlockShape], + *args: Any, + input_dtype: DType = snp.float32, + jit: bool = True, + **kwargs: Any, + ): + self._eval = lambda x: f(x, *args, **kwargs) + super().__init__(input_shape, input_dtype=input_dtype, jit=jit) # type: ignore + + OpClass = type(classname, (LinearOperator,), {"__init__": __init__}) + __class__ = OpClass # needed for super() to work + + OpClass.__doc__ = f"Linear operator version of :func:`{f_name}`." + OpClass.__init__.__doc__ = f_doc # type: ignore + return OpClass diff --git a/scico/linop/_stack.py b/scico/linop/_stack.py index 2289904ef..eebea6a70 100644 --- a/scico/linop/_stack.py +++ b/scico/linop/_stack.py @@ -17,6 +17,7 @@ import scico.numpy as snp from scico.numpy import BlockArray +from scico.numpy.util import is_nested from scico.typing import JaxArray from ._linop import LinearOperator, _wrap_add_sub, _wrap_mul_div_scalar @@ -35,59 +36,69 @@ def __init__( r""" Args: ops: Operators to stack. - collapse: If `True` and the output would be a `BlockArray` + collapse: If ``True`` and the output would be a `BlockArray` with shape ((m, n, ...), (m, n, ...), ...), the output is instead a `DeviceArray` with shape (S, m, n, ...) where S - is the length of `ops`. Defaults to True. + is the length of `ops`. Defaults to ``True``. jit: see `jit` in :class:`LinearOperator`. """ - if not isinstance(ops, (list, tuple)): - raise ValueError("expected a list of `LinearOperator`") + + LinearOperatorStack.check_if_stackable(ops) self.ops = ops + self.collapse = collapse + + self.collapsable = all(op.output_shape == ops[0].output_shape for op in ops) + + output_shapes = tuple(op.output_shape for op in ops) + if self.collapsable and self.collapse: + output_shape = (len(ops),) + output_shapes[0] # collapse to DeviceArray + else: + output_shape = output_shapes + + super().__init__( + input_shape=ops[0].input_shape, + output_shape=output_shape, # type: ignore + input_dtype=ops[0].input_dtype, + output_dtype=ops[0].output_dtype, + jit=jit, + **kwargs, + ) + + @staticmethod + def check_if_stackable(ops: List[LinearOperator]): + """Check that input ops are suitable for stack creation.""" + if not isinstance(ops, (list, tuple)): + raise ValueError("Expected a list of `LinearOperator`") input_shapes = [op.shape[1] for op in ops] if not all(input_shapes[0] == s for s in input_shapes): raise ValueError( - "expected all `LinearOperator`s to have the same input shapes, " + "Expected all `LinearOperator`s to have the same input shapes, " f"but got {input_shapes}" ) input_dtypes = [op.input_dtype for op in ops] if not all(input_dtypes[0] == s for s in input_dtypes): raise ValueError( - "expected all `LinearOperator`s to have the same input dtype, " + "Expected all `LinearOperator`s to have the same input dtype, " f"but got {input_dtypes}." ) - self.collapse = collapse - output_shape = tuple(op.shape[0] for op in ops) # assumes BlockArray output - - # check if collapsable and adjust output_shape if needed - self.collapsable = all(output_shape[0] == s for s in output_shape) - if self.collapsable and self.collapse: - output_shape = (len(ops),) + output_shape[0] # collapse to DeviceArray + if any([is_nested(op.shape[0]) for op in ops]): + raise ValueError("Cannot stack `LinearOperator`s with nested output shapes.") output_dtypes = [op.output_dtype for op in ops] if not np.all(output_dtypes[0] == s for s in output_dtypes): - raise ValueError("expected all `LinearOperator`s to have the same output dtype") - - super().__init__( - input_shape=input_shapes[0], - output_shape=output_shape, - input_dtype=input_dtypes[0], - output_dtype=output_dtypes[0], - jit=jit, - **kwargs, - ) + raise ValueError("Expected all `LinearOperator`s to have the same output dtype.") def _eval(self, x: JaxArray) -> Union[JaxArray, BlockArray]: if self.collapsable and self.collapse: return snp.stack([op @ x for op in self.ops]) return BlockArray([op @ x for op in self.ops]) - def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: + def _adj(self, y: Union[JaxArray, BlockArray]) -> JaxArray: # type: ignore return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) def scale_ops(self, scalars: JaxArray): diff --git a/scico/linop/abel.py b/scico/linop/abel.py index f50813922..f1bfceeb8 100644 --- a/scico/linop/abel.py +++ b/scico/linop/abel.py @@ -56,7 +56,7 @@ def _eval(self, x: JaxArray) -> JaxArray: self.output_dtype ) - def _adj(self, x: JaxArray) -> JaxArray: + def _adj(self, x: JaxArray) -> JaxArray: # type: ignore return _pyabel_transform(x, direction="transpose", proj_mat_quad=self.proj_mat_quad).astype( self.input_dtype ) diff --git a/scico/linop/optics.py b/scico/linop/optics.py index 4a785d4f5..4b56ab28a 100644 --- a/scico/linop/optics.py +++ b/scico/linop/optics.py @@ -47,18 +47,19 @@ and :math:`y` to axis 1 (columns, increasing to the right). """ - # Needed to annotate a class method that returns the encapsulating class; # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Tuple, Union +from typing import Any, Tuple, Union import numpy as np from numpy.lib.scimath import sqrt # complex sqrt import jax +from typing_extensions import TypeGuard + import scico.numpy as snp from scico.linop import Diagonal, Identity, LinearOperator from scico.numpy.util import no_nan_divide @@ -67,6 +68,11 @@ from ._dft import DFT +def _isscalar(element: Any) -> TypeGuard[Union[int, float]]: + """Type guard interface to `snp.isscalar`.""" + return snp.isscalar(element) + + def radial_transverse_frequency( input_shape: Shape, dx: Union[float, Tuple[float, ...]] ) -> np.ndarray: @@ -89,18 +95,20 @@ def radial_transverse_frequency( :math:`\sqrt{k_x^2 + k_y^2}\,`. """ - ndim = len(input_shape) # 1 or 2 dimensions + ndim: int = len(input_shape) # 1 or 2 dimensions if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): + if _isscalar(dx): dx = (dx,) * ndim else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) if ndim == 1: kx = 2 * np.pi * np.fft.fftfreq(input_shape[0], dx[0]) @@ -145,14 +153,16 @@ def __init__( if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): + if _isscalar(dx): dx = (dx,) * ndim else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) #: Illumination wavenumber; 2𝜋/wavelength self.k0: float = k0 @@ -173,7 +183,7 @@ def __init__( self.F = DFT(input_shape=input_shape, output_shape=self.padded_shape, jit=False) # Diagonal operator; phase shifting - self.D = Identity(self.kp.shape) + self.D: LinearOperator = Identity(self.kp.shape) super().__init__( input_shape=input_shape, @@ -491,14 +501,16 @@ def __init__( if ndim not in (1, 2): raise ValueError("Invalid input dimensions; must be 1 or 2") - if np.isscalar(dx): + if _isscalar(dx): dx = (dx,) * ndim else: + assert isinstance(dx, tuple) if len(dx) != ndim: raise ValueError( "dx must be a scalar or have len(dx) == len(input_shape); " f"got len(dx)={len(dx)}, len(input_shape)={ndim}" ) + assert isinstance(dx, tuple) L: Tuple[float, ...] = tuple(s * d for s, d in zip(input_shape, dx)) @@ -515,7 +527,7 @@ def __init__( self.dx_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * l)) for l in L) #: Destination plane side length self.L_D: Tuple[float, ...] = tuple(np.abs(2 * np.pi * z / (k0 * d)) for d in dx) - x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D)) + x_D = tuple(np.r_[-l / 2 : l / 2 : d] for l, d in zip(self.L_D, self.dx_D)) # type: ignore # set up radial coordinate system; either x^2 or (x^2 + y^2) if ndim == 1: diff --git a/scico/linop/radon_astra.py b/scico/linop/radon_astra.py index 3ffe5c6c3..d97bb9398 100644 --- a/scico/linop/radon_astra.py +++ b/scico/linop/radon_astra.py @@ -95,7 +95,7 @@ def __init__( "for specifics." ) else: - self.vol_geom: dict = astra.create_vol_geom(*input_shape) + self.vol_geom = astra.create_vol_geom(*input_shape) dev0 = jax.devices()[0] if dev0.device_kind == "cpu" or device == "cpu": @@ -107,9 +107,9 @@ def __init__( # Wrap our non-jax function to indicate we will supply fwd/rev mode functions self._eval = jax.custom_vjp(self._proj) - self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) + self._eval.defvjp(lambda x: (self._proj(x), None), lambda _, y: (self._bproj(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj) - self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) + self._adj.defvjp(lambda y: (self._bproj(y), None), lambda _, x: (self._proj(x),)) # type: ignore super().__init__( input_shape=self.input_shape, diff --git a/scico/linop/radon_svmbir.py b/scico/linop/radon_svmbir.py index 3c10de4ae..a73760c0d 100644 --- a/scico/linop/radon_svmbir.py +++ b/scico/linop/radon_svmbir.py @@ -11,7 +11,7 @@ `svmbir `_ package. """ -from typing import Optional +from typing import Optional, Tuple, Union import numpy as np @@ -20,7 +20,7 @@ import scico.numpy as snp from scico.loss import Loss, SquaredL2Loss -from scico.typing import JaxArray, Shape +from scico.typing import Array, JaxArray, Shape from ._linop import Diagonal, Identity, LinearOperator @@ -64,7 +64,7 @@ class TomographicProjector(LinearOperator): def __init__( self, input_shape: Shape, - angles: np.ndarray, + angles: Array, num_channels: int, center_offset: float = 0.0, is_masked: bool = False, @@ -121,7 +121,7 @@ def __init__( if len(input_shape) == 2: # 2D input self.svmbir_input_shape = (1,) + input_shape - output_shape = (len(angles), num_channels) + output_shape: Tuple[int, ...] = (len(angles), num_channels) self.svmbir_output_shape = output_shape[0:1] + (1,) + output_shape[1:2] elif len(input_shape) == 3: # 3D input self.svmbir_input_shape = input_shape @@ -162,10 +162,10 @@ def __init__( # Set up custom_vjp for _eval and _adj so jax.grad works on them. self._eval = jax.custom_vjp(self._proj_hcb) - self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) + self._eval.defvjp(lambda x: (self._proj_hcb(x), None), lambda _, y: (self._bproj_hcb(y),)) # type: ignore self._adj = jax.custom_vjp(self._bproj_hcb) - self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) + self._adj.defvjp(lambda y: (self._bproj_hcb(y), None), lambda _, x: (self._proj_hcb(x),)) # type: ignore super().__init__( input_shape=input_shape, @@ -305,6 +305,9 @@ class SVMBIRExtendedLoss(Loss): described in class :class:`.TomographicProjector`. """ + A: TomographicProjector + W: Union[Identity, Diagonal] + def __init__( self, *args, @@ -328,7 +331,7 @@ def __init__( W: Weighting diagonal operator. Must be non-negative. If ``None``, defaults to :class:`.Identity`. """ - super().__init__(*args, scale=scale, **kwargs) + super().__init__(*args, scale=scale, **kwargs) # type: ignore if not isinstance(self.A, TomographicProjector): raise ValueError("LinearOperator A must be a radon_svmbir.TomographicProjector.") @@ -367,13 +370,13 @@ def __call__(self, x: JaxArray) -> float: else: return self.scale * (self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2).sum() - def prox(self, v: JaxArray, lam: float, **kwargs) -> JaxArray: + def prox(self, v: JaxArray, lam: float = 1, **kwargs) -> JaxArray: v = v.reshape(self.A.svmbir_input_shape) y = self.y.reshape(self.A.svmbir_output_shape) weights = self.W.diagonal.reshape(self.A.svmbir_output_shape) sigma_p = snp.sqrt(lam) if "v0" in kwargs and kwargs["v0"] is not None: - v0 = np.reshape(np.array(kwargs["v0"]), self.A.svmbir_input_shape) + v0: Union[float, Array] = np.reshape(np.array(kwargs["v0"]), self.A.svmbir_input_shape) else: v0 = 0.0 diff --git a/scico/loss.py b/scico/loss.py index 2e76d9566..3bcd26bc1 100644 --- a/scico/loss.py +++ b/scico/loss.py @@ -19,7 +19,7 @@ from scico import functional, linop, operator from scico.numpy import BlockArray from scico.numpy.util import ensure_on_device, no_nan_divide -from scico.scipy.special import gammaln +from scico.scipy.special import gammaln # type: ignore from scico.solver import cg from scico.typing import JaxArray @@ -71,7 +71,7 @@ def __init__( self.y = ensure_on_device(y) if A is None: # y and x must have same shape - A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype) + A = linop.Identity(input_shape=self.y.shape, input_dtype=self.y.dtype) # type: ignore self.A = A self.f = f self.scale = scale @@ -97,7 +97,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * self.f(self.A(x) - self.y) def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1, **kwargs ) -> Union[JaxArray, BlockArray]: r"""Scaled proximal operator of loss function. @@ -120,6 +120,7 @@ def prox( f"prox is not implemented for {type(self)} when A is {type(self.A)}; " "must be Identity" ) + assert self.f is not None return self.f.prox(v - self.y, self.scale * lam, **kwargs) + self.y @_loss_mul_div_wrapper @@ -180,9 +181,9 @@ def __init__( self.W: linop.Diagonal if W is None: - self.W = linop.Identity(y.shape) + self.W = linop.Identity(y.shape) # type: ignore elif isinstance(W, linop.Diagonal): - if snp.all(W.diagonal >= 0): + if snp.all(W.diagonal >= 0): # type: ignore self.W = W else: raise ValueError(f"The weights, W.diagonal, must be non-negative.") @@ -195,7 +196,6 @@ def __init__( if prox_kwargs: default_prox_kwargs.update(prox_kwargs) self.prox_kwargs = default_prox_kwargs - prox_kwargs: dict = ({"maxiter": 100, "tol": 1e-5},) if isinstance(self.A, linop.LinearOperator): self.has_prox = True @@ -204,7 +204,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * snp.sum(self.W.diagonal * snp.abs(self.y - self.A(x)) ** 2) def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not isinstance(self.A, linop.LinearOperator): raise NotImplementedError( @@ -216,8 +216,8 @@ def prox( c = 2.0 * self.scale * lam A = self.A.diagonal W = self.W.diagonal - lhs = c * A.conj() * W * self.y + v - ATWA = c * A.conj() * W * A + lhs = c * A.conj() * W * self.y + v # type: ignore + ATWA = c * A.conj() * W * A # type: ignore return lhs / (ATWA + 1.0) # prox_{f}(v) = arg min 1/2 || v - x ||_2^2 + λ 𝛼 || A x - y ||^2_W @@ -237,7 +237,7 @@ def prox( hessian = self.hessian # = (2𝛼 A^T W A) lhs = linop.Identity(v.shape) + lam * hessian rhs = v + 2 * lam * 𝛼 * A.adj(W(y)) - x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) + x, _ = cg(lhs, rhs, x0, **self.prox_kwargs) # type: ignore return x @property @@ -254,8 +254,8 @@ def hessian(self) -> linop.LinearOperator: return linop.LinearOperator( input_shape=A.input_shape, output_shape=A.input_shape, - eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), - adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), + eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore + adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore input_dtype=A.input_dtype, ) @@ -357,7 +357,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x))) ** 2).sum() def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not self.has_prox: raise NotImplementedError(f"prox is not implemented.") @@ -564,7 +564,7 @@ def __call__(self, x: Union[JaxArray, BlockArray]) -> float: return self.scale * (self.W.diagonal * snp.abs(self.y - snp.abs(self.A(x)) ** 2) ** 2).sum() def prox( - self, v: Union[JaxArray, BlockArray], lam: float, **kwargs + self, v: Union[JaxArray, BlockArray], lam: float = 1.0, **kwargs ) -> Union[JaxArray, BlockArray]: if not self.has_prox: raise NotImplementedError(f"prox is not implemented.") diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 3cd1d1807..d1eeba82c 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -14,8 +14,8 @@ many have been extended to automatically map over block array blocks as described in :mod:`scico.numpy.blockarray`. Also included are additional functions unique to SCICO in :mod:`.util`. - """ + import numpy as np import jax.numpy as jnp diff --git a/scico/numpy/blockarray.py b/scico/numpy/blockarray.py index 84e409e46..df528f3d3 100644 --- a/scico/numpy/blockarray.py +++ b/scico/numpy/blockarray.py @@ -212,7 +212,7 @@ class BlockArray(list): - """BlockArray class""" + """BlockArray class.""" # Ensure we use BlockArray.__radd__, __rmul__, etc for binary # operations of the form op(np.ndarray, BlockArray) See @@ -233,7 +233,7 @@ def __init__(self, inputs): def dtype(self): """Return the dtype of the blocks, which must currently be homogeneous. - This allows snp.zeros(x.shape, x.dtype) to work without a mechanism + This allows `snp.zeros(x.shape, x.dtype)` to work without a mechanism to handle to lists of dtypes. """ return self[0].dtype diff --git a/scico/numpy/util.py b/scico/numpy/util.py index 74eacfccd..980eca104 100644 --- a/scico/numpy/util.py +++ b/scico/numpy/util.py @@ -161,14 +161,14 @@ def indexed_shape(shape: Shape, idx: ArrayIndex) -> Tuple[int, ...]: idx = (idx,) if len(idx) > len(shape): raise ValueError(f"Slice {idx} has more dimensions than shape {shape}.") - idx_shape = list(shape) + idx_shape: List[Optional[int]] = list(shape) offset = 0 for axis, ax_idx in enumerate(idx): if ax_idx is Ellipsis: offset = len(shape) - len(idx) continue idx_shape[axis + offset] = slice_length(shape[axis + offset], ax_idx) - return tuple(filter(lambda x: x is not None, idx_shape)) + return tuple(filter(lambda x: x is not None, idx_shape)) # type: ignore def no_nan_divide( @@ -187,7 +187,7 @@ def no_nan_divide( return snp.where(y != 0, snp.divide(x, snp.where(y != 0, y, 1)), 0) -def shape_to_size(shape: Union[Shape, BlockShape]) -> Axes: +def shape_to_size(shape: Union[Shape, BlockShape]) -> int: r"""Compute the size corresponding to a (possibly nested) shape. Args: @@ -226,8 +226,8 @@ def is_real_dtype(dtype: DType) -> bool: """Determine whether a dtype is real. Args: - dtype: A numpy or scico.numpy dtype (e.g. np.float32, - snp.complex64). + dtype: A numpy or scico.numpy dtype (e.g. ``np.float32``, + ``np.complex64``). Returns: ``False`` if the dtype is complex, otherwise ``True``. @@ -252,8 +252,8 @@ def real_dtype(dtype: DType) -> DType: """Construct the corresponding real dtype for a given complex dtype. Construct the corresponding real dtype for a given complex dtype, - e.g. the real dtype corresponding to `np.complex64` is - `np.float32`. + e.g. the real dtype corresponding to ``np.complex64`` is + ``np.float32``. Args: dtype: A complex numpy or scico.numpy dtype (e.g. ``np.complex64``, diff --git a/scico/operator/biconvolve.py b/scico/operator/biconvolve.py index 77a0be36c..46a3a6326 100644 --- a/scico/operator/biconvolve.py +++ b/scico/operator/biconvolve.py @@ -7,6 +7,7 @@ """Biconvolution operator.""" +from typing import Tuple, cast import numpy as np @@ -16,7 +17,7 @@ from scico.linop import Convolve, ConvolveByX from scico.numpy import BlockArray from scico.numpy.util import is_nested -from scico.typing import BlockShape, DType, JaxArray +from scico.typing import DType, JaxArray, Shape class BiConvolve(Operator): @@ -32,7 +33,7 @@ class BiConvolve(Operator): def __init__( self, - input_shape: BlockShape, + input_shape: Tuple[Shape, Shape], input_dtype: DType = np.float32, mode: str = "full", jit: bool = True, @@ -87,7 +88,7 @@ def freeze(self, argnum: int, val: JaxArray) -> LinearOperator: if argnum == 0: return ConvolveByX( x=val, - input_shape=self.input_shape[1], + input_shape=cast(Shape, self.input_shape[1]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, @@ -95,7 +96,7 @@ def freeze(self, argnum: int, val: JaxArray) -> LinearOperator: if argnum == 1: return Convolve( h=val, - input_shape=self.input_shape[0], + input_shape=cast(Shape, self.input_shape[0]), input_dtype=self.input_dtype, output_shape=self.output_shape, mode=self.mode, diff --git a/scico/optimize/_ladmm.py b/scico/optimize/_ladmm.py index 2be8fae21..3af0824c8 100644 --- a/scico/optimize/_ladmm.py +++ b/scico/optimize/_ladmm.py @@ -11,7 +11,7 @@ # see https://www.python.org/dev/peps/pep-0563/ from __future__ import annotations -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import scico.numpy as snp from scico.diagnostics import IterationStats @@ -144,7 +144,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -155,8 +155,8 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C.input_shape @@ -237,7 +237,9 @@ def norm_dual_residual(self) -> float: """ return norm(self.C.adj(self.z - self.z_old)) - def z_init(self, x0: Union[JaxArray, BlockArray]): + def z_init( + self, x0: Union[JaxArray, BlockArray] + ) -> Tuple[Union[JaxArray, BlockArray], Union[JaxArray, BlockArray]]: r"""Initialize auxiliary variable :math:`\mb{z}`. Initialized to @@ -254,7 +256,7 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): z_old = z return z, z_old - def u_init(self, x0: Union[JaxArray, BlockArray]): + def u_init(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: r"""Initialize scaled Lagrange multiplier :math:`\mb{u}`. Initialized to diff --git a/scico/optimize/_primaldual.py b/scico/optimize/_primaldual.py index eb0e27299..f5c6097c3 100644 --- a/scico/optimize/_primaldual.py +++ b/scico/optimize/_primaldual.py @@ -150,7 +150,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -161,8 +161,8 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C.input_shape @@ -213,7 +213,7 @@ def norm_primal_residual(self) -> float: Current value of primal residual. """ - return norm(self.x - self.x_old) / self.tau + return norm(self.x - self.x_old) / self.tau # type: ignore def norm_dual_residual(self) -> float: r"""Compute the :math:`\ell_2` norm of the dual residual. diff --git a/scico/optimize/admm.py b/scico/optimize/admm.py index 369450d49..00c8f3ea2 100644 --- a/scico/optimize/admm.py +++ b/scico/optimize/admm.py @@ -12,7 +12,7 @@ from __future__ import annotations from functools import reduce -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import jax from jax.scipy.sparse.linalg import cg as jax_cg @@ -80,7 +80,7 @@ def __init__(self, minimize_kwargs: dict = {"options": {"maxiter": 100}}): :func:`scico.solver.minimize`. """ self.minimize_kwargs = minimize_kwargs - self.info = {} + self.info: dict = {} def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """Solve the ADMM step. @@ -154,7 +154,7 @@ class LinearSubproblemSolver(SubproblemSolver): :math:`\mb{x}` update step. """ - def __init__(self, cg_kwargs: Optional[dict] = None, cg_function: str = "scico"): + def __init__(self, cg_kwargs: Optional[dict[str, Any]] = None, cg_function: str = "scico"): """Initialize a :class:`LinearSubproblemSolver` object. Args: @@ -236,8 +236,8 @@ def compute_rhs(self) -> Union[JaxArray, BlockArray]: rhs = snp.zeros(C0.input_shape, C0.input_dtype) if self.admm.f is not None: - ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y) - rhs += 2.0 * self.admm.f.scale * ATWy + ATWy = self.admm.f.A.adj(self.admm.f.W.diagonal * self.admm.f.y) # type: ignore + rhs += 2.0 * self.admm.f.scale * ATWy # type: ignore for rhoi, Ci, zi, ui in zip( self.admm.rho_list, self.admm.C_list, self.admm.z_list, self.admm.u_list @@ -256,7 +256,7 @@ def solve(self, x0: Union[JaxArray, BlockArray]) -> Union[JaxArray, BlockArray]: """ x0 = ensure_on_device(x0) rhs = self.compute_rhs() - x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) + x, self.info = self.cg(self.lhs_op, rhs, x0, **self.cg_kwargs) # type: ignore return x @@ -481,7 +481,7 @@ def __init__( # dynamically create itstat_func; see https://stackoverflow.com/questions/24733831 itstat_return = "return(" + ", ".join(["obj." + attr for attr in itstat_attrib]) + ")" - scope = {} + scope: dict[str, Callable] = {} exec("def itstat_func(obj): " + itstat_return, scope) # determine itstat options and initialize IterationStats object @@ -492,8 +492,8 @@ def __init__( } if itstat_options: default_itstat_options.update(itstat_options) - self.itstat_insert_func = default_itstat_options.pop("itstat_func", None) - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func", None) # type: ignore + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore if x0 is None: input_shape = C_list[0].input_shape @@ -531,6 +531,7 @@ def objective( if x is None: x = self.x z_list = self.z_list + assert z_list is not None out = 0.0 if self.f: out += self.f(x) @@ -581,7 +582,9 @@ def norm_dual_residual(self) -> float: out += norm(Ci.adj(zi - ziold)) ** 2 return snp.sqrt(out) - def z_init(self, x0: Union[JaxArray, BlockArray]): + def z_init( + self, x0: Union[JaxArray, BlockArray] + ) -> Tuple[List[Union[JaxArray, BlockArray]], List[Union[JaxArray, BlockArray]]]: r"""Initialize auxiliary variables :math:`\mb{z}_i`. Initialized to @@ -595,11 +598,11 @@ def z_init(self, x0: Union[JaxArray, BlockArray]): Args: x0: Initial value of :math:`\mb{x}`. """ - z_list = [Ci(x0) for Ci in self.C_list] + z_list: List[Union[JaxArray, BlockArray]] = [Ci(x0) for Ci in self.C_list] z_list_old = z_list.copy() return z_list, z_list_old - def u_init(self, x0: Union[JaxArray, BlockArray]): + def u_init(self, x0: Union[JaxArray, BlockArray]) -> List[Union[JaxArray, BlockArray]]: r"""Initialize scaled Lagrange multipliers :math:`\mb{u}_i`. Initialized to diff --git a/scico/optimize/pgm.py b/scico/optimize/pgm.py index 318dce827..4fda0e0fb 100644 --- a/scico/optimize/pgm.py +++ b/scico/optimize/pgm.py @@ -180,8 +180,8 @@ def __init__(self, kappa: float = 0.5): self.kappa: float = kappa self.xprev: Union[JaxArray, BlockArray] = None self.gradprev: Union[JaxArray, BlockArray] = None - self.Lbb1prev: float = None - self.Lbb2prev: float = None + self.Lbb1prev: Optional[float] = None + self.Lbb2prev: Optional[float] = None def update(self, v: Union[JaxArray, BlockArray]) -> float: """Update the reciprocal of the step size. @@ -469,7 +469,7 @@ def x_step(v: Union[JaxArray, BlockArray], L: float) -> Union[JaxArray, BlockArr if itstat_options: default_itstat_options.update(itstat_options) self.itstat_insert_func: Callable = default_itstat_options.pop("itstat_func") # type: ignore - self.itstat_object = IterationStats(**default_itstat_options) + self.itstat_object = IterationStats(**default_itstat_options) # type: ignore self.x: Union[JaxArray, BlockArray] = ensure_on_device(x0) # current estimate of solution diff --git a/scico/random.py b/scico/random.py index 700ed82f8..96a803a3e 100644 --- a/scico/random.py +++ b/scico/random.py @@ -108,12 +108,12 @@ def fun_alt(*args, key=None, seed=None, **kwargs): fun_alt.__doc__ = "\n\n".join( lines[0:1] + [ - f" Wrapped version of `jax.random.{fun.__name__} `_. " - "The SCICO version of this function moves the `key` argument to the end of the argument list, " - "adds an additional `seed` argument after that, and allows the `shape` argument " - "to accept a nested list, in which case a `BlockArray` is returned. " - "Always returns a `(result, key)` tuple.", - " Original docstring below.", + f" Wrapped version of `jax.random.{fun.__name__} " + f"`_. " + "The SCICO version of this function moves the `key` argument to the end of the " + "argument list, adds an additional `seed` argument after that, and allows the " + "`shape` argument to accept a nested list, in which case a `BlockArray` is returned. " + "Always returns a `(result, key)` tuple. Original docstring below.", ] + lines[1:] ) @@ -168,4 +168,4 @@ def randn( - **x** : (DeviceArray): Generated random array. - **key** : Updated random PRNGKey. """ - return normal(shape, dtype, key, seed) + return normal(shape, dtype, key, seed) # type: ignore diff --git a/scico/solver.py b/scico/solver.py index 177d1434b..703caa7fe 100644 --- a/scico/solver.py +++ b/scico/solver.py @@ -300,7 +300,7 @@ def cg( maxiter: int = 1000, info: bool = False, M: Optional[Callable] = None, -) -> Union[JaxArray, dict]: +) -> Tuple[JaxArray, dict]: r"""Conjugate Gradient solver. Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is diff --git a/scico/typing.py b/scico/typing.py index d2aad9aa3..7c9208a80 100644 --- a/scico/typing.py +++ b/scico/typing.py @@ -9,6 +9,12 @@ from typing import Any, Tuple, Union +try: + # available in python 3.10 + from types import EllipsisType # type: ignore +except ImportError: + EllipsisType = Any # type: ignore + import numpy as np import jax @@ -35,9 +41,9 @@ Axes = Union[int, Tuple[int, ...]] """Specification of one or more array axes.""" -AxisIndex = Union[slice, type(Ellipsis), int] +AxisIndex = Union[slice, EllipsisType, int] """An entity suitable for indexing/slicing of a single array axis; either a slice object, Ellipsis, or int.""" -ArrayIndex = Union[AxisIndex, Tuple[AxisIndex]] +ArrayIndex = Union[AxisIndex, Tuple[AxisIndex, ...]] """An entity suitable for indexing/slicing of multi-dimentional arrays.""" diff --git a/scico/util.py b/scico/util.py index 6e430df35..1f73bf51b 100644 --- a/scico/util.py +++ b/scico/util.py @@ -16,7 +16,7 @@ import urllib.request as urlrequest from functools import wraps from timeit import default_timer as timer -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import jax from jax.interpreters.batching import BatchTracer @@ -123,8 +123,8 @@ def __init__( """ # Initialise current and accumulated time dictionaries - self.t0 = {} - self.td = {} + self.t0: Dict[str, Optional[float]] = {} + self.td: Dict[str, float] = {} # Record default label and string indicating all labels self.default_label = default_label self.all_label = all_label @@ -188,7 +188,7 @@ def stop(self, labels: Optional[Union[str, List[str]]] = None): # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: - labels = self.t0.keys() + labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, @@ -202,7 +202,7 @@ def stop(self, labels: Optional[Union[str, List[str]]] = None): if self.t0[lbl] is not None: # Increment time accumulator from the elapsed time # since most recent start call - self.td[lbl] += t - self.t0[lbl] + self.td[lbl] += t - self.t0[lbl] # type: ignore # Set start time to None to indicate timer is not running self.t0[lbl] = None @@ -223,7 +223,7 @@ def reset(self, labels: Optional[Union[str, List[str]]] = None): # All timers are affected if label is equal to self.all_label, # otherwise only the timer(s) specified by label if labels == self.all_label: - labels = self.t0.keys() + labels = list(self.t0.keys()) elif not isinstance(labels, (list, tuple)): labels = [ labels, @@ -271,7 +271,7 @@ def elapsed(self, label: Optional[str] = None, total: bool = True) -> float: # return just the time since the current start call te = 0.0 if self.t0[label] is not None: - te = t - self.t0[label] + te = t - self.t0[label] # type: ignore if total: te += self.td[label] @@ -284,7 +284,7 @@ def labels(self) -> List[str]: List of timer labels. """ - return self.t0.keys() + return list(self.t0.keys()) def __str__(self) -> str: """Return string representation of object. @@ -313,7 +313,7 @@ def __str__(self) -> str: if self.t0[lbl] is None: ts = " Stopped" else: - ts = f" {(t - self.t0[lbl]):.2e} s" % (t - self.t0[lbl]) + ts = f" {(t - self.t0[lbl]):.2e} s" % (t - self.t0[lbl]) # type: ignore s += f"{lbl:{lfldln}s} {td:.2e} s {ts}\n" return s