From d1044ff6f512e88b23e9436c9f23945ee3a2f3d7 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 17 Jul 2024 06:36:54 +0000 Subject: [PATCH 1/3] pass_index --- SIRF_data_preparation/evaluation_utilities.py | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index b65d5bd..01e71da 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -4,7 +4,7 @@ from typing import Iterator import matplotlib.pyplot as plt -import numpy +import numpy as np import sirf.STIR as STIR from petric import QualityMetrics @@ -15,16 +15,31 @@ def read_objectives(datadir='.'): with (Path(datadir) / 'objectives.csv').open() as csvfile: reader = csv.reader(csvfile) next(reader) # skip first (header) line - return numpy.asarray([tuple(map(float, row)) for row in reader]) + return np.asarray([tuple(map(float, row)) for row in reader]) def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'): """Read 'iter_{iter_glob}.hv' images from datadir, compute metrics and return as 2d array""" - return numpy.asarray([ + return np.asarray([ list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def plot_metrics(iters: Iterator[int], m: numpy.ndarray, labels=None, suffix=""): +def pass_index(metrics: np.ndarray, thresh: np.ndarray, window: int = 1) -> int: + """ + Returns first index of `metrics` with value <= `thresh`. + The value must remain below the threshold for at least `window`. + Raises IndexError if doesn't pass. + """ + assert metrics.shape[1] == len(thresh) + assert metrics.ndim == 2 + assert thresh.ndim == 1 + + m = (metrics <= thresh[None]).all(1) + assert window == 1 + return np.where(m[1:] & m[:-1])[0][0] + + +def plot_metrics(iters: Iterator[int], m: np.ndarray, labels=None, suffix=""): """Make 2 subplots of metrics""" if labels is None: labels = [""] * m.shape[1] From 7f3cb7080c0d903565c165088783555e6821fcdd Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 17 Jul 2024 07:11:31 +0000 Subject: [PATCH 2/3] pass_index(window) --- SIRF_data_preparation/evaluation_utilities.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 01e71da..5916c9c 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -35,8 +35,11 @@ def pass_index(metrics: np.ndarray, thresh: np.ndarray, window: int = 1) -> int: assert thresh.ndim == 1 m = (metrics <= thresh[None]).all(1) - assert window == 1 - return np.where(m[1:] & m[:-1])[0][0] + res = m[:-window] + for i in range(1, window): + res &= m[i:-window+i] + res &= m[window:] + return np.where(res)[0][0] def plot_metrics(iters: Iterator[int], m: np.ndarray, labels=None, suffix=""): From 8256725de9ef6d0ccf18e5c5a1fccb7f3bceee3a Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Wed, 17 Jul 2024 16:56:02 +0100 Subject: [PATCH 3/3] use scipy.ndimage.binary_erosion --- SIRF_data_preparation/evaluation_utilities.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 5916c9c..29cf75b 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import numpy as np +from scipy.ndimage import binary_erosion import sirf.STIR as STIR from petric import QualityMetrics @@ -24,21 +25,18 @@ def get_metrics(qm: QualityMetrics, iters: Iterator[int], srcdir='.'): list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def pass_index(metrics: np.ndarray, thresh: np.ndarray, window: int = 1) -> int: +def pass_index(metrics: np.ndarray, thresh: Iterator, window: int = 1) -> int: """ Returns first index of `metrics` with value <= `thresh`. - The value must remain below the threshold for at least `window`. - Raises IndexError if doesn't pass. + The values must remain below the respective thresholds for at least `window` number of entries. + Otherwise raises IndexError. """ - assert metrics.shape[1] == len(thresh) + thr_arr = np.asanyarray(thresh) assert metrics.ndim == 2 - assert thresh.ndim == 1 - - m = (metrics <= thresh[None]).all(1) - res = m[:-window] - for i in range(1, window): - res &= m[i:-window+i] - res &= m[window:] + assert thr_arr.ndim == 1 + assert metrics.shape[1] == thr_arr.shape[0] + passed = (metrics <= thr_arr[None]).all(axis=1) + res = binary_erosion(passed, structure=np.ones(window), origin=-(window // 2)) return np.where(res)[0][0]