Skip to content

Conversation

@samjmolyneux
Copy link
Collaborator

  • By using jax.jit, and calculating the gradient and objective in a single pass with jax.value_and_grad, we get big speed boosts to sgd. The new time taken is approximately $\frac{1}{7}\text{th}$ of the original for the two normal example on my machine.

  • New history_logging_interval parameter for the stochastic_gradient_descent function allows the user to enable or disable logging of the optimisation history. The interval determines how frequently the history is logged. This makes it easier to debug optimisations and make decisions about hyperparameters.

  • To make history logging work, we add a dataclass IterationResult which is analogous to SolverResult. However, IterationResult uses frozen=True to allow for dataclass updates each iteration. Using a dataclass ensures backward compatibility for callbacks if a new attribute is logged.

  • New callbacks parameter for the stochastic_gradient_descent function allows the user to set a list of callback functions as is standard in optimisation loops. In future, the callbacks can be used for early stopping or live plotting of the results. As an example, we include a useful tqdm callback that displays a progress bar for the iterations and displays the current objective value.

  • We also add the following tests:
    • test_normalise_callbacks: Tests that _normalise_callbacks does validation and casts valid types to list[Callable[IterationResult], None] .
    • test_sgd_history_logging_intervals: Tests that the correct iterations are logged for different intervals and that the correct associated obj, fn_args and grad are too for sgd.
    • test_callback_invocation: Tests that sgd callbacks are called in the correct order with the correct IterationResult.
    • test_invalid_callback: Tests that sgd will raise an error if given an invalid callback.
    • test_logging_or_callbacks_affect_sgd_convergence: Tests that various combinations of callbacks and logging intervals all result in the same convergence behaviour and thus all have the same final obj, fn_args, grad etc.

To sgd.py:
    - Added single pass JIT grad and objective calculation for big speed
    - boosts.
    - Added history logging option for easy debugging and better understanding
    of optimisation process.
    - Added callbacks functionality, allowing for user specific callbacks,
    e.g. early stopping, live graph etc.

Added iteration_result.py to monitor current state of convergence.
Use of dataclass ensures backwards compatablity of callbacks.

Added solver_callbacks.py for callbacks and associated funcs.

Added history attributes to solver result.
@samjmolyneux samjmolyneux force-pushed the sjmolyneux/sgd-improvements branch from 62924d7 to 3fbed8d Compare October 5, 2025 21:30
@willGraham01 willGraham01 self-requested a review October 6, 2025 09:26
Copy link
Collaborator

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All-in-all I like these changes, and these are useful features that we could do with adding. Callbacks in particular should be very useful from a development/debugging perspective too,

Most of my comments relate to the design decisions for the code, considering what's in the rest of the codebase. Namely I think we can do some code recycling in places, and we tend to write our tests in a particular format (though the test cases provided are good).

Also, there are only two commits on this branch (one for codebase changes, one for tests). In general, don't be afraid to use more granular commits in your PRs (just take a look at how long the other PRs are!) - we use squash merges anyway, so everything gets condensed into a single commit on main anyway. And it's good to be able to roll things back.

@willGraham01 willGraham01 self-requested a review November 18, 2025 11:12
Copy link
Collaborator

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like everything has been addressed and was waiting on my approval.

Unless you have any pending local changes @samjmolyneux, hit merge when you're ready.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants