Skip to content

Commit

Permalink
Merge pull request #67 from SyneRBI/tb-walltime-offset
Browse files Browse the repository at this point in the history
metrics: exclude callback time
  • Loading branch information
casperdcl authored Jul 17, 2024
2 parents 6a0c66b + e92046a commit d38a919
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions petric.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,22 @@ def __init__(self, transverse_slice=None, coronal_slice=None, vmax=None, logdir=
def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
log.debug("logging iter %d...", algo.iteration)
# initialise `None` values
self.transverse_slice = algo.x.dimensions()[0] // 2 if self.transverse_slice is None else self.transverse_slice
self.coronal_slice = algo.x.dimensions()[1] // 2 if self.coronal_slice is None else self.coronal_slice
self.vmax = algo.x.max() if self.vmax is None else self.vmax

self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration)
self.tb.add_scalar("objective", algo.get_last_loss(), algo.iteration, t)
if self.x_prev is not None:
normalised_change = (algo.x - self.x_prev).norm() / algo.x.norm()
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration)
self.tb.add_scalar("normalised_change", normalised_change, algo.iteration, t)
self.x_prev = algo.x.clone()
self.tb.add_image("transverse",
np.clip(algo.x.as_array()[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0, 1),
algo.iteration)
self.tb.add_image("coronal", np.clip(algo.x.as_array()[None, :, self.coronal_slice] / self.vmax, 0, 1),
algo.iteration)
x_arr = algo.x.as_array()
self.tb.add_image("transverse", np.clip(x_arr[self.transverse_slice:self.transverse_slice + 1] / self.vmax, 0,
1), algo.iteration, t)
self.tb.add_image("coronal", np.clip(x_arr[None, :, self.coronal_slice] / self.vmax, 0, 1), algo.iteration, t)
log.debug("...logged")


Expand All @@ -118,8 +118,9 @@ def __init__(self, reference_image, whole_object_mask, background_mask, interval
def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = getattr(self, '__time', None) or time()
for tag, value in self.evaluate(algo.x).items():
self.tb_summary_writer.add_scalar(tag, value, algo.iteration)
self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t)

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"
Expand Down Expand Up @@ -153,15 +154,16 @@ def __init__(self, seconds=300, outdir=OUTDIR, transverse_slice=None, coronal_sl

def reset(self, seconds=None):
self.limit = time() + (self._seconds if seconds is None else seconds)
self.offset = 0

def __call__(self, algo: Algorithm):
if (now := time()) > self.limit:
if (now := time()) > self.limit + self.offset:
log.warning("Timeout reached. Stopping algorithm.")
raise StopIteration
if self.callbacks:
for c in self.callbacks:
c(algo)
self.limit += time() - now
for c in self.callbacks:
c.__time = now - self.offset # privately inject walltime-excluding-petric-callbacks
c(algo)
self.offset += time() - now

@staticmethod
def mean_absolute_error(y, x):
Expand Down

0 comments on commit d38a919

Please sign in to comment.