Skip to content

Latest commit

 

History

History

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Keras YOLOV8


Summary


Detection Models

Model Params FLOPs Input COCO val AP test AP Download
YOLOV8_N 3.16M 4.39G 640 37.3 yolov8_n_coco.h5
YOLOV8_S 11.17M 14.33G 640 44.9 yolov8_s_coco.h5
YOLOV8_M 25.90M 39.52G 640 50.2 yolov8_m_coco.h5
YOLOV8_L 43.69M 82.65G 640 52.9 yolov8_l_coco.h5
YOLOV8_X 68.23M 129.0G 640 53.9 yolov8_x_coco.h5
YOLOV8_X6 97.42M 522.6G 1280 56.7 ? yolov8_x6_coco.h5
Model Params FLOPs Input COCO val AP test AP Download
YOLO_NAS_S 12.88M 16.96G 640 47.5 s_before_reparam.h5
- use_reparam_conv=False 12.18M 15.92G 640 47.5 yolo_nas_s_coco.h5
YOLO_NAS_M 33.86M 47.12G 640 51.55 m_before_reparam.h5
- use_reparam_conv=False 31.92M 43.91G 640 51.55 yolo_nas_m_coco.h5
YOLO_NAS_L 44.53M 64.53G 640 52.22 l_before_reparam.h5
- use_reparam_conv=False 42.02M 59.95G 640 52.22 yolo_nas_l_coco.h5

Classification Models

Model Params FLOPs@640 FLOPs@224 Input Top1 Acc Download
YOLOV8_N_CLS 2.72M 1.65G 203.7M 224 66.6 yolov8_n_cls.h5
YOLOV8_S_CLS 6.36M 6.24G 765.7M 224 72.3 yolov8_s_cls.h5
YOLOV8_M_CLS 17.05M 20.85G 2.56G 224 76.4 yolov8_m_cls.h5
YOLOV8_L_CLS 37.48M 49.41G 6.05G 224 78.0 yolov8_l_cls.h5
YOLOV8_X_CLS 57.42M 76.96G 9.43G 224 78.4 yolov8_x_cls.h5

Segmentation Models

Model Params FLOPs Input COCO val mask AP Download
YOLOV8_N_SEG 3.41M 6.02G 640 30.5 yolov8_n_seg.h5
YOLOV8_S_SEG 11.82M 20.08G 640 36.8 yolov8_s_seg.h5
YOLOV8_M_SEG 27.29M 52.33G 640 40.8 yolov8_m_seg.h5
YOLOV8_L_SEG 46.00M 105.29G 640 42.6 yolov8_l_seg.h5
YOLOV8_X_SEG 71.83M 164.30G 640 43.4 yolov8_x_seg.h5

Usage

  • Basic usage
    from keras_cv_attention_models import yolov8
    model = yolov8.YOLOV8_N(pretrained="coco")
    
    # Run prediction
    from keras_cv_attention_models import test_images
    imm = test_images.dog_cat()
    preds = model(model.preprocess_input(imm))
    bboxs, lables, confidences = model.decode_predictions(preds)[0]
    
    # Show result
    from keras_cv_attention_models.coco import data
    data.show_image_with_bboxes(imm, bboxs, lables, confidences)
    yolov8_n_dog_cat
  • Use dynamic input resolution by set input_shape=(None, None, 3). Note: For YOLO_NAS models, actual input shape needs to be divisible by 32.
    from keras_cv_attention_models import yolov8, test_images, coco, plot_func
    model = yolov8.YOLOV8_S(input_shape=(None, None, 3), pretrained="coco")
    # >>>> Load pretrained from: ~/.keras/models/yolov8_s_coco.h5
    print(model.input_shape, model.output_shape)
    # (None, None, None, 3) (None, None, 144)
    print(model(tf.ones([1, 742, 355, 3])).shape)
    # (1, 5554, 144)
    print(model(tf.ones([1, 188, 276, 3])).shape)
    # (1, 1110, 144)
    
    imm = test_images.dog_cat()
    input_shape = (320, 224, 3)
    preds = model(model.preprocess_input(imm, input_shape=input_shape))
    bboxs, lables, confidences = model.decode_predictions(preds, input_shape=input_shape)[0]
    
    # Show result
    plot_func.show_image_with_bboxes(imm, bboxs, lables, confidences, num_classes=80)
    yolov8_s_dynamic_dog_cat
  • Classification model
    from keras_cv_attention_models.yolov8 import yolov8
    model = yolov8.YOLOV8_N_CLS(pretrained="imagenet")
    
    # Run prediction
    from skimage.data import chelsea # Chelsea the cat
    preds = model(model.preprocess_input(chelsea()))
    print(model.decode_predictions(preds))
    # [('n02124075', 'Egyptian_cat', 0.2490207), ('n02123045', 'tabby', 0.12989485), ...]
  • Segmentation model using dynamic input resolution.
    from keras_cv_attention_models import yolov8, test_images, coco, plot_func
    mm = yolov8.YOLOV8_S_SEG(pretrained="coco", input_shape=[None, None, 3])
    print(mm.input_shape, mm.output_shape)
    # (None, None, None, 3) ((None, None, 176), (None, None, None, 32))
    
    image = test_images.dog_cat()
    input_shape = (320, 608, 3)
    preds, mask_protos = mm.predict(mm.preprocess_input(image, input_shape=input_shape))
    bboxes, labels, confidences, masks = mm.decode_predictions(preds, mask_protos=mask_protos, input_shape=input_shape)[0]
    _ = plot_func.show_image_with_bboxes_and_masks(image, bboxes, labels, confidences, masks=masks)
    yolov8_s_dynamic_segment_dog_cat
  • Switch to deploy by calling model.switch_to_deploy() if using use_reparam_conv=True. Will fuse reparameter block into a single Conv2D layer. Also applying convert_to_fused_conv_bn_model that fusing Conv2D->BatchNorm.
    from keras_cv_attention_models import yolov8, test_images, model_surgery
    
    mm = yolov8.YOLO_NAS_S(use_reparam_conv=True)
    model_surgery.count_params(mm)
    # Total params: 12,911,584.0 | Trainable params: 12,878,304.0 | Non-trainable params:33,280.0
    preds = mm(mm.preprocess_input(test_images.dog_cat()))
    
    bb = mm.switch_to_deploy()
    model_surgery.count_params(bb)
    # Total params: 12,167,600.0 | Trainable params: 12,167,600.0 | Non-trainable params:0.0
    preds_deploy = bb(bb.preprocess_input(test_images.dog_cat()))
    
    print(f"{np.allclose(preds, preds_deploy, atol=1e-3) = }")
    # np.allclose(preds, preds_deploy, atol=1e-3) = True
  • Using PyTorch backend by set KECAM_BACKEND='torch' environment variable.
    os.environ['KECAM_BACKEND'] = 'torch'
    
    from keras_cv_attention_models import yolov8
    model = yolov8.YOLOV8_S(input_shape=(None, None, 3), pretrained="coco")
    # >>>> Using PyTorch backend
    # >>>> Aligned input_shape: [3, None, None]
    # >>>> Load pretrained from: ~/.keras/models/yolov8_s_coco.h5
    
    print(model.input_shape, model.output_shape)
    # (None, 3, None, None) (None, None, 144)
    
    import torch
    print(model(torch.ones([1, 3, 736, 352])).shape)
    # torch.Size([1, 5313, 144])
    
    from keras_cv_attention_models import test_images
    imm = test_images.dog_cat()
    input_shape = (320, 224, 3)
    preds = model(model.preprocess_input(imm, input_shape=input_shape))
    bboxs, lables, confidences = model.decode_predictions(preds, input_shape=input_shape)[0]
    
    # Show result
    from keras_cv_attention_models.coco import data
    data.show_image_with_bboxes(imm, bboxs, lables, confidences, num_classes=80)

