from keras_cv_attention_models import mobilevit
# Will download and load pretrained imagenet weights.
mm = mobilevit .MobileViT_XXS (pretrained = "imagenet" )
# Run prediction
from skimage .data import chelsea
import tensorflow as tf
from tensorflow import keras
imm = tf .expand_dims (tf .image .resize (chelsea (), mm .input_shape [1 :3 ]), 0 ) / 255 # Chelsea the cat
pred = mm (imm ).numpy ()
print (keras .applications .imagenet_utils .decode_predictions (pred )[0 ])
# [('n02124075', 'Egyptian_cat', 0.6774389), ('n02123045', 'tabby', 0.12461892), ...]
Change input resolution . For input resolution not divisible by 64
, will apply tf.image.resize
for transformer blocks.
from keras_cv_attention_models import mobilevit
mm = mobilevit .MobileViT_V2_100 (input_shape = (260 , 277 , 3 ), pretrained = "imagenet" )
# >>>> Load pretrained from: ~/.keras/models/mobilevit_v2_100_256_imagenet.h5
# Run prediction
from skimage .data import chelsea
preds = mm (mm .preprocess_input (chelsea ()))
print (mm .decode_predictions (preds ))
# [[('n02124075', 'Egyptian_cat', 0.38652435), ('n02123159', 'tiger_cat', 0.2578847), ...]
Verification with PyTorch version
""" PyTorch mobilevit_s """
sys .path .append ('../pytorch-image-models/' )
import timm
torch_model = timm .models .mobilevit_s (pretrained = True )
_ = torch_model .eval ()
""" Keras MobileViT_S """
from keras_cv_attention_models import mobilevit
mm = mobilevit .MobileViT_S (pretrained = "imagenet" , classifier_activation = None )
""" Verification """
import torch
inputs = np .random .uniform (size = (1 , * mm .input_shape [1 :3 ], 3 )).astype ("float32" )
torch_out = torch_model (torch .from_numpy (inputs ).permute (0 , 3 , 1 , 2 )).detach ().numpy ()
keras_out = mm (inputs ).numpy ()
print (f"{ np .allclose (torch_out , keras_out , atol = 1e-3 ) = } " )
# np.allclose(torch_out, keras_out, atol=1e-3) = True