Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebook with experimental newton implementation #944

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 181 additions & 0 deletions experiment-newton.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
{
Copy link
Member

@jessegrabowski jessegrabowski Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line #7.        grad = pt.linalg.solve(jac, f_x, assume_a="sym")

You're assuming jacobian is symmetrical, but that shouldn't be true in general right?


Reply via ReviewNB

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you are right. It is in the case of minimization, but might not be for different root finding problems. We should add an option for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch! The jacobian is indeed generally asymmetrical for n>1.

At first I thought you were talking about the Hessian. That is symmetrical but with a really weird caveat for the case when things are twice differentiable but the second derivatives aren't continuous. (In these sorts of cases you can get some really weird stuff like a fractal trail that's flat at every point but ascends.)

"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "ff134dc2-ad8c-41b9-a8da-8cc7b5352b9d",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"from scipy import linalg\n",
"from pytensor.scan.utils import until\n",
"from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "759d75fe-6b86-42a5-a9d3-96af6de75053",
"metadata": {},
"outputs": [],
"source": [
"def _newton_step(func, x, args):\n",
" f_x = func(x, *args)\n",
" jac = pt.jacobian(f_x, x)\n",
"\n",
" # TODO It would be nice to return the factored matrix for the pullback\n",
" # TODO Handle errors of the factorization\n",
" grad = pt.linalg.solve(jac, f_x, assume_a=\"sym\")\n",
"\n",
" return f_x, x - grad, grad, jac\n",
"\n",
"def _check_convergence(f_x, x, new_x, grad, tol):\n",
" # TODO What convergence criterion? Norm of grad etc...\n",
" converged = pt.lt(pt.linalg.norm(f_x, ord=1), tol)\n",
" return converged\n",
"\n",
"def _scan_step(x, n_steps, *args, func, tol):\n",
" f_x, new_x, grad, jac = _newton_step(func, x, args)\n",
" is_converged = _check_convergence(f_x, x, new_x, grad, tol)\n",
" return (new_x, n_steps + 1, jac), until(is_converged)\n",
"\n",
"def root(\n",
" func: Callable,\n",
" x0: pt.TensorVariable, # rank 1\n",
" args: tuple[pt.Variable, ...],\n",
" max_iter: int = 113,\n",
" tol: float = 1e-8,\n",
") -> tuple[\n",
" pt.TensorVariable, dict,\n",
"]:\n",
" root_func = partial(\n",
" _scan_step,\n",
" func=func,\n",
" tol=tol,\n",
" )\n",
"\n",
" outputs, updates = pytensor.scan(\n",
" root_func,\n",
" outputs_info=[x0, pt.constant(0, dtype=\"int64\"), None],\n",
" non_sequences=args,\n",
" n_steps=max_iter,\n",
" strict=True,\n",
" )\n",
"\n",
" x_trace, n_steps_trace, jac_trace = outputs\n",
" assert not updates\n",
"\n",
" return x_trace[-1], {\"n_steps\": n_steps_trace[-1], \"jac\": jac_trace[-1]}\n",
"\n",
"\n",
"def minimize(cost: Callable, x0: pt.TensorVariable, args):\n",
" def func(x):\n",
" return pt.grad(cost(x), x)\n",
"\n",
" return root(func, x0, args)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "21304789-4eab-49de-9db7-a5bb327712b2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b031e81a-c615-4af5-b2d9-897ee46f15dc",
"metadata": {},
"outputs": [],
"source": [
"x0 = pt.tensor(\"x0\", shape=(3,))\n",
"#x0 = pt.full((3,), [2., 2., 2.])\n",
"#x0 = x0.copy()\n",
"\n",
"mu = pt.tensor(\"mu\", shape=())\n",
"\n",
"def func(x, mu):\n",
" cost = pt.sum((x ** 2 - mu) ** 2)\n",
" return pt.grad(cost, x)\n",
"\n",
"\n",
"x_root, stats = root(func, x0, args=[mu], tol=1e-8)\n",
"\n",
"(x_root_dmu,) = pt.grad(x_root[0], [mu])\n",
"\n",
"f_x = func(x_root, mu)\n",
"dfunc_dmu = pt.jacobian(f_x, mu, consider_constant=[x_root])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0d54e9a4-89ed-4670-b069-ea58bb4e85e5",
"metadata": {},
"outputs": [],
"source": [
"func = pytensor.function([x0, mu], [x_root, stats[\"n_steps\"], stats[\"jac\"], dfunc_dmu])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "07747b6d-71ca-4bc3-9546-45e3122890d4",
"metadata": {},
"outputs": [],
"source": [
"x_root, n_steps, jac, dfunc_dmu_val = func(np.ones(3) * 3, np.full((), 5.))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "2bf94004-465e-4c04-a23a-971c43b637a7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.2236068, 0.2236068, 0.2236068])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Dervivative of x_root with respect to mu\n",
"-linalg.solve(jac, dfunc_dmu_val, assume_a=\"sym\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev-cuda",
"language": "python",
"name": "dev-cuda"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}