Skip to content

Commit

Permalink
allow passing rotation center for tilt
Browse files Browse the repository at this point in the history
  • Loading branch information
KedoKudo committed Sep 17, 2024
1 parent 6d79fbf commit d6ce446
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/imars3d/backend/diagnostics/tilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import multiprocessing
from imars3d.backend.util.functions import clamp_max_workers
import numpy as np
from typing import Tuple
from typing import Tuple, Union, Optional
from functools import partial
from scipy.optimize import minimize_scalar
from scipy.optimize import OptimizeResult
Expand Down Expand Up @@ -103,6 +103,7 @@ def calculate_dissimilarity(
tilt: float,
image0: np.ndarray,
image1: np.ndarray,
center: Optional[Tuple[Union[float, int], Union[float, int]]] = None,
) -> float:
"""Calculate the dissimilarity between two images with given tilt.
Expand All @@ -119,6 +120,9 @@ def calculate_dissimilarity(
image1:
The second image for comparison, which is often the radiograph taken at
omega + 180 deg
center:
The center of the rotation axis, default is None, which means the center
of the image. This will be passed to the rotation function from skimage.
Returns
-------
Expand Down Expand Up @@ -168,6 +172,7 @@ def calculate_dissimilarity(
resize=True,
preserve_range=True,
order=1, # use default bi-linear interpolation for rotation
center=center,
)
# since 180 is flipped, tilting back -2 deg of the original img180 means tilting +2 deg
# of the flipped one
Expand All @@ -178,6 +183,7 @@ def calculate_dissimilarity(
resize=True,
preserve_range=True,
order=1, # use default bi-linear interpolation for rotation
center=center,
)

# p-norm
Expand All @@ -198,6 +204,7 @@ def calculate_tilt(
image180: np.ndarray,
low_bound: float = -5.0,
high_bound: float = 5.0,
center: Optional[Tuple[Union[float, int], Union[float, int]]] = None,
) -> OptimizeResult:
"""
Use optimization to find the in-plane tilt angle.
Expand All @@ -214,13 +221,16 @@ def calculate_tilt(
The lower bound of the tilt angle search space
high_bound:
The upper bound of the tilt angle search space
center:
The center of the rotation axis, default is None, which means the center
of the image. This will be passed to the rotation function from skimage.
Returns
-------
The optimization results from scipy.optimize.minimize_scalar
"""
# make the error function
err_func = partial(calculate_dissimilarity, image0=image0, image1=image180)
err_func = partial(calculate_dissimilarity, image0=image0, image1=image180, center=center)
# use bounded uni-variable optimizer to locate the tilt angle that minimize
# the dissimilarity of the 180 deg pair
res = minimize_scalar(
Expand Down Expand Up @@ -249,6 +259,9 @@ class tilt_correction(param.ParameterizedFunction):
cut_off_angle_deg: float
The angle in degrees to cut off the rotation axis tilt correction, i.e.
skip applying tilt correction for tilt angles that are too small.
center: Any
The center of the rotation axis, default is None, which means the center
of the image. This will be passed to the rotation function from skimage.
max_workers:
Number of cores to use for parallel median filtering, default is 0,
which means using all available cores.
Expand All @@ -275,6 +288,10 @@ class tilt_correction(param.ParameterizedFunction):
default=2.0,
doc="The angle in degrees to cut off the rotation axis tilt correction, i.e. skip applying tilt correction for tilt angles that are too small.",
)
center = param.Parameter(
default=None,
doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.",
)
# NOTE:
# The front and backend are sharing the same computing unit, therefore we can
# set a hard cap on the max_workers.
Expand Down Expand Up @@ -329,6 +346,7 @@ def __call__(self, **params):
calculate_tilt,
low_bound=params.low_bound,
high_bound=params.high_bound,
center=params.center,
),
[shm_arrays[il] for il in idx_lowrange],
[shm_arrays[ih] for ih in idx_highrange],
Expand All @@ -349,6 +367,7 @@ def __call__(self, **params):
corrected_array = apply_tilt_correction(
arrays=params.arrays,
tilt=tilt,
center=params.center,
max_workers=self.max_workers,
)
return corrected_array
Expand All @@ -366,6 +385,8 @@ class apply_tilt_correction(param.ParameterizedFunction):
The array for tilt correction
tilt: float
The rotation axis tilt angle in degrees
center: Any
The center of the rotation axis, default is None, which means the center
max_workers: int
Number of cores to use for parallel median filtering, default is 0, which means using all available cores.
tqdm_class: panel.widgets.Tqdm
Expand All @@ -379,6 +400,10 @@ class apply_tilt_correction(param.ParameterizedFunction):

arrays = param.Array(doc="The array for tilt correction", default=None)
tilt = param.Number(doc="The rotation axis tilt angle in degrees", default=None)
center = param.Parameter(
default=None,
doc="The center of the rotation axis, default is None, which means the center of the image. This will be passed to the rotation function from skimage.",
)
# NOTE:
# The front and backend are sharing the same computing unit, therefore we can
# set a hard cap on the max_workers.
Expand Down Expand Up @@ -406,7 +431,9 @@ def __call__(self, **params):
# dimensionality check
if params.arrays.ndim == 2:
logger.info(f"2D image detected, applying tilt correction with tilt = {params.tilt:.3f} deg")
corrected_array = rotate(params.arrays, -params.tilt, resize=False, preserve_range=True)
corrected_array = rotate(
params.arrays, -params.tilt, resize=False, preserve_range=True, center=params.center
)
elif params.arrays.ndim == 3:
logger.info(f"3D array detected, applying tilt correction with tilt = {params.tilt:.3f} deg")
with SharedMemoryManager() as smm:
Expand All @@ -420,7 +447,7 @@ def __call__(self, **params):
if params.tqdm_class:
kwargs["tqdm_class"] = params.tqdm_class
rst = process_map(
partial(rotate, angle=-params.tilt, resize=False, preserve_range=True),
partial(rotate, angle=-params.tilt, resize=False, preserve_range=True, center=params.center),
[shm_arrays[idx] for idx in range(params.arrays.shape[0])],
**kwargs,
)
Expand Down

0 comments on commit d6ce446

Please sign in to comment.