Skip to content

Commit

Permalink
Make progress on porting to keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Apr 4, 2024
1 parent 8f0aace commit ae671ac
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 186 deletions.
4 changes: 2 additions & 2 deletions scikeras/_saving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def pack_keras_model(
tp = type(model)
out = BytesIO()
if tp not in keras.saving.object_registration.GLOBAL_CUSTOM_OBJECTS:
module = '.'.join(tp.__qualname__.split('.')[:-1])
name = tp.__qualname__.split('.')[-1]
module = ".".join(tp.__qualname__.split(".")[:-1])
name = tp.__qualname__.split(".")[-1]
keras.saving.register_keras_serializable(module, name)(tp)
save_model(model, out)
model_bytes = np.asarray(memoryview(out.getvalue()))
Expand Down
1 change: 0 additions & 1 deletion scikeras/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import inspect
from types import FunctionType
from typing import Any, Callable, Dict, Iterable, Mapping, Sequence, Type, Union

from keras import losses as losses_mod
Expand Down
13 changes: 7 additions & 6 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Set, Tuple, Type, Union

import numpy as np
import tensorflow as tf
import keras
import numpy as np
from keras import losses as losses_module
from keras.models import Model
from scipy.sparse import isspmatrix, lil_matrix
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
from sklearn.exceptions import NotFittedError
Expand All @@ -18,8 +19,6 @@
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import type_of_target
from sklearn.utils.validation import _check_sample_weight, check_array, check_X_y
from keras import losses as losses_module
from keras.models import Model

from scikeras._utils import (
accepts_kwargs,
Expand Down Expand Up @@ -381,7 +380,9 @@ def _get_compile_kwargs(self):
strict=False,
),
)
if compile_kwargs["metrics"] is not None and not isinstance(compile_kwargs['metrics'], (dict, list)):
if compile_kwargs["metrics"] is not None and not isinstance(
compile_kwargs["metrics"], (dict, list)
):
# Keras expects a list or dict of metrics, not a single metric
compile_kwargs["metrics"] = [compile_kwargs["metrics"]]
return compile_kwargs
Expand Down Expand Up @@ -537,7 +538,7 @@ def _fit_keras_model(
self.history_ = defaultdict(list)

for key, val in hist.history.items():
if not (key == 'loss' or key[:4] == 'val_'):
if not (key == "loss" or key[:4] == "val_"):
try:
key = metric_name(key)
except ValueError:
Expand Down
2 changes: 1 addition & 1 deletion tests/multi_output_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List

import numpy as np
from sklearn.utils.multiclass import type_of_target
from keras.backend import floatx as tf_floatx
from sklearn.utils.multiclass import type_of_target

from scikeras.utils.transformers import ClassifierLabelEncoder
from scikeras.wrappers import KerasClassifier
Expand Down
14 changes: 7 additions & 7 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
from functools import partial
from typing import Any, Dict

import keras
import numpy as np
import pytest
from keras import backend as K
from keras import losses as losses_module
from keras import metrics as metrics_module
from keras.layers import Conv2D, Dense, Flatten, Input
from keras.models import Model, Sequential
from keras.utils import to_categorical
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import load_diabetes, load_digits, load_iris
from sklearn.ensemble import (
Expand All @@ -17,13 +24,6 @@
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import keras
from keras import backend as K
from keras import losses as losses_module
from keras import metrics as metrics_module
from keras.layers import Conv2D, Dense, Flatten, Input
from keras.models import Model, Sequential
from keras.utils import to_categorical

from scikeras.wrappers import KerasClassifier, KerasRegressor

Expand Down
4 changes: 2 additions & 2 deletions tests/test_basewrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Test that BaseWrapper for uses other than KerasClassifier and KerasRegressor.
"""
import keras
import numpy as np
from keras import layers
from sklearn.base import TransformerMixin
from sklearn.metrics import mean_squared_error
import keras
from keras import layers

from scikeras.wrappers import BaseWrapper

Expand Down
2 changes: 1 addition & 1 deletion tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict

import pytest
import keras
import pytest
from keras.callbacks import Callback

from scikeras.wrappers import KerasClassifier
Expand Down
Loading

0 comments on commit ae671ac

Please sign in to comment.