Skip to content

Commit

Permalink
Support for Keras 3 (#317)
Browse files Browse the repository at this point in the history
* Drops support for older versions of Keras
* Expected to require followups, please open issues
  • Loading branch information
adriangb authored Apr 10, 2024
1 parent 9ec5ca6 commit 390789f
Show file tree
Hide file tree
Showing 41 changed files with 491 additions and 624 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: "3.x"

- name: Install pandoc
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/deploy_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
name: Deploy Docs
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: '3.10'

- name: Get GitHub Pages Data
uses: actions/github-script@v3
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
steps:
- uses: actions/checkout@v2

- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.10'

- name: Install Poetry
run: pip install --upgrade poetry
Expand Down
26 changes: 13 additions & 13 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5

- name: Update packages
run: |
Expand All @@ -33,15 +33,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
tensorflow-version: ["2.12.0"]
python-version: ["3.9", "3.10", "3.11"]
tensorflow-version: ["2.16.1"]
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -93,8 +93,8 @@ jobs:
- name: Install Nightly Versions
if: always()
run: |
poetry run python -m pip install -U tf-nightly
poetry run python -m pip install -U scipy
poetry run python -m pip install -U git+https://github.com/keras-team/keras.git
poetry run python -m pip install -U --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn
- name: Test with pytest
Expand All @@ -111,15 +111,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
tf-version: [2.12.0]
python-version: ["3.8", "3.9"]
sklearn-version: [1.0.0]
tf-version: ["2.16.1"]
python-version: ["3.9"]
sklearn-version: ["1.4.1.post1"]

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -149,15 +149,15 @@ jobs:
strategy:
matrix:
os: [MacOS, Windows] # test all OSs except Ubuntu, which is already running other tests
python-version: ["3.8", "3.11"] # test only the two extremes of supported Python versions
tensorflow-version: ["2.12.0"] # test only the two extremes of supported TF versions
python-version: ["3.9", "3.11"] # test only the two extremes of supported Python versions
tensorflow-version: ["2.16.1"] # test only the two extremes of supported TF versions
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
21 changes: 6 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,28 @@ Scikit-Learn compatible wrappers for Keras Models.

## Why SciKeras

SciKeras is derived from and API compatible with `tf.keras.wrappers.scikit_learn`. The original TensorFlow (TF) wrappers are not actively maintained,
and [will be removed](https://github.com/tensorflow/tensorflow/pull/36137#issuecomment-726271760) in a future release.
SciKeras is derived from and API compatible with the now deprecated / removed `tf.keras.wrappers.scikit_learn`.

An overview of the advantages and differences as compared to the TF wrappers can be found in our
An overview of the differences as compared to the TF wrappers can be found in our
[migration](https://www.adriangb.com/scikeras/stable/migration.html) guide.

## Installation

This package is available on PyPi:

```bash
# Normal tensorflow
# Tensorflow
pip install scikeras[tensorflow]

# or tensorflow-cpu
pip install scikeras[tensorflow-cpu]
```

SciKeras packages TensorFlow as an optional dependency because there are
several flavors of TensorFlow available (`tensorflow`, `tensorflow-cpu`, etc.).
Depending on _one_ of them in particular disallows the usage of the other, which is why
they need to be optional.

`pip install scikeras[tensorflow]` is basically equivalent to `pip install scikeras tensorflow`
Note that `pip install scikeras[tensorflow]` is basically equivalent to `pip install scikeras tensorflow`
and is offered just for convenience. You can also install just SciKeras with
`pip install scikeras`, but you will need a version of tensorflow installed at
runtime or SciKeras will throw an error when you try to import it.

The current version of SciKeras depends on `scikit-learn>=1.0.0` and `TensorFlow>=2.7.0`.
The current version of SciKeras depends on `scikit-learn>=1.4.1post1` and `Keras>=3.2.0`.

### Migrating from `tf.keras.wrappers.scikit_learn`
### Migrating from `keras.wrappers.scikit_learn`

Please see the [migration](https://www.adriangb.com/scikeras/stable/migration.html) section of our documentation.

Expand Down
46 changes: 23 additions & 23 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ on the overall functionality of the wrappers and hence will refer to
Detailed information on usage of specific classes is available in the
:ref:`scikeras-api` documentation.

SciKeras wraps the Keras :py:class:`~tensorflow.keras.Model` to
SciKeras wraps the Keras :py:class:`~keras.Model` to
provide an interface that should be familiar for Scikit-Learn users and is compatible
with most of the Scikit-Learn ecosystem.

To get started, define your :py:class:`~tensorflow.keras.Model` architecture like you always do,
To get started, define your :py:class:`~keras.Model` architecture like you always do,
but within a callable top-level function (we will call this function ``model_build_fn`` for
the remained of these docs, but you are free to name it as you wish).
Then pass this function to :py:class:`.BaseWrapper` in the ``model`` parameter.
Expand All @@ -42,9 +42,9 @@ estimator. The finished code could look something like this:
Let's see what SciKeras did:

- wraps ``tensorflow.keras.Model`` in an sklearn interface
- wraps ``keras.Model`` in an sklearn interface
- handles encoding and decoding of the target ``y``
- compiles the :py:class:`~tensorflow.keras.Model` (unless you do it yourself in ``model_build_fn``)
- compiles the :py:class:`~keras.Model` (unless you do it yourself in ``model_build_fn``)
- makes all ``Keras`` objects serializable so that they can be used with :py:mod:`~sklearn.model_selection`.

SciKeras abstracts away the incompatibilities and data conversions,
Expand Down Expand Up @@ -112,7 +112,7 @@ offer an easy way to compile and tune compilation parameters. Examples:
.. code:: python
from tensorflow.keras.optimizers import Adam
from keras.optimizers import Adam
def model_build_fn():
model = Model(...)
Expand Down Expand Up @@ -164,7 +164,7 @@ see the :ref:`scikeras-api` documentation.

``compile_kwargs``
++++++++++++++++++++++++
This is a dictionary of parameters destined for :py:func:`tensorflow.Keras.Model.compile`.
This is a dictionary of parameters destined for :py:func:`keras.Model.compile`.
This dictionary can be used like ``model.compile(**compile_kwargs)``.
All optimizers, losses and metrics will be compiled to objects,
even if string shorthands (e.g. ``optimizer="adam"``) were passed.
Expand Down Expand Up @@ -192,7 +192,7 @@ To work around this issue, SciKeras implements a data conversion
abstraction in the form of Scikit-Learn style transformers,
one for ``X`` (features) and one for ``y`` (target).
By implementing a custom transformer, you can split a single input ``X`` into multiple inputs
for :py:class:`tensorflow.keras.Model` or perform any other manipulation you need.
for :py:class:`keras.Model` or perform any other manipulation you need.
To override the default transformers, simply override
:py:func:`scikeras.wrappers.BaseWrappers.target_encoder` or
:py:func:`scikeras.wrappers.BaseWrappers.feature_encoder` for ``y`` and ``X`` respectively.
Expand Down Expand Up @@ -248,8 +248,8 @@ All special prefixes are stored in the ``prefixes_`` class attribute
of :py:class:`scikeras.wrappers.BaseWrappers`. Currently, they are:

- ``model__``: passed to ``model_build_fn`` (or whatever function is passed to the ``model`` param of :class:`scikeras.wrappers.BaseWrapper`).
- ``fit__``: passed to :func:`tensorflow.keras.Model.fit`
- ``predict__``: passed to :func:`tensorflow.keras.Model.predict`. Note that internally SciKeras also uses :func:`tensorflow.keras.Model.predict` within :func:`scikeras.wrappers.BaseWrapper.score` and so this prefix applies to both.
- ``fit__``: passed to :func:`keras.Model.fit`
- ``predict__``: passed to :func:`keras.Model.predict`. Note that internally SciKeras also uses :func:`keras.Model.predict` within :func:`scikeras.wrappers.BaseWrapper.score` and so this prefix applies to both.
- ``callbacks__``: used to instantiate callbacks.
- ``optimizer__``: used to instantiate optimizers.
- ``loss__``: used to instantiate losses.
Expand Down Expand Up @@ -280,7 +280,7 @@ Optimizer
.. code:: python
from scikeras.wrappers import KerasClassifier
from tensorflow import keras
import keras
clf = KerasClassifier(
model=model_build_fn,
Expand All @@ -305,7 +305,7 @@ Losses

.. code:: python
from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy
from keras.losses import BinaryCrossentropy, CategoricalCrossentropy
clf = KerasClassifier(
...,
Expand All @@ -322,7 +322,7 @@ Additionally, SciKeras supports routed parameters to each individual loss, or to

.. code:: python
from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy
from keras.losses import BinaryCrossentropy, CategoricalCrossentropy
clf = KerasClassifier(
...,
Expand All @@ -348,7 +348,7 @@ Here are several support use cases:

.. code:: python
from tensorflow.keras.metrics import BinaryAccuracy, AUC
from keras.metrics import BinaryAccuracy, AUC
clf = KerasClassifier(
...,
Expand Down Expand Up @@ -388,7 +388,7 @@ SciKeras can route parameters to callbacks.
clf = KerasClassifier(
...,
callbacks=tf.keras.callbacks.EarlyStopping
callbacks=keras.callbacks.EarlyStopping
callbacks__monitor="loss",
)
Expand All @@ -399,21 +399,21 @@ Just like metrics and losses, callbacks support several syntaxes to compile them
# for multiple callbacks using dict syntax
clf = KerasClassifier(
...,
callbacks={"bl": tf.keras.callbacks.BaseLogger, "es": tf.keras.callbacks.EarlyStopping}
callbacks={"bl": keras.callbacks.BaseLogger, "es": keras.callbacks.EarlyStopping}
callbacks__es__monitor="loss",
)
# or using list sytnax
clf = KerasClassifier(
...,
callbacks=[tf.keras.callbacks.BaseLogger, tf.keras.callbacks.EarlyStopping]
callbacks=[keras.callbacks.BaseLogger, keras.callbacks.EarlyStopping]
callbacks__1__monitor="loss", # EarlyStopping(monitor="loss")
)
Keras callbacks are event based, and are triggered depending on the methods they implement.
For example:

.. code:: python
from tensorflow import keras
import keras
class MyCallback(keras.callbacks.Callback):
Expand All @@ -433,9 +433,9 @@ simply use the ``fit__`` or ``predict__`` routing prefixes on your callback:

clf = KerasClassifier(
...,
callbacks=tf.keras.callbacks.Callback, # called from both fit and predict
fit__callbacks=tf.keras.callbacks.Callback, # called only from fit
predict__callbacks=tf.keras.callbacks.Callback, # called only from predict
callbacks=keras.callbacks.Callback, # called from both fit and predict
fit__callbacks=keras.callbacks.Callback, # called only from fit
predict__callbacks=keras.callbacks.Callback, # called only from predict
)

Any routed constructor parameters must also use the corresponding prefix to get routed correctly.
Expand All @@ -449,7 +449,7 @@ which tells SciKeras to pass that parameter as an positional argument instead of

.. code:: python
from tensorflow import keras
import keras
class Schedule:
"""Exponential decay lr scheduler.
Expand Down Expand Up @@ -478,6 +478,6 @@ as the scoring functions for :class:`scikeras.wrappers.KerasClassifier`
and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these scoring functions,


.. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks
.. _Keras Callbacks docs: https://keras.io/api/callbacks/

.. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics
.. _Keras Metrics docs: https://keras.io/api/metrics/
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Welcome to SciKeras's documentation!

The goal of scikeras is to make it possible to use Keras/TensorFlow with sklearn.
This is achieved by providing a wrapper around Keras that has an Scikit-Learn interface.
SciKeras is the successor to ``tf.keras.wrappers.scikit_learn``, and offers many
SciKeras is the successor to ``keras.wrappers.scikit_learn``, and offers many
improvements over the TensorFlow version of the wrappers. See :ref:`Migration<Migration>` for a more details.

SciKeras tries to make things easy for you while staying out of your way.
Expand Down
16 changes: 6 additions & 10 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,16 @@ To install with pip, run:

.. code:: bash
pip install scikeras[tensorflow]
pip install scikeras
We recommend to use a virtual environment for this.

You will need to manually install TensorFlow; due to TensorFlow's packaging it is not a direct dependency of SciKeras.
You can do this by running:
This will install SciKeras and Keras.
Keras does not automatically install a backend.
For example, to install TensorFlow you can do:

.. code:: bash
pip install tensorflow
This allows you to install an alternative TensorFlow binary, for example `tensorflow-cpu`_.

You can also install SciKeras without any dependencies, for example to install a nightly version of Scikit-Learn:

.. code:: bash
Expand All @@ -34,8 +31,8 @@ You can also install SciKeras without any dependencies, for example to install a
As of SciKeras v0.5.0, the minimum required versions are as follows:

- TensorFlow: v2.7.0
- Scikit-Learn: v1.0.0
- Keras: v3.2.0
- Scikit-Learn: v1.4.1post1

Developer Installation
~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -56,4 +53,3 @@ We use Poetry_ to manage dependencies.
.. _Poetry: https://python-poetry.org/
.. _tensorflow-cpu: https://pypi.org/project/tensorflow-cpu/
Loading

0 comments on commit 390789f

Please sign in to comment.