Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Keras 3 #317

Merged
merged 24 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/build_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: "3.x"

- name: Install pandoc
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/deploy_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ jobs:
name: Deploy Docs
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: '3.10'

- name: Get GitHub Pages Data
uses: actions/github-script@v3
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ jobs:
steps:
- uses: actions/checkout@v2

- uses: actions/setup-python@v2
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.10'

- name: Install Poetry
run: pip install --upgrade poetry
Expand Down
26 changes: 13 additions & 13 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5

- name: Update packages
run: |
Expand All @@ -33,15 +33,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
tensorflow-version: ["2.12.0"]
python-version: ["3.9", "3.10", "3.11"]
tensorflow-version: ["2.16.1"]
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -93,8 +93,8 @@ jobs:
- name: Install Nightly Versions
if: always()
run: |
poetry run python -m pip install -U tf-nightly
poetry run python -m pip install -U scipy
poetry run python -m pip install -U git+https://github.com/keras-team/keras.git
poetry run python -m pip install -U --pre --extra-index https://pypi.anaconda.org/scipy-wheels-nightly/simple scikit-learn

- name: Test with pytest
Expand All @@ -111,15 +111,15 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
tf-version: [2.12.0]
python-version: ["3.8", "3.9"]
sklearn-version: [1.0.0]
tf-version: ["2.16.1"]
python-version: ["3.9"]
sklearn-version: ["1.4.1.post1"]

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down Expand Up @@ -149,15 +149,15 @@ jobs:
strategy:
matrix:
os: [MacOS, Windows] # test all OSs except Ubuntu, which is already running other tests
python-version: ["3.8", "3.11"] # test only the two extremes of supported Python versions
tensorflow-version: ["2.12.0"] # test only the two extremes of supported TF versions
python-version: ["3.9", "3.11"] # test only the two extremes of supported Python versions
tensorflow-version: ["2.16.1"] # test only the two extremes of supported TF versions
fail-fast: false

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand Down
21 changes: 6 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,28 @@ Scikit-Learn compatible wrappers for Keras Models.

## Why SciKeras

