Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Keras 3 #317

Merged
merged 24 commits into from
Apr 10, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Many fixes
adriangb committed Apr 7, 2024
commit 5c23f1dbb9318c82fe7ae8e365b563b3f0cc5db5
44 changes: 22 additions & 22 deletions scikeras/utils/random_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import random
from contextlib import contextmanager
from typing import Generator
from typing import Generator, Iterator

import numpy as np

@@ -13,7 +13,7 @@
def tf_set_seed(seed: int) -> None:
tf.random.set_seed(seed)

def tf_get_seed() -> int:
def tf_get_seed() -> Iterator[int]:
if context.executing_eagerly():
return context.global_seed()
else:
@@ -48,23 +48,23 @@ def tensorflow_random_state(seed: int) -> Generator[None, None, None]:
orig_random_state = random.getstate()
orig_np_random_state = np.random.get_state()
tf_random_seed = tf_get_seed()
determism_enabled = tf_enable_op_determinism()

# Set values
os.environ["TF_DETERMINISTIC_OPS"] = "1"
random.seed(seed)
np.random.seed(seed)
tf_set_seed(seed)

yield

# Reset values
if origin_gpu_det is not None:
os.environ["TF_DETERMINISTIC_OPS"] = origin_gpu_det
else:
os.environ.pop("TF_DETERMINISTIC_OPS")
random.setstate(orig_random_state)
np.random.set_state(orig_np_random_state)
tf_set_seed(tf_random_seed)
if determism_enabled:
tf_disable_op_determinism()
determinism_enabled = None
try:
# Set values
os.environ["TF_DETERMINISTIC_OPS"] = "1"
random.seed(seed)
np.random.seed(seed)
tf_set_seed(seed)
determinism_enabled = tf_enable_op_determinism()
yield
finally:
# Reset values
if origin_gpu_det is not None:
os.environ["TF_DETERMINISTIC_OPS"] = origin_gpu_det
else:
os.environ.pop("TF_DETERMINISTIC_OPS")
random.setstate(orig_random_state)
np.random.set_state(orig_np_random_state)
tf_set_seed(tf_random_seed)
if determinism_enabled is False:
tf_disable_op_determinism()
79 changes: 37 additions & 42 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from unittest import mock

import keras.backend
import numpy as np
import pytest
from keras import Sequential
@@ -114,53 +113,49 @@ def test_random_states_env_vars(self, estimator, pyhash, gpu):
assert "PYTHONHASHSEED" not in os.environ


@pytest.mark.parametrize("set_floatx_and_backend_config", ["float64"], indirect=True)
def test_sample_weights_fit():
"""Checks that the `sample_weight` parameter when passed to `fit`
has the intended effect.
"""
with keras.backend.set_floatx("float64"):
# build estimator
estimator = KerasClassifier(
model=dynamic_classifier,
model__hidden_layer_sizes=(100,),
epochs=10,
random_state=0,
)
estimator1 = clone(estimator)
estimator2 = clone(estimator)

# we create 20 points
X = np.array([1] * 10000).reshape(-1, 1)
y = [1] * 5000 + [-1] * 5000
# build estimator
estimator = KerasClassifier(
model=dynamic_classifier,
model__hidden_layer_sizes=(100,),
epochs=10,
random_state=0,
)
estimator1 = clone(estimator)
estimator2 = clone(estimator)

# heavily weight towards y=1 points
sw_first_class = [0.8] * 5000 + [0.2] * 5000
# train estimator 1 with weights
estimator1.fit(X, y, sample_weight=sw_first_class)
# train estimator 2 without weights
estimator2.fit(X, y)
# estimator1 should tilt towards y=1
# estimator2 should predict about equally
average_diff_pred_prob_1 = np.average(
np.diff(estimator1.predict_proba(X), axis=1)
)
average_diff_pred_prob_2 = np.average(
np.diff(estimator2.predict_proba(X), axis=1)
)
assert average_diff_pred_prob_2 < average_diff_pred_prob_1
# we create 20 points
X = np.array([1] * 10000).reshape(-1, 1)
y = [1] * 5000 + [-1] * 5000

# equal weighting
sw_equal = [0.5] * 5000 + [0.5] * 5000
# train estimator 1 with weights
estimator1.fit(X, y, sample_weight=sw_equal)
# train estimator 2 without weights
estimator2.fit(X, y)
# both estimators should have about the same predictions
np.testing.assert_allclose(
actual=estimator1.predict_proba(X),
desired=estimator2.predict_proba(X),
rtol=1e-4,
)
# heavily weight towards y=1 points
sw_first_class = [0.8] * 5000 + [0.2] * 5000
# train estimator 1 with weights
estimator1.fit(X, y, sample_weight=sw_first_class)
# train estimator 2 without weights
estimator2.fit(X, y)
# estimator1 should tilt towards y=1
# estimator2 should predict about equally
average_diff_pred_prob_1 = np.average(np.diff(estimator1.predict_proba(X), axis=1))
average_diff_pred_prob_2 = np.average(np.diff(estimator2.predict_proba(X), axis=1))
assert average_diff_pred_prob_2 < average_diff_pred_prob_1

# equal weighting
sw_equal = [0.5] * 5000 + [0.5] * 5000
# train estimator 1 with weights
estimator1.fit(X, y, sample_weight=sw_equal)
# train estimator 2 without weights
estimator2.fit(X, y)
# both estimators should have about the same predictions
np.testing.assert_allclose(
actual=estimator1.predict_proba(X),
desired=estimator2.predict_proba(X),
rtol=1e-3,
)


def test_sample_weights_score():
4 changes: 2 additions & 2 deletions tests/test_scikit_learn_checks.py
Original file line number Diff line number Diff line change
@@ -88,15 +88,15 @@ def test_fully_compliant_estimators_low_precision(estimator, check):
),
],
)
@pytest.mark.parametrize("set_floatx_and_backend_config", ["float64"], indirect=True)
def test_fully_compliant_estimators_high_precision(estimator, check):
"""Checks that require higher training epochs."""
check_name = check.func.__name__
if check_name not in higher_precision:
pytest.skip(
"This test is run as part of test_fully_compliant_estimators_low_precision."
)
with set_floatx("float64"):
check(estimator)
check(estimator)


class SubclassedClassifier(KerasClassifier):