Custom detector using YOLOV8 header

  • Backbone for YOLOV8 can be any model with pyramid stage structure. NOTE: YOLOV8 has a default regression_len=64 for bbox output length. Typically it's 4 for other detection models, for yolov8 it's reg_max=16 -> regression_len = 16 * 4 == 64.

    from keras_cv_attention_models import efficientnet, yolov8
    bb = efficientnet.EfficientNetV2B1(input_shape=(256, 256, 3), num_classes=0)
    mm = yolov8.YOLOV8(backbone=bb)
    # >>>> features: {'stack_2_block2_output': (None, 32, 32, 48),
    #                 'stack_4_block5_output': (None, 16, 16, 112),
    #                 'stack_5_block8_output': (None, 8, 8, 192)}
    
    mm.summary()  # Trainable params: 8,025,252
    print(mm.output_shape)
    # (None, 1344, 144)
  • Currently 4 types anchors supported, parameter anchors_mode controls which anchor to use, value in ["efficientdet", "anchor_free", "yolor", "yolov8"]. Default is "yolov8".

    from keras_cv_attention_models import efficientnet, yolov8
    bb = efficientnet.EfficientNetV2B1(input_shape=(256, 256, 3), num_classes=0)
    
    mm = yolov8.YOLOV8(backbone=bb, anchors_mode="anchor_free", regression_len=4) # Trainable params: 7,756,707
    print(mm.output_shape) # (None, 1344, 85)
    
    mm = yolov8.YOLOV8(backbone=bb, anchors_mode="efficientdet", regression_len=64) # Trainable params: 8,280,612
    print(mm.output_shape) # (None, 1344, 1296) -> 1296 == num_anchors 9 * (regression_len 64 + num_classes 80)

    Default settings for anchors_mode

    anchors_mode use_object_scores num_anchors anchor_scale aspect_ratios num_scales grid_zero_start
    efficientdet False 9 4 [1, 2, 0.5] 3 False
    anchor_free True 1 1 [1] 1 True
    yolor True 3 None presets None offset=0.5
    yolov8 False 1 1 [1] 1 False

Verification with PyTorch version

inputs = np.random.uniform(size=(1, 640, 640, 3)).astype("float32")

""" PyTorch yolov8n """
sys.path.append('../ultralytics')
import torch

tt = torch.load('yolov8n.pt')
_ = tt['model'].eval()
torch_model = tt['model'].float()
_, torch_out = torch_model(torch.from_numpy(inputs).permute([0, 3, 1, 2]))
torch_out_concat = [ii.reshape([1, ii.shape[1], -1]) for ii in torch_out]
torch_out_concat = torch.concat(torch_out_concat, axis=-1).permute([0, 2, 1])

""" Keras YOLOV8_N """
from keras_cv_attention_models import yolov8
mm = yolov8.YOLOV8_N(pretrained='coco', classifier_activation=None)
keras_out = mm(inputs)

""" Model outputs verification """
# [top, left, bottom, right] -> [left, top, right, bottom]
bbox_out, cls_out = tf.split(keras_out, [64, 80], axis=-1)
bbox_out = tf.gather(tf.reshape(bbox_out, [1, -1, 4, 16]), [1, 0, 3, 2], axis=-2)
bbox_out = tf.reshape(bbox_out, [1, -1, 4 * 16])
keras_out_reorder = tf.concat([bbox_out, cls_out], axis=-1)
print(f"{np.allclose(keras_out_reorder, torch_out_concat.detach(), atol=1e-4) = }")
# np.allclose(keras_out_reorder, torch_out_concat.detach(), atol=1e-4) = True

Segmentation model

inputs = np.random.uniform(size=(1, 640, 640, 3)).astype("float32")

""" PyTorch yolov8n-seg """
sys.path.append('../ultralytics')
import torch

tt = torch.load('yolov8n-seg.pt')
_ = tt['model'].eval()
torch_model = tt['model'].float()
torch_out = torch_model(torch.from_numpy(inputs).permute([0, 3, 1, 2]))

""" Keras YOLOV8_N """
from keras_cv_attention_models import yolov8, coco
mm = yolov8.YOLOV8_N_SEG(pretrained='yolov8_n_seg_coco.h5', classifier_activation='sigmoid')
keras_out, keras_masks = mm(inputs)

