Keras FastViT



Model Params FLOPs Input Top1 Acc Download
FastViT_T8 4.03M 0.65G 256 76.2 fastvit_t8_imagenet.h5
- distill 4.03M 0.65G 256 77.2 fastvit_t8_distill.h5
FastViT_T12 7.55M 1.34G 256 79.3 fastvit_t12_imagenet.h5
- distill 7.55M 1.34G 256 80.3 fastvit_t12_distill.h5
FastViT_S12 9.47M 1.74G 256 79.9 fastvit_s12_imagenet.h5
- distill 9.47M 1.74G 256 81.1 fastvit_s12_distill.h5
FastViT_SA12 11.58M 1.88G 256 80.9 fastvit_sa12_imagenet.h5
- distill 11.58M 1.88G 256 81.9 fastvit_sa12_distill.h5
FastViT_SA24 21.55M 3.66G 256 82.7 fastvit_sa24_imagenet.h5
- distill 21.55M 3.66G 256 83.4 fastvit_sa24_distill.h5
FastViT_SA36 31.53M 5.44G 256 83.6 fastvit_sa36_imagenet.h5
- distill 31.53M 5.44G 256 84.2 fastvit_sa36_distill.h5
FastViT_MA36 44.07M 7.64G 256 83.9 fastvit_ma36_imagenet.h5
- distill 44.07M 7.64G 256 84.6 fastvit_ma36_distill.h5


from keras_cv_attention_models import fastvit, test_images

# Will download and load pretrained imagenet weights.
mm = fastvit.FastViT_T8(pretrained="distill")

# Run prediction
preds = mm(mm.preprocess_input(
# [('n02124075', 'Egyptian_cat', 0.8990752), ('n02123045', 'tabby', 0.013779595), ...]

Change input resolution by set new input_shape, or use dynamic input resolution by set input_shape=(None, None, 3).

from keras_cv_attention_models import fastvit, test_images
mm = fastvit.FastViT_SA12(pretrained="imagenet", input_shape=(219, 112, 3))
# Run prediction
preds = mm(mm.preprocess_input(
# [('n02124075', 'Egyptian_cat', 0.94039464), ('n02123159', 'tiger_cat', 0.0059115295), ...]
from keras_cv_attention_models import fastvit, test_images
mm = fastvit.FastViT_T8(pretrained="distill", input_shape=(None, None, 3))
# Run prediction
preds = mm(mm.preprocess_input(, input_shape=(219, 112, 3)))
# [('n02124075', 'Egyptian_cat', 0.9374073), ('n03942813', 'ping-pong_ball', 0.019263275), ...]

Switch to deploy by calling model.switch_to_deploy(), will fuse reparameter block into a single Conv2D layer, by calling fuse_reparam_blocks 3 times, and apply convert_to_fused_conv_bn_model that fusing Conv2D->BatchNorm.

from keras_cv_attention_models import fastvit, test_images, model_surgery

mm = fastvit.FastViT_SA12(pretrained="imagenet")
# Total params: 11,616,296 | Trainable params: 11,580,968 | Non-trainable params:35,328
preds = mm(mm.preprocess_input(

""" switch_to_deploy """
bb = mm.switch_to_deploy()
# Total params: 11,540,456 | Trainable params: 11,538,408 | Non-trainable params:2,048
preds_deploy = bb(bb.preprocess_input(

print(f"{np.allclose(preds, preds_deploy, atol=1e-5) = }")
# np.allclose(preds, preds_deploy, atol=1e-5) = True

""" save and load weights using deploy=True """"aa.h5")
cc = fastvit.FastViT_SA12(pretrained="aa.h5", deploy=True)
print(f"{np.allclose(preds_deploy, cc(cc.preprocess_input(, atol=1e-7) = }")
# np.allclose(preds_deploy, cc(cc.preprocess_input(, atol=1e-7) = True

Using PyTorch backend by set KECAM_BACKEND='torch' environment variable.

os.environ['KECAM_BACKEND'] = 'torch'
from keras_cv_attention_models import fastvit, test_images
mm = fastvit.FastViT_T8(pretrained="distill", input_shape=(219, 112, 3))
# >>>> Using PyTorch backend
# >>>> Load pretrained from: ~/.keras/models/fastvit_t8_distill.h5

# Run prediction
preds = mm(mm.preprocess_input(
# [('n02124075', 'Egyptian_cat', 0.937407), ('n03942813', 'ping-pong_ball', 0.019263512), ...]

Verification with PyTorch version

""" PyTorch fastvit_sa12 """
sys.path.append('../pytorch-image-models/')  # Needs timm
import torch
import models as torch_fastvit

torch_model = torch_fastvit.fastvit_sa12()
ss = torch.load('fastvit_sa12.pth.tar', map_location=torch.device('cpu'))
torch_model.load_state_dict(ss.get('state_dict', ss))
_ = torch_model.eval()

""" Keras FastViT_SA12 """
from keras_cv_attention_models import fastvit
mm = fastvit.FastViT_SA12(pretrained="imagenet", classifier_activation=None)

""" Verification """
inputs = np.random.uniform(size=(1, *mm.input_shape[1:3], 3)).astype("float32")
torch_out = torch_model(torch.from_numpy(inputs).permute(0, 3, 1, 2)).detach().numpy()
keras_out = mm(inputs).numpy()
print(f"{np.allclose(torch_out, keras_out, atol=1e-4) = }")
# np.allclose(torch_out, keras_out, atol=1e-4) = True