|
2 | 2 | "cells": [ |
3 | 3 | { |
4 | 4 | "cell_type": "raw", |
5 | | - "id": "4f595ac8", |
| 5 | + "id": "d2098092", |
6 | 6 | "metadata": {}, |
7 | 7 | "source": [ |
8 | 8 | "<a href=\"https://colab.research.google.com/github/adriangb/scikeras/blob/docs-deploy/refs/heads/master/notebooks/Benchmarks.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Run in Google Colab</a>" |
9 | 9 | ] |
10 | 10 | }, |
11 | 11 | { |
12 | 12 | "cell_type": "markdown", |
13 | | - "id": "ecb3854e", |
| 13 | + "id": "418feec2", |
14 | 14 | "metadata": {}, |
15 | 15 | "source": [ |
16 | 16 | "# SciKeras Benchmarks\n", |
|
31 | 31 | { |
32 | 32 | "cell_type": "code", |
33 | 33 | "execution_count": 1, |
34 | | - "id": "593ee4bf", |
| 34 | + "id": "9a8a6197", |
35 | 35 | "metadata": { |
36 | 36 | "execution": { |
37 | | - "iopub.execute_input": "2024-12-12T21:33:44.339321Z", |
38 | | - "iopub.status.busy": "2024-12-12T21:33:44.339129Z", |
39 | | - "iopub.status.idle": "2024-12-12T21:33:48.703555Z", |
40 | | - "shell.execute_reply": "2024-12-12T21:33:48.702471Z" |
| 37 | + "iopub.execute_input": "2024-12-12T21:44:12.045338Z", |
| 38 | + "iopub.status.busy": "2024-12-12T21:44:12.044960Z", |
| 39 | + "iopub.status.idle": "2024-12-12T21:44:15.912883Z", |
| 40 | + "shell.execute_reply": "2024-12-12T21:44:15.912025Z" |
41 | 41 | } |
42 | 42 | }, |
43 | 43 | "outputs": [ |
|
46 | 46 | "output_type": "stream", |
47 | 47 | "text": [ |
48 | 48 | "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", |
49 | | - "E0000 00:00:1734039224.808030 7945 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", |
50 | | - "E0000 00:00:1734039224.813871 7945 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" |
51 | | - ] |
52 | | - }, |
53 | | - { |
54 | | - "name": "stderr", |
55 | | - "output_type": "stream", |
56 | | - "text": [ |
57 | | - "/home/runner/work/scikeras/scikeras/scikeras/__init__.py:20: UserWarning: \n", |
58 | | - " This project is now deprecated. Keras has re-introduced wrappers with a similar API to those in SciKeras, but they will be better maintained.\n", |
59 | | - " SciKeras was a project to meet a specific need that was developed by a single developer.\n", |
60 | | - " I no longer use Keras nor do I have the time to maintain this project, which became increasingly difficult with multiple versions of Keras and Scikit-Learn to support.\n", |
61 | | - " I thank all of the users and contributors over the years and hope that the new Keras wrappers will meet your needs.\n", |
62 | | - " TODO: add link to Keras docs and release here.\n", |
63 | | - " \n", |
64 | | - " warn(\n" |
| 49 | + "E0000 00:00:1734039852.622019 6721 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", |
| 50 | + "E0000 00:00:1734039852.628823 6721 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" |
65 | 51 | ] |
66 | 52 | } |
67 | 53 | ], |
|
74 | 60 | }, |
75 | 61 | { |
76 | 62 | "cell_type": "markdown", |
77 | | - "id": "17dbc11b", |
| 63 | + "id": "833491fc", |
78 | 64 | "metadata": {}, |
79 | 65 | "source": [ |
80 | 66 | "Silence TensorFlow logging to keep output succinct." |
|
83 | 69 | { |
84 | 70 | "cell_type": "code", |
85 | 71 | "execution_count": 2, |
86 | | - "id": "0a4d3079", |
| 72 | + "id": "c9a078fe", |
87 | 73 | "metadata": { |
88 | 74 | "execution": { |
89 | | - "iopub.execute_input": "2024-12-12T21:33:48.707462Z", |
90 | | - "iopub.status.busy": "2024-12-12T21:33:48.706695Z", |
91 | | - "iopub.status.idle": "2024-12-12T21:33:48.712417Z", |
92 | | - "shell.execute_reply": "2024-12-12T21:33:48.711511Z" |
| 75 | + "iopub.execute_input": "2024-12-12T21:44:15.917190Z", |
| 76 | + "iopub.status.busy": "2024-12-12T21:44:15.916308Z", |
| 77 | + "iopub.status.idle": "2024-12-12T21:44:15.923235Z", |
| 78 | + "shell.execute_reply": "2024-12-12T21:44:15.922092Z" |
93 | 79 | } |
94 | 80 | }, |
95 | 81 | "outputs": [], |
|
103 | 89 | { |
104 | 90 | "cell_type": "code", |
105 | 91 | "execution_count": 3, |
106 | | - "id": "8eca2d91", |
| 92 | + "id": "00b82d8f", |
107 | 93 | "metadata": { |
108 | 94 | "execution": { |
109 | | - "iopub.execute_input": "2024-12-12T21:33:48.715552Z", |
110 | | - "iopub.status.busy": "2024-12-12T21:33:48.715114Z", |
111 | | - "iopub.status.idle": "2024-12-12T21:33:49.454687Z", |
112 | | - "shell.execute_reply": "2024-12-12T21:33:49.453891Z" |
| 95 | + "iopub.execute_input": "2024-12-12T21:44:15.927424Z", |
| 96 | + "iopub.status.busy": "2024-12-12T21:44:15.926213Z", |
| 97 | + "iopub.status.idle": "2024-12-12T21:44:16.486812Z", |
| 98 | + "shell.execute_reply": "2024-12-12T21:44:16.485869Z" |
113 | 99 | } |
114 | 100 | }, |
115 | 101 | "outputs": [], |
|
121 | 107 | }, |
122 | 108 | { |
123 | 109 | "cell_type": "markdown", |
124 | | - "id": "0eaab19f", |
| 110 | + "id": "2c216fe1", |
125 | 111 | "metadata": {}, |
126 | 112 | "source": [ |
127 | 113 | "## 2. Dataset\n", |
|
132 | 118 | { |
133 | 119 | "cell_type": "code", |
134 | 120 | "execution_count": 4, |
135 | | - "id": "d8ee4792", |
| 121 | + "id": "ca09c0ad", |
136 | 122 | "metadata": { |
137 | 123 | "execution": { |
138 | | - "iopub.execute_input": "2024-12-12T21:33:49.458697Z", |
139 | | - "iopub.status.busy": "2024-12-12T21:33:49.458291Z", |
140 | | - "iopub.status.idle": "2024-12-12T21:33:49.726766Z", |
141 | | - "shell.execute_reply": "2024-12-12T21:33:49.726133Z" |
| 124 | + "iopub.execute_input": "2024-12-12T21:44:16.490743Z", |
| 125 | + "iopub.status.busy": "2024-12-12T21:44:16.490241Z", |
| 126 | + "iopub.status.idle": "2024-12-12T21:44:16.840144Z", |
| 127 | + "shell.execute_reply": "2024-12-12T21:44:16.838555Z" |
142 | 128 | } |
143 | 129 | }, |
144 | 130 | "outputs": [], |
|
157 | 143 | }, |
158 | 144 | { |
159 | 145 | "cell_type": "markdown", |
160 | | - "id": "1a8087e2", |
| 146 | + "id": "497036c9", |
161 | 147 | "metadata": {}, |
162 | 148 | "source": [ |
163 | 149 | "## 3. Define Keras Model\n", |
|
168 | 154 | { |
169 | 155 | "cell_type": "code", |
170 | 156 | "execution_count": 5, |
171 | | - "id": "ed313491", |
| 157 | + "id": "ccf7b3ce", |
172 | 158 | "metadata": { |
173 | 159 | "execution": { |
174 | | - "iopub.execute_input": "2024-12-12T21:33:49.729503Z", |
175 | | - "iopub.status.busy": "2024-12-12T21:33:49.729208Z", |
176 | | - "iopub.status.idle": "2024-12-12T21:33:49.736201Z", |
177 | | - "shell.execute_reply": "2024-12-12T21:33:49.735017Z" |
| 160 | + "iopub.execute_input": "2024-12-12T21:44:16.843737Z", |
| 161 | + "iopub.status.busy": "2024-12-12T21:44:16.843410Z", |
| 162 | + "iopub.status.idle": "2024-12-12T21:44:16.850249Z", |
| 163 | + "shell.execute_reply": "2024-12-12T21:44:16.849087Z" |
178 | 164 | } |
179 | 165 | }, |
180 | 166 | "outputs": [], |
|
204 | 190 | }, |
205 | 191 | { |
206 | 192 | "cell_type": "markdown", |
207 | | - "id": "f2977461", |
| 193 | + "id": "fe064f93", |
208 | 194 | "metadata": {}, |
209 | 195 | "source": [ |
210 | 196 | "## 4. Keras benchmarks" |
|
213 | 199 | { |
214 | 200 | "cell_type": "code", |
215 | 201 | "execution_count": 6, |
216 | | - "id": "5134057a", |
| 202 | + "id": "285d78c3", |
217 | 203 | "metadata": { |
218 | 204 | "execution": { |
219 | | - "iopub.execute_input": "2024-12-12T21:33:49.739648Z", |
220 | | - "iopub.status.busy": "2024-12-12T21:33:49.739224Z", |
221 | | - "iopub.status.idle": "2024-12-12T21:33:49.745228Z", |
222 | | - "shell.execute_reply": "2024-12-12T21:33:49.744110Z" |
| 205 | + "iopub.execute_input": "2024-12-12T21:44:16.853253Z", |
| 206 | + "iopub.status.busy": "2024-12-12T21:44:16.852708Z", |
| 207 | + "iopub.status.idle": "2024-12-12T21:44:16.857705Z", |
| 208 | + "shell.execute_reply": "2024-12-12T21:44:16.856708Z" |
223 | 209 | } |
224 | 210 | }, |
225 | 211 | "outputs": [], |
|
230 | 216 | { |
231 | 217 | "cell_type": "code", |
232 | 218 | "execution_count": 7, |
233 | | - "id": "4a74c804", |
| 219 | + "id": "83eff969", |
234 | 220 | "metadata": { |
235 | 221 | "execution": { |
236 | | - "iopub.execute_input": "2024-12-12T21:33:49.749067Z", |
237 | | - "iopub.status.busy": "2024-12-12T21:33:49.747670Z", |
238 | | - "iopub.status.idle": "2024-12-12T21:33:49.754258Z", |
239 | | - "shell.execute_reply": "2024-12-12T21:33:49.753204Z" |
| 222 | + "iopub.execute_input": "2024-12-12T21:44:16.860753Z", |
| 223 | + "iopub.status.busy": "2024-12-12T21:44:16.860272Z", |
| 224 | + "iopub.status.idle": "2024-12-12T21:44:16.864872Z", |
| 225 | + "shell.execute_reply": "2024-12-12T21:44:16.863931Z" |
240 | 226 | } |
241 | 227 | }, |
242 | 228 | "outputs": [], |
|
248 | 234 | { |
249 | 235 | "cell_type": "code", |
250 | 236 | "execution_count": 8, |
251 | | - "id": "a01b3bea", |
| 237 | + "id": "99dff865", |
252 | 238 | "metadata": { |
253 | 239 | "execution": { |
254 | | - "iopub.execute_input": "2024-12-12T21:33:49.757809Z", |
255 | | - "iopub.status.busy": "2024-12-12T21:33:49.756784Z", |
256 | | - "iopub.status.idle": "2024-12-12T21:33:55.699278Z", |
257 | | - "shell.execute_reply": "2024-12-12T21:33:55.698767Z" |
| 240 | + "iopub.execute_input": "2024-12-12T21:44:16.868299Z", |
| 241 | + "iopub.status.busy": "2024-12-12T21:44:16.868003Z", |
| 242 | + "iopub.status.idle": "2024-12-12T21:44:23.808386Z", |
| 243 | + "shell.execute_reply": "2024-12-12T21:44:23.807184Z" |
258 | 244 | } |
259 | 245 | }, |
260 | 246 | "outputs": [ |
261 | 247 | { |
262 | 248 | "name": "stdout", |
263 | 249 | "output_type": "stream", |
264 | 250 | "text": [ |
265 | | - "Training time: 5.69\n", |
| 251 | + "Training time: 6.47\n", |
266 | 252 | "\r", |
267 | | - "\u001b[1m 1/16\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 48ms/step" |
| 253 | + "\u001b[1m 1/16\u001b[0m \u001b[32m━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m1s\u001b[0m 78ms/step" |
268 | 254 | ] |
269 | 255 | }, |
270 | 256 | { |
271 | 257 | "name": "stdout", |
272 | 258 | "output_type": "stream", |
273 | 259 | "text": [ |
274 | 260 | "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", |
275 | | - "\u001b[1m16/16\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step " |
| 261 | + "\u001b[1m10/16\u001b[0m \u001b[32m━━━━━━━━━━━━\u001b[0m\u001b[37m━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 6ms/step " |
| 262 | + ] |
| 263 | + }, |
| 264 | + { |
| 265 | + "name": "stdout", |
| 266 | + "output_type": "stream", |
| 267 | + "text": [ |
| 268 | + "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", |
| 269 | + "\u001b[1m16/16\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 9ms/step" |
276 | 270 | ] |
277 | 271 | }, |
278 | 272 | { |
279 | 273 | "name": "stdout", |
280 | 274 | "output_type": "stream", |
281 | 275 | "text": [ |
282 | 276 | "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r", |
283 | | - "\u001b[1m16/16\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step\n" |
| 277 | + "\u001b[1m16/16\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 9ms/step\n" |
284 | 278 | ] |
285 | 279 | }, |
286 | 280 | { |
|
305 | 299 | }, |
306 | 300 | { |
307 | 301 | "cell_type": "markdown", |
308 | | - "id": "5df17b4c", |
| 302 | + "id": "912da84c", |
309 | 303 | "metadata": {}, |
310 | 304 | "source": [ |
311 | 305 | "## 5. SciKeras benchmark" |
|
314 | 308 | { |
315 | 309 | "cell_type": "code", |
316 | 310 | "execution_count": 9, |
317 | | - "id": "e8291e96", |
| 311 | + "id": "5850890a", |
318 | 312 | "metadata": { |
319 | 313 | "execution": { |
320 | | - "iopub.execute_input": "2024-12-12T21:33:55.701227Z", |
321 | | - "iopub.status.busy": "2024-12-12T21:33:55.700867Z", |
322 | | - "iopub.status.idle": "2024-12-12T21:33:55.703738Z", |
323 | | - "shell.execute_reply": "2024-12-12T21:33:55.703276Z" |
| 314 | + "iopub.execute_input": "2024-12-12T21:44:23.811749Z", |
| 315 | + "iopub.status.busy": "2024-12-12T21:44:23.811419Z", |
| 316 | + "iopub.status.idle": "2024-12-12T21:44:23.816069Z", |
| 317 | + "shell.execute_reply": "2024-12-12T21:44:23.815174Z" |
324 | 318 | } |
325 | 319 | }, |
326 | 320 | "outputs": [], |
|
335 | 329 | { |
336 | 330 | "cell_type": "code", |
337 | 331 | "execution_count": 10, |
338 | | - "id": "00f2669d", |
| 332 | + "id": "a3644eb5", |
339 | 333 | "metadata": { |
340 | 334 | "execution": { |
341 | | - "iopub.execute_input": "2024-12-12T21:33:55.705565Z", |
342 | | - "iopub.status.busy": "2024-12-12T21:33:55.705115Z", |
343 | | - "iopub.status.idle": "2024-12-12T21:33:58.980436Z", |
344 | | - "shell.execute_reply": "2024-12-12T21:33:58.979833Z" |
| 335 | + "iopub.execute_input": "2024-12-12T21:44:23.821206Z", |
| 336 | + "iopub.status.busy": "2024-12-12T21:44:23.818791Z", |
| 337 | + "iopub.status.idle": "2024-12-12T21:44:29.681714Z", |
| 338 | + "shell.execute_reply": "2024-12-12T21:44:29.681188Z" |
345 | 339 | } |
346 | 340 | }, |
347 | 341 | "outputs": [ |
348 | 342 | { |
349 | 343 | "name": "stdout", |
350 | 344 | "output_type": "stream", |
351 | 345 | "text": [ |
352 | | - "Training time: 3.12\n", |
| 346 | + "Training time: 5.72\n", |
353 | 347 | "Accuracy: 0.89\n" |
354 | 348 | ] |
355 | 349 | } |
|
364 | 358 | }, |
365 | 359 | { |
366 | 360 | "cell_type": "markdown", |
367 | | - "id": "9215d012", |
| 361 | + "id": "b44c12f4", |
368 | 362 | "metadata": {}, |
369 | 363 | "source": [ |
370 | 364 | "As you can see, the overhead for SciKeras is <1 sec, and the accuracy is identical." |
|
0 commit comments