From 54466f6429f3508426242a8aaa2d38beb3771e2d Mon Sep 17 00:00:00 2001 From: nicktianboli Date: Tue, 13 Feb 2024 17:38:03 +0800 Subject: [PATCH] change ad_util const name --- autofd/__init__.py | 2 +- autofd/operators/operators.py | 4 ++-- setup.cfg | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/autofd/__init__.py b/autofd/__init__.py index 8fe0d61..d62d217 100644 --- a/autofd/__init__.py +++ b/autofd/__init__.py @@ -60,4 +60,4 @@ def scale_by_learning_rate( "operators", ] -__version__ = "0.0.6" # noqa +__version__ = "0.0.7" # noqa diff --git a/autofd/operators/operators.py b/autofd/operators/operators.py index 5eddd54..48f487f 100644 --- a/autofd/operators/operators.py +++ b/autofd/operators/operators.py @@ -1196,8 +1196,8 @@ def add_values(*args): return fs[0] -ad_util.jaxval_adders[types.FunctionType] = add -ad_util.jaxval_adders[function] = add +ad_util.raw_jaxval_adders[types.FunctionType] = add +ad_util.raw_jaxval_adders[function] = add array_operators = { "neg": lambda x: numpy.negative(x), # noqa diff --git a/setup.cfg b/setup.cfg index 2dccb6e..356a90a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,7 +22,7 @@ ignore = E731 [metadata] name = autofd -version = 0.0.6 +version = 0.0.7 author = "Min Lin" author_email = "linmin@sea.com" description = "Automatic Functional Derivative in JAX" @@ -38,7 +38,7 @@ classifiers = packages = find: python_requires = >=3.9 install_requires = - jax>=0.4.16 + jax>=0.4.24 jaxtyping>=0.2.21 [options.packages.find]