SciKeras is derived from and API compatible with `tf.keras.wrappers.scikit_learn`. The original TensorFlow (TF) wrappers are not actively maintained,
and [will be removed](https://github.com/tensorflow/tensorflow/pull/36137#issuecomment-726271760) in a future release.
SciKeras is derived from and API compatible with the now deprecated / removed `tf.keras.wrappers.scikit_learn`.

An overview of the advantages and differences as compared to the TF wrappers can be found in our
An overview of the differences as compared to the TF wrappers can be found in our
[migration](https://www.adriangb.com/scikeras/stable/migration.html) guide.

## Installation

This package is available on PyPi:

```bash
# Normal tensorflow
# Tensorflow
pip install scikeras[tensorflow]

# or tensorflow-cpu
pip install scikeras[tensorflow-cpu]
```

SciKeras packages TensorFlow as an optional dependency because there are
several flavors of TensorFlow available (`tensorflow`, `tensorflow-cpu`, etc.).
Depending on _one_ of them in particular disallows the usage of the other, which is why
they need to be optional.

`pip install scikeras[tensorflow]` is basically equivalent to `pip install scikeras tensorflow`
Note that `pip install scikeras[tensorflow]` is basically equivalent to `pip install scikeras tensorflow`
and is offered just for convenience. You can also install just SciKeras with
`pip install scikeras`, but you will need a version of tensorflow installed at
runtime or SciKeras will throw an error when you try to import it.

The current version of SciKeras depends on `scikit-learn>=1.0.0` and `TensorFlow>=2.7.0`.
The current version of SciKeras depends on `scikit-learn>=1.4.1post1` and `Keras>=3.2.0`.

### Migrating from `tf.keras.wrappers.scikit_learn`
### Migrating from `keras.wrappers.scikit_learn`

Please see the [migration](https://www.adriangb.com/scikeras/stable/migration.html) section of our documentation.

Expand Down
46 changes: 23 additions & 23 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ on the overall functionality of the wrappers and hence will refer to
Detailed information on usage of specific classes is available in the
:ref:`scikeras-api` documentation.

SciKeras wraps the Keras :py:class:`~tensorflow.keras.Model` to
SciKeras wraps the Keras :py:class:`~keras.Model` to
provide an interface that should be familiar for Scikit-Learn users and is compatible
with most of the Scikit-Learn ecosystem.

To get started, define your :py:class:`~tensorflow.keras.Model` architecture like you always do,
To get started, define your :py:class:`~keras.Model` architecture like you always do,
but within a callable top-level function (we will call this function ``model_build_fn`` for
the remained of these docs, but you are free to name it as you wish).
Then pass this function to :py:class:`.BaseWrapper` in the ``model`` parameter.
Expand All @@ -42,9 +42,9 @@ estimator. The finished code could look something like this:

Let's see what SciKeras did:

- wraps ``tensorflow.keras.Model`` in an sklearn interface
- wraps ``keras.Model`` in an sklearn interface
- handles encoding and decoding of the target ``y``
- compiles the :py:class:`~tensorflow.keras.Model` (unless you do it yourself in ``model_build_fn``)
- compiles the :py:class:`~keras.Model` (unless you do it yourself in ``model_build_fn``)
- makes all ``Keras`` objects serializable so that they can be used with :py:mod:`~sklearn.model_selection`.

SciKeras abstracts away the incompatibilities and data conversions,
Expand Down Expand Up @@ -112,7 +112,7 @@ offer an easy way to compile and tune compilation parameters. Examples:

.. code:: python

from tensorflow.keras.optimizers import Adam
from keras.optimizers import Adam

def model_build_fn():
model = Model(...)
Expand Down Expand Up @@ -164,7 +164,7 @@ see the :ref:`scikeras-api` documentation.

``compile_kwargs``
++++++++++++++++++++++++
This is a dictionary of parameters destined for :py:func:`tensorflow.Keras.Model.compile`.
This is a dictionary of parameters destined for :py:func:`keras.Model.compile`.
This dictionary can be used like ``model.compile(**compile_kwargs)``.
All optimizers, losses and metrics will be compiled to objects,
even if string shorthands (e.g. ``optimizer="adam"``) were passed.
Expand Down Expand Up @@ -192,7 +192,7 @@ To work around this issue, SciKeras implements a data conversion
abstraction in the form of Scikit-Learn style transformers,
one for ``X`` (features) and one for ``y`` (target).
By implementing a custom transformer, you can split a single input ``X`` into multiple inputs
for :py:class:`tensorflow.keras.Model` or perform any other manipulation you need.
for :py:class:`keras.Model` or perform any other manipulation you need.
To override the default transformers, simply override
:py:func:`scikeras.wrappers.BaseWrappers.target_encoder` or
:py:func:`scikeras.wrappers.BaseWrappers.feature_encoder` for ``y`` and ``X`` respectively.
Expand Down Expand Up @@ -248,8 +248,8 @@ All special prefixes are stored in the ``prefixes_`` class attribute
of :py:class:`scikeras.wrappers.BaseWrappers`. Currently, they are:

- ``model__``: passed to ``model_build_fn`` (or whatever function is passed to the ``model`` param of :class:`scikeras.wrappers.BaseWrapper`).
- ``fit__``: passed to :func:`tensorflow.keras.Model.fit`
- ``predict__``: passed to :func:`tensorflow.keras.Model.predict`. Note that internally SciKeras also uses :func:`tensorflow.keras.Model.predict` within :func:`scikeras.wrappers.BaseWrapper.score` and so this prefix applies to both.
- ``fit__``: passed to :func:`keras.Model.fit`
- ``predict__``: passed to :func:`keras.Model.predict`. Note that internally SciKeras also uses :func:`keras.Model.predict` within :func:`scikeras.wrappers.BaseWrapper.score` and so this prefix applies to both.
- ``callbacks__``: used to instantiate callbacks.
- ``optimizer__``: used to instantiate optimizers.
- ``loss__``: used to instantiate losses.
Expand Down Expand Up @@ -280,7 +280,7 @@ Optimizer
.. code:: python

from scikeras.wrappers import KerasClassifier
from tensorflow import keras
import keras

clf = KerasClassifier(
model=model_build_fn,
Expand All @@ -305,7 +305,7 @@ Losses

.. code:: python

from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy
from keras.losses import BinaryCrossentropy, CategoricalCrossentropy

clf = KerasClassifier(
...,
Expand All @@ -322,7 +322,7 @@ Additionally, SciKeras supports routed parameters to each individual loss, or to

.. code:: python

from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy
from keras.losses import BinaryCrossentropy, CategoricalCrossentropy

clf = KerasClassifier(
...,
Expand All @@ -348,7 +348,7 @@ Here are several support use cases:

.. code:: python

from tensorflow.keras.metrics import BinaryAccuracy, AUC
from keras.metrics import BinaryAccuracy, AUC

clf = KerasClassifier(
...,
Expand Down Expand Up @@ -388,7 +388,7 @@ SciKeras can route parameters to callbacks.

clf = KerasClassifier(
...,
callbacks=tf.keras.callbacks.EarlyStopping
callbacks=keras.callbacks.EarlyStopping
callbacks__monitor="loss",
)

Expand All @@ -399,21 +399,21 @@ Just like metrics and losses, callbacks support several syntaxes to compile them
# for multiple callbacks using dict syntax
clf = KerasClassifier(
...,
callbacks={"bl": tf.keras.callbacks.BaseLogger, "es": tf.keras.callbacks.EarlyStopping}
callbacks={"bl": keras.callbacks.BaseLogger, "es": keras.callbacks.EarlyStopping}
callbacks__es__monitor="loss",
)
# or using list sytnax
clf = KerasClassifier(
...,
callbacks=[tf.keras.callbacks.BaseLogger, tf.keras.callbacks.EarlyStopping]
callbacks=[keras.callbacks.BaseLogger, keras.callbacks.EarlyStopping]
callbacks__1__monitor="loss", # EarlyStopping(monitor="loss")
)

Keras callbacks are event based, and are triggered depending on the methods they implement.
For example:

.. code:: python
from tensorflow import keras
import keras

class MyCallback(keras.callbacks.Callback):

Expand All @@ -433,9 +433,9 @@ simply use the ``fit__`` or ``predict__`` routing prefixes on your callback:

clf = KerasClassifier(
...,
callbacks=tf.keras.callbacks.Callback, # called from both fit and predict
fit__callbacks=tf.keras.callbacks.Callback, # called only from fit
predict__callbacks=tf.keras.callbacks.Callback, # called only from predict
callbacks=keras.callbacks.Callback, # called from both fit and predict
fit__callbacks=keras.callbacks.Callback, # called only from fit
predict__callbacks=keras.callbacks.Callback, # called only from predict
)

Any routed constructor parameters must also use the corresponding prefix to get routed correctly.
Expand All @@ -449,7 +449,7 @@ which tells SciKeras to pass that parameter as an positional argument instead of

.. code:: python

from tensorflow import keras
import keras

class Schedule:
"""Exponential decay lr scheduler.
Expand Down Expand Up @@ -478,6 +478,6 @@ as the scoring functions for :class:`scikeras.wrappers.KerasClassifier`
and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these scoring functions,


.. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks
.. _Keras Callbacks docs: https://keras.io/api/callbacks/

.. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics
.. _Keras Metrics docs: https://keras.io/api/metrics/
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Welcome to SciKeras's documentation!

The goal of scikeras is to make it possible to use Keras/TensorFlow with sklearn.
This is achieved by providing a wrapper around Keras that has an Scikit-Learn interface.
SciKeras is the successor to ``tf.keras.wrappers.scikit_learn``, and offers many
SciKeras is the successor to ``keras.wrappers.scikit_learn``, and offers many
improvements over the TensorFlow version of the wrappers. See :ref:`Migration<Migration>` for a more details.

SciKeras tries to make things easy for you while staying out of your way.
Expand Down
16 changes: 6 additions & 10 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,16 @@ To install with pip, run:

.. code:: bash

pip install scikeras[tensorflow]
pip install scikeras

We recommend to use a virtual environment for this.

You will need to manually install TensorFlow; due to TensorFlow's packaging it is not a direct dependency of SciKeras.
You can do this by running:
This will install SciKeras and Keras.
Keras does not automatically install a backend.
For example, to install TensorFlow you can do:

.. code:: bash

pip install tensorflow

This allows you to install an alternative TensorFlow binary, for example `tensorflow-cpu`_.

You can also install SciKeras without any dependencies, for example to install a nightly version of Scikit-Learn:

.. code:: bash
Expand All @@ -34,8 +31,8 @@ You can also install SciKeras without any dependencies, for example to install a

As of SciKeras v0.5.0, the minimum required versions are as follows:

- TensorFlow: v2.7.0
- Scikit-Learn: v1.0.0
- Keras: v3.2.0
- Scikit-Learn: v1.4.1post1

Developer Installation
~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -56,4 +53,3 @@ We use Poetry_ to manage dependencies.


.. _Poetry: https://python-poetry.org/
.. _tensorflow-cpu: https://pypi.org/project/tensorflow-cpu/
Loading
Loading