Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to step length handling #94

Merged
merged 10 commits into from
Feb 14, 2025
2 changes: 2 additions & 0 deletions stride/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def __init__(self, *args, **kwargs):
self.prec = None
self.transform = kwargs.pop('transform', None)

self.step_size = None

self.graph = Graph()
self.prev_op = None
self.needs_grad = kwargs.pop('needs_grad', False)
Expand Down
22 changes: 13 additions & 9 deletions stride/optimisation/optimisers/optimiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, variable, **kwargs):
self.variable = variable
self.step_size = kwargs.pop('step_size', 1.)
self.test_step_size = kwargs.pop('test_step_size', 1.)
self.force_step = kwargs.pop('force_step', False)
self.dump_grad = kwargs.pop('dump_grad', False)
self.dump_prec = kwargs.pop('dump_prec', False)
self._process_grad = kwargs.pop('process_grad', ProcessGlobalGradient(**kwargs))
Expand Down Expand Up @@ -201,18 +202,20 @@ async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs):
done_search = True

if done_search:
# cap the step if needed
max_step = kwargs.pop('max_step', None)
max_step = np.inf if not isinstance(max_step, (int, float)) else max_step
if not self.force_step:
# cap the step if needed
max_step = kwargs.pop('max_step', None)
max_step = np.inf if not isinstance(max_step, (int, float)) else max_step

unclipped_step = next_step

if next_step > -0.2: # if bit -ve, still assume grad is right dirn
next_step = max(0.1, min(next_step, max_step))
elif max_step < np.inf and next_step < -max_step * 0.75: # in general, prevent -ve steps
next_step = -max_step * 0.75
elif next_step < -0.2:
next_step = next_step * 0.25
if not self.force_step:
if next_step > -0.2: # if bit -ve, still assume grad is right dirn
next_step = max(0.1, min(next_step, max_step))
elif max_step < np.inf and next_step < -max_step * 0.75: # in general, prevent -ve steps
next_step = -max_step * 0.75
elif next_step < -0.2:
next_step = next_step * 0.25

logger.perf('\t taking final update step of %e [unclipped step of %e]' % (next_step, unclipped_step))
else:
Expand All @@ -230,6 +233,7 @@ async def step(self, step_size=None, grad=None, processed_grad=None, **kwargs):
if self.variable.transform is not None:
upd_variable = self.variable.transform(upd_variable)
self.variable.data[:] = upd_variable.data.copy()
self.variable.step_size = next_step

# post-process variable after update
await self.post_process(**kwargs)
Expand Down
4 changes: 2 additions & 2 deletions stride/problem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ def __get_desc__(self, **kwargs):
'inner': inner,
'dtype': str(np.dtype(self._dtype)),
'data': data,
'compression': compression if compression is not None else False
'compression': compression if compression is not None else False,
'step_size': self.step_size
}

return description
Expand Down Expand Up @@ -1219,7 +1220,6 @@ def __get_desc__(self, **kwargs):
description = super().__get_desc__(**kwargs)
description['time_dependent'] = self._time_dependent
description['slow_time_dependent'] = self._slow_time_dependent

return description

def __set_desc__(self, description, **kwargs):
Expand Down
Loading