""" Model outputs verification """
anchors = coco.get_anchor_free_anchors(input_shape=[640, 640], pyramid_levels=[3, 5], grid_zero_start=False)
keras_bbox_decoded = coco.decode_bboxes(keras_out, anchors, regression_len=64, return_centers=True).numpy()
# [top, left, bottom, right] -> [left, top, right, bottom]
print(f"{np.allclose(torch_out[0][:, :4].permute([0, 2, 1]) / 640, keras_bbox_decoded[:, :, [1, 0, 3, 2]], atol=1e-5) = }")
# np.allclose(torch_out[0][:, :4].permute([0, 2, 1]) / 640, keras_bbox_decoded[:, :, [1, 0, 3, 2]], atol=1e-5) = True
print(f"{np.allclose(torch_out[0][:, 4:].permute([0, 2, 1]), keras_bbox_decoded[:, :, 4:], atol=1e-5) = }")
# np.allclose(torch_out[0][:, 4:].permute([0, 2, 1]), keras_bbox_decoded[:, :, 4:], atol=1e-5) = True
print(f"{np.allclose(torch_out[1][-1].permute([0, 2, 3, 1]), keras_masks, atol=1e-5) = }")
# np.allclose(torch_out[1][-1].permute([0, 2, 3, 1]), keras_masks, atol=1e-5) = True

COCO eval results

python coco_eval_script.py -m yolov8.YOLOV8_N --nms_method hard --nms_iou_or_sigma 0.65 --nms_max_output_size 300 \
--nms_topk -1 --letterbox_pad 64 --input_shape 704
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.373
# Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.529
# Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.402
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.184
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.410
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.531
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.321
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.533
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.585
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.355
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.649
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.761
python coco_eval_script.py -m yolov8.YOLOV8_X6 --nms_method hard --nms_iou_or_sigma 0.65 --nms_max_output_size 300 \
--nms_topk -1 --letterbox_pad 64 --input_shape 1344
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.567
# Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.740
# Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.618
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.428
# Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.612
# Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.702
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.410
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.688
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.739
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.623
# Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.772
# Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.855

Training using PyTorch backend

  • Custom dataset is created using custom_dataset_script.py, which can be used as dataset_path="coco.json" for training, detail usage can be found in Custom detection dataset.
  • Train using EfficientNetV2B0 backbone + YOLOV8_N head.
    import os, sys, torch
    os.environ["KECAM_BACKEND"] = "torch"
    
    from keras_cv_attention_models.yolov8 import train, yolov8
    from keras_cv_attention_models import efficientnet
    
    global_device = torch.device("cuda:0") if torch.cuda.is_available() and int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")) >= 0 else torch.device("cpu")
    # model Trainable params: 7,023,904, GFLOPs: 8.1815G
    bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0)
    model = yolov8.YOLOV8_N(backbone=bb, classifier_activation=None, pretrained=None).to(global_device)  # Note: classifier_activation=None
    # model = yolov8.YOLOV8_N(input_shape=(3, None, None), classifier_activation=None, pretrained=None).to(global_device)
    ema = train.train(model, dataset_path="coco.json", initial_epoch=0)
    yolov8_training
  • Predict after training using Torch / TF backend. bbox output format is in [top, left, bottom, right], or yxyx format.
    from keras_cv_attention_models import efficientnet, yolov8, test_images
    from keras_cv_attention_models.coco import data
    
    bb = efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0, pretrained=None)
    model = yolov8.YOLOV8_N(backbone=bb, pretrained="yolov8_n.h5")
    
    imm = test_images.dog_cat()
    preds = model(model.preprocess_input(imm))
    bboxes, labels, confidences = model.decode_predictions(preds)[0]
    data.show_image_with_bboxes(imm, bboxes, labels, confidences)
  • Evaluation after training
    import kecam
    bb = kecam.efficientnet.EfficientNetV2B0(input_shape=(3, 640, 640), num_classes=0, pretrained=None)
    model = kecam.yolov8.YOLOV8_N(backbone=bb, pretrained="yolov8_n.h5")
    ee = kecam.coco.eval_func.COCOEvalCallback(data_name="coco.json", batch_size=32, nms_method="hard", nms_iou_or_sigma=0.65, rescale_mode="raw01")
    ee.model = model
    ee.on_epoch_end()