Skip to content

Commit

Permalink
Make progress on getting tests working again
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Mar 25, 2024
1 parent 5399a97 commit 8f0aace
Show file tree
Hide file tree
Showing 13 changed files with 95 additions and 153 deletions.
5 changes: 3 additions & 2 deletions docs/source/notebooks/Basic_Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,17 @@ def get_reg(meta, hidden_layer_sizes, dropout):

### 3.3 Defining and training the neural net regressor

Training a regressor has nearly the same data flow as training a classifier. The differences include using `KerasRegressor` instead of `KerasClassifier` and adding `KerasRegressor.r_squared` as a metric. Most of the Scikit-learn regressors use the coefficient of determination or R^2 as a metric function, which measures correlation between the true labels and predicted labels.
Training a regressor has nearly the same data flow as training a classifier. The differences include using `KerasRegressor` instead of `KerasClassifier` and adding `keras.metrics.R2Score` as a metric. Most of the Scikit-learn regressors use the coefficient of determination or R^2 as a metric function, which measures correlation between the true labels and predicted labels.

```python
import keras
from scikeras.wrappers import KerasRegressor


reg = KerasRegressor(
model=get_reg,
loss="mse",
metrics=[KerasRegressor.r_squared],
metrics=[keras.metrics.R2Score],
hidden_layer_sizes=(100,),
dropout=0.5,
)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/notebooks/Meta_Estimators.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Because SciKeras estimators are fully compliant with the Scikit-Learn API, we ca
from sklearn.ensemble import AdaBoostClassifier


adaboost = AdaBoostClassifier(base_estimator=clf, random_state=0)
adaboost = AdaBoostClassifier(estimator=clf, random_state=0)
```

## 3. Testing with a toy dataset
Expand Down Expand Up @@ -144,7 +144,7 @@ For comparison, we run the same test with an ensemble built using `sklearn.ensem
from sklearn.ensemble import BaggingClassifier


bagging = BaggingClassifier(base_estimator=clf, random_state=0, n_jobs=-1)
bagging = BaggingClassifier(estimator=clf, random_state=0, n_jobs=-1)

bagging_score = bagging.fit(X, y).score(X, y)

Expand Down
7 changes: 7 additions & 0 deletions scikeras/_saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, Callable, Dict, Hashable, List, Tuple

import keras as keras
import keras.saving
import keras.saving.object_registration
import numpy as np
from keras.saving.saving_lib import load_model, save_model

Expand All @@ -21,7 +23,12 @@ def pack_keras_model(
Tuple[np.ndarray, List[np.ndarray]],
]:
"""Support for Pythons's Pickle protocol."""
tp = type(model)
out = BytesIO()
if tp not in keras.saving.object_registration.GLOBAL_CUSTOM_OBJECTS:
module = '.'.join(tp.__qualname__.split('.')[:-1])
name = tp.__qualname__.split('.')[-1]
keras.saving.register_keras_serializable(module, name)(tp)
save_model(model, out)
model_bytes = np.asarray(memoryview(out.getvalue()))
return (unpack_keras_model, (model_bytes,))
Expand Down
10 changes: 5 additions & 5 deletions scikeras/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def unflatten_params(items, params, base_params=None):
kwargs = {k: v for k, v in args_and_kwargs.items() if k[0] not in DIGITS}
args = [(int(k), v) for k, v in args_and_kwargs.items() if k not in kwargs]
args = (v for _, v in sorted(args)) # sorts by key / arg num
return item(*args, **kwargs)
try:
return item(*args, **kwargs)
except Exception as e:
raise e
if isinstance(items, (list, tuple)):
iter_type_ = type(items)
res = []
Expand Down Expand Up @@ -173,10 +176,7 @@ def get_metric_class(


def get_loss_class_function_or_string(loss: str) -> Union[losses_mod.Loss, Callable]:
got = losses_mod.get(loss)
if type(got) == FunctionType:
return got
return type(got) # a class, e.g. if loss="BinaryCrossentropy"
return losses_mod.get(loss)


def try_to_convert_strings_to_classes(
Expand Down
2 changes: 1 addition & 1 deletion scikeras/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def loss_name(loss: Union[str, Loss, Callable]) -> str:
if isinstance(fn_or_cls, Loss):
return _camel2snake(fn_or_cls.__class__.__name__)
if hasattr(fn_or_cls, "__name__"):
return fn_or_cls.__name__
return _camel2snake(fn_or_cls.__name__)
return fn_or_cls


Expand Down
57 changes: 22 additions & 35 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Set, Tuple, Type, Union

import keras
import numpy as np
import tensorflow as tf
from keras import losses as losses_module
from keras.models import Model
import keras
from scipy.sparse import isspmatrix, lil_matrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
Expand All @@ -20,6 +18,8 @@
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _check_sample_weight, check_array, check_X_y
from keras import losses as losses_module
from keras.models import Model

from scikeras._utils import (
accepts_kwargs,
Expand Down Expand Up @@ -381,17 +381,20 @@ def _get_compile_kwargs(self):
strict=False,
),
)
if compile_kwargs["metrics"] is not None and not isinstance(compile_kwargs['metrics'], (dict, list)):
# Keras expects a list or dict of metrics, not a single metric
compile_kwargs["metrics"] = [compile_kwargs["metrics"]]
return compile_kwargs

def _build_keras_model(self):
def _build_keras_model(self) -> keras.Model:
"""Build the Keras model.
This method will process all arguments and call the model building
function with appropriate arguments.
Returns
-------
tensorflow.keras.Model
keras.Model
Instantiated and compiled keras Model.
"""
# dynamically build model, i.e. final_build_fn builds a Keras model
Expand Down Expand Up @@ -432,9 +435,16 @@ def _build_keras_model(self):

def _ensure_compiled_model(self) -> None:
# compile model if user gave us an un-compiled model
if not (hasattr(self.model_, "loss") and hasattr(self.model_, "optimizer")):
if not self.model_.compiled:
kw = self._get_compile_kwargs()
self.model_.compile(**kw)
# check that the model has been properly compiled, which at the very least means it
# has an optimizer and a loss
# the errors keras would give are not very helpful, wrap them here in something a bit better
if not getattr(self.model_, "loss", None):
raise ValueError("You must provide a loss or a compiled model")
if not getattr(self.model_, "optimizer", None):
raise ValueError("You must provide an optimizer or a compiled model")

def _fit_keras_model(
self,
Expand Down Expand Up @@ -527,9 +537,12 @@ def _fit_keras_model(
self.history_ = defaultdict(list)

for key, val in hist.history.items():
if key == "loss" or key[:4] == "val_":
continue
key = metric_name(key)
if not (key == 'loss' or key[:4] == 'val_'):
try:
key = metric_name(key)
except ValueError:
# unknown metric, e.g. custom metric
pass
self.history_[key] += val

def _check_model_compatibility(self, y: np.ndarray) -> None:
Expand Down Expand Up @@ -1734,29 +1747,3 @@ def target_encoder(self):
interface.
"""
return RegressorTargetEncoder()

@staticmethod
def r_squared(y_true, y_pred):
"""A simple Keras implementation of R^2 that can be used as a Keras
metric function.
Larger values indicate a better fit, with 1.0 representing a perfect fit.
Parameters
----------
y_true : array-like of shape (n_samples,) or (n_samples, n_outputs)
True labels.
y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs)
Predicted labels.
"""
# Ensure input dytpes match
y_true = keras.ops.cast(y_true, dtype=y_pred.dtype)
# Calculate R^2
ss_res = keras.ops.sum(keras.ops.square(y_true, y_pred), axis=0)
ss_tot = keras.ops.sum(
keras.ops.square(y_true, tf.math.reduce_mean(y_true, axis=0)),
axis=0,
)
return tf.math.reduce_mean(
1 - ss_res / (ss_tot + keras.backend.epsilon()), axis=-1
)
9 changes: 3 additions & 6 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def test_calibratedclassifiercv(self, config):
base_estimator = KerasClassifier(
build_fn, epochs=1, model__hidden_layer_sizes=[]
)
estimator = CalibratedClassifierCV(base_estimator=base_estimator, cv=5)
estimator = CalibratedClassifierCV(estimator=base_estimator, cv=5)
basic_checks(estimator, loader)


Expand Down Expand Up @@ -851,11 +851,8 @@ def test_prebuilt_model(self, wrapper):
y_pred_keras = y_pred_keras.reshape(
-1,
)
# Extract the weights into a copy of the model
weights = m1.get_weights()
m2 = keras.models.clone_model(m1)
m2.set_weights(weights)
m2.compile() # No loss, inference models shouldn't need a loss!
# Make a copy of the model to make sure we don't modify the original
m2 = pickle.loads(pickle.dumps(m1))
# Wrap with SciKeras
est = wrapper(model=m2)
# Without calling initialize, a NotFittedError is raised
Expand Down
39 changes: 24 additions & 15 deletions tests/test_compile_kwargs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from sklearn.datasets import make_classification
from keras.backend.common.variables import KerasVariable
from keras import losses as losses_module
from keras import metrics as metrics_module
from keras import optimizers as optimizers_module
Expand Down Expand Up @@ -45,8 +46,14 @@ def test_optimizer(optimizer):
est.fit(X, y)
est_opt = est.model_.optimizer
if not isinstance(optimizer, str):
assert float(est_opt.momentum) == pytest.approx(0.5)
assert float(est_opt.learning_rate) == pytest.approx(0.15, abs=1e-6)
momentum = est_opt.momentum
if isinstance(momentum, KerasVariable):
momentum = momentum.numpy()
assert float(momentum) == pytest.approx(0.5)
lr = est_opt.learning_rate
if isinstance(lr, KerasVariable):
lr = lr.numpy()
assert lr == pytest.approx(0.15, abs=1e-6)
else:
assert est_opt.__class__ == optimizers_module.get(optimizer).__class__

Expand All @@ -65,7 +72,7 @@ def test_optimizer_invalid_string():
optimizer=optimizer,
loss="binary_crossentropy",
)
with pytest.raises(ValueError, match="Unknown optimizer"):
with pytest.raises(ValueError, match="Could not interpret optimizer"):
est.fit(X, y)


Expand Down Expand Up @@ -137,7 +144,7 @@ def test_loss_invalid_string():
num_hidden=20,
loss=loss,
)
with pytest.raises(ValueError, match="Unknown loss function"):
with pytest.raises(ValueError, match="Could not interpret loss"):
est.fit(X, y)


Expand Down Expand Up @@ -254,16 +261,18 @@ def test_metrics_single_metric_per_output(metrics, n_outputs_):
else:
expected_name = metrics().name

# List of metrics
est = MultiOutputClassifier(
model=get_model,
loss="binary_crossentropy",
metrics=[
metrics if not isinstance(metrics, metrics_module.Metric) else metrics()
],
)
est.fit(X, y)
assert est.model_.metrics[metric_idx].name == prefix + expected_name
if n_outputs_ == 1:
# List of metrics, not supported for multiple outputs where each output is required to get
# its own metrics if passing metrics as a list
est = MultiOutputClassifier(
model=get_model,
loss="binary_crossentropy",
metrics=[
metrics if not isinstance(metrics, metrics_module.Metric) else metrics()
],
)
est.fit(X, y)
assert est.model_.metrics[metric_idx].name == prefix + expected_name

# List of lists of metrics
est = MultiOutputClassifier(
Expand Down Expand Up @@ -471,7 +480,7 @@ def test_metrics_invalid_string():
loss="binary_crossentropy",
metrics=metrics,
)
with pytest.raises(ValueError, match="Unknown metric function"):
with pytest.raises(ValueError, match="Could not interpret metric identifier"):
est.fit(X, y)


Expand Down
6 changes: 2 additions & 4 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_model(compile, meta, compile_kwargs):
return model

est = KerasRegressor(model=get_model, loss=loss, compile=compile)
with pytest.raises(ValueError, match="must provide a loss function"):
with pytest.raises(ValueError, match=r".*(?:provide a loss)|(?:Provide a `loss`).*"):
est.fit([[0], [1]], [0, 1])


Expand All @@ -175,9 +175,7 @@ def get_model(compile, meta, compile_kwargs):
compile=compile,
optimizer=None,
)
with pytest.raises(
ValueError, match="Could not interpret optimizer identifier" # Keras error
):
with pytest.raises(ValueError, match="You must provide an optimizer"):
est.fit([[0], [1]], [0, 1])


Expand Down
12 changes: 4 additions & 8 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_sample_weights_fit():
np.testing.assert_allclose(
actual=estimator1.predict_proba(X),
desired=estimator2.predict_proba(X),
rtol=1e-5,
rtol=1e-4,
)


Expand Down Expand Up @@ -271,14 +271,15 @@ def test_kwargs(wrapper, builder):
kwarg_epochs = (
2 # epochs is a special case for fit since SciKeras also uses it internally
)
extra_kwargs = {"workers": 1} # chosen because it is not a SciKeras hardcoded param
extra_kwargs = {"verbose": True} # chosen because it is not a SciKeras hardcoded param
est = wrapper(
model=builder,
model__hidden_layer_sizes=(100,),
warm_start=True, # for mocking to work properly
batch_size=original_batch_size, # test that this is overridden by kwargs
fit__batch_size=original_batch_size, # test that this is overridden by kwargs
predict__batch_size=original_batch_size, # test that this is overridden by kwargs
verbose=False, # opposite of the extra_kwargs
)
X, y = np.random.random((100, 10)), np.random.randint(low=0, high=3, size=(100,))
est.initialize(X, y)
Expand Down Expand Up @@ -312,12 +313,7 @@ def test_kwargs(wrapper, builder):
# check that params were restored and extra_kwargs were not stored
for param_name in ("batch_size", "fit__batch_size", "predict__batch_size"):
assert getattr(est, param_name) == original_batch_size
for k in extra_kwargs.keys():
assert (
not hasattr(est, k)
or hasattr(est, "fit__" + k)
or hasattr(est, "predict__" + k)
)
assert est.verbose == False


@pytest.mark.parametrize("kwargs", ({"epochs": 1}, {"initial_epoch": 1}))
Expand Down
Loading

0 comments on commit 8f0aace

Please sign in to comment.