From 27dc9263015bd073a6109709fc64720add378be2 Mon Sep 17 00:00:00 2001 From: Yusha Arif Date: Wed, 18 Sep 2024 16:23:04 +0000 Subject: [PATCH] fix: renaming the `USE_NATIVE_KERAS_LAYERS` env variable --- ivy/stateful/utilities.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/stateful/utilities.py b/ivy/stateful/utilities.py index a36414f43e95..78f91c889b13 100644 --- a/ivy/stateful/utilities.py +++ b/ivy/stateful/utilities.py @@ -89,9 +89,9 @@ def _maybe_update_flax_layer_weights(layer, weight_name, new_weight): import flax.nnx as nnx import jax.numpy as jnp - has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + has_flax_layers = os.environ.get("USE_NATIVE_FW_LAYERS", None) == "true" transpose_weights = ( - has_keras_layers + has_flax_layers or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true" ) @@ -286,7 +286,7 @@ def _maybe_update_keras_layer_weights(layer, weight_name, new_weight): else: KerasVariable = tf.Variable - has_keras_layers = os.environ.get("USE_NATIVE_KERAS_LAYERS", None) == "true" + has_keras_layers = os.environ.get("USE_NATIVE_FW_LAYERS", None) == "true" transpose_weights = ( has_keras_layers or os.environ.get("APPLY_TRANSPOSE_OPTIMIZATION", None) == "true"