Skip to content

Commit ee704c3

Browse files
author
JAXopt authors
committed
Merge pull request #604 from BalzaniEdoardo:fix_prox_lasso
PiperOrigin-RevId: 745656867
2 parents 4fe3f08 + 9b9ccd5 commit ee704c3

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

jaxopt/_src/prox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def prox_lasso(x: Any,
7070
if l1reg is None:
7171
l1reg = 1.0
7272

73-
if type(l1reg) == float:
73+
if jnp.isscalar(l1reg):
7474
l1reg = tree_util.tree_map(lambda y: l1reg*jnp.ones_like(y), x)
7575

7676
def fun(u, v): return jnp.sign(u) * jax.nn.relu(jnp.abs(u) - v * scaling)

tests/prox_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ def test_prox_lasso(self):
8787
got = prox.prox_lasso(x, alpha)
8888
self.assertArraysAllClose(jnp.array(expected), jnp.array(got))
8989

90+
# test that works when regularizer is float and prox jit compiled
91+
got = jax.jit(prox.prox_lasso)(x, 0.5)
92+
expected0 = [self._prox_l1(x[0][i], 0.5) for i in range(len(x[0]))]
93+
expected1 = [self._prox_l1(x[1][i], 0.5) for i in range(len(x[0]))]
94+
expected = (jnp.array(expected0), jnp.array(expected1))
95+
self.assertArraysAllClose(jnp.array(expected), jnp.array(got))
96+
9097
def _prox_enet(self, x, lam, gamma):
9198
return (1.0 / (1.0 + lam * gamma)) * self._prox_l1(x, lam)
9299

0 commit comments

Comments
 (0)