File tree Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Expand file tree Collapse file tree 2 files changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -70,7 +70,7 @@ def prox_lasso(x: Any,
70
70
if l1reg is None :
71
71
l1reg = 1.0
72
72
73
- if type (l1reg ) == float :
73
+ if jnp . isscalar (l1reg ):
74
74
l1reg = tree_util .tree_map (lambda y : l1reg * jnp .ones_like (y ), x )
75
75
76
76
def fun (u , v ): return jnp .sign (u ) * jax .nn .relu (jnp .abs (u ) - v * scaling )
Original file line number Diff line number Diff line change @@ -87,6 +87,13 @@ def test_prox_lasso(self):
87
87
got = prox .prox_lasso (x , alpha )
88
88
self .assertArraysAllClose (jnp .array (expected ), jnp .array (got ))
89
89
90
+ # test that works when regularizer is float and prox jit compiled
91
+ got = jax .jit (prox .prox_lasso )(x , 0.5 )
92
+ expected0 = [self ._prox_l1 (x [0 ][i ], 0.5 ) for i in range (len (x [0 ]))]
93
+ expected1 = [self ._prox_l1 (x [1 ][i ], 0.5 ) for i in range (len (x [0 ]))]
94
+ expected = (jnp .array (expected0 ), jnp .array (expected1 ))
95
+ self .assertArraysAllClose (jnp .array (expected ), jnp .array (got ))
96
+
90
97
def _prox_enet (self , x , lam , gamma ):
91
98
return (1.0 / (1.0 + lam * gamma )) * self ._prox_l1 (x , lam )
92
99
You can’t perform that action at this time.
0 commit comments