From 985a17c48ab2823fc4200b64a5fc2cca7599dbb7 Mon Sep 17 00:00:00 2001 From: a Date: Sat, 8 Oct 2022 14:18:14 +0900 Subject: [PATCH] fix sample_weight --- scikeras/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 5ca51691..dbade3d6 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -503,7 +503,7 @@ def _fit_keras_model( # collect parameters params = self.get_params() fit_args = route_params(params, destination="fit", pass_filter=self._fit_kwargs) - fit_args["sample_weight"] = sample_weight + fit_args["sample_weight"] = [sample_weight] fit_args["epochs"] = initial_epoch + epochs fit_args["initial_epoch"] = initial_epoch fit_args.update(kwargs)