Open
Description
When using the model below for multi-class classification and passing y
targets from LabelBinarizer
, these get detected as task type: multilabel-indicator
instead of task type: multiclass
.
def get_model(hidden_layer_sizes, meta, compile_kwargs):
model = keras.Sequential()
inp = keras.layers.Input(shape=(meta["n_features_in_"]))
model.add(inp)
for hidden_layer_size in hidden_layer_sizes:
layer = keras.layers.Dense(hidden_layer_size, activation="relu")
model.add(layer)
if meta["target_type_"] == "binary":
n_output_units = 1
output_activation = "sigmoid"
loss = "binary_crossentropy"
elif meta["target_type_"] == "multiclass":
n_output_units = meta["n_classes_"]
output_activation = "softmax"
loss = "sparse_categorical_crossentropy"
else:
raise NotImplementedError(f"Unsupported task type: {meta['target_type_']}")
out = keras.layers.Dense(n_output_units, activation=output_activation)
model.add(out)
model.compile(loss=loss, optimizer=compile_kwargs["optimizer"])
return model
from scikeras.wrappers import KerasClassifier
pipeline = make_pipeline(
SimpleImputer(strategy='constant', fill_value='None'),
OneHotEncoder(sparse_output=True, handle_unknown='ignore'),
StandardScaler(copy=False, with_mean=False),
VarianceThreshold(threshold=0.05),
KerasClassifier(
model=get_model,
hidden_layer_sizes=(100,),
optimizer='adam',
epochs=5,
batch_size=15,
verbose=1,
random_state=42,
warm_start=True,
metrics='accuracy'
),
verbose=True
)
label_binarizer = LabelBinarizer(sparse_output=False)
label_binarizer.fit(y_train['attack_type'])
y_train_encoded = label_binarizer.transform(y_train['attack_type'])
label_binarizer.fit(y_test['attack_type'])
y_test_encoded = label_binarizer.transform(y_test['attack_type'])
pipeline.fit(X_train, y_train_encoded)
throws
Cell In[18], line 19, in get_multiclass_network(hidden_layer_sizes, meta, compile_kwargs)
17 loss = "sparse_categorical_crossentropy"
18 else:
---> 19 raise NotImplementedError(f"Unsupported task type: {meta['target_type_']}")
20 out = keras.layers.Dense(n_output_units, activation=output_activation)
21 model.add(out)
NotImplementedError: Unsupported task type: multilabel-indicator
Metadata
Metadata
Assignees
Labels
No labels