diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9a5e19a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,11 @@ +repos: +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.2.0 + hooks: + - id: black +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: + - --profile=black diff --git a/README.md b/README.md index ccf2f85..f3c55a5 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,25 @@ -# mlx-onnx -MLX support for the Open Neural Network Exchange (ONNX) +# MLX ONNX + +MLX support for the Open Neural Network Exchange ([ONNX](https://onnx.ai/)) + +## Install + +```shell +pip install mlx-onnx +``` + +## Usage + +```python +from mlx.onnx import MlxBackend +from onnx import hub + +model = hub.load("mnist") +backend = MlxBackend(model) +result = backend.run(...) # pass inputs to model +``` + +## Examples + +- [ResNet](./examples/resnet/example.py) +- [Mnist](./examples/mnist/example.py) diff --git a/examples/mnist/example.py b/examples/mnist/example.py new file mode 100644 index 0000000..a1731a2 --- /dev/null +++ b/examples/mnist/example.py @@ -0,0 +1,19 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import numpy as np +from onnx import hub +from PIL import Image + +from mlx.onnx import MlxBackend + +if __name__ == "__main__": + x = ( + mx.array(np.asarray(Image.open("./nine.jpeg"))) + .reshape((1, 1, 28, 28)) + .astype(mx.float32) + ) + model = hub.load("mnist") + backend = MlxBackend(model) + res = backend.run(x) + print(f"It was a {mx.argmax(res[0]).item()}") diff --git a/examples/mnist/five.jpeg b/examples/mnist/five.jpeg new file mode 100644 index 0000000..afe5b32 Binary files /dev/null and b/examples/mnist/five.jpeg differ diff --git a/examples/mnist/four.jpeg b/examples/mnist/four.jpeg new file mode 100644 index 0000000..6aad21e Binary files /dev/null and b/examples/mnist/four.jpeg differ diff --git a/examples/mnist/nine.jpeg b/examples/mnist/nine.jpeg new file mode 100644 index 0000000..6d15287 Binary files /dev/null and b/examples/mnist/nine.jpeg differ diff --git a/examples/resnet/car.jpg b/examples/resnet/car.jpg new file mode 100644 index 0000000..8e6b40b Binary files /dev/null and b/examples/resnet/car.jpg differ diff --git a/examples/resnet/example.py b/examples/resnet/example.py new file mode 100644 index 0000000..4120007 --- /dev/null +++ b/examples/resnet/example.py @@ -0,0 +1,33 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import mlx.data as dx +import onnx + +from mlx.onnx import MlxBackend + + +def run(image: str): + dataset = ( + dx.buffer_from_vector([{"file_name": image.encode()}]) + .load_image("file_name", output_key="image") + .image_resize_smallest_side("image", 256) + .image_center_crop("image", 224, 224) + .key_transform("image", lambda x: (x - 127.0) / 128.0) + ) + with open("./imagenet_labels.txt") as f: + labels = [l.strip() for l in f.readlines()] + + model = onnx.hub.load("resnet50") + backend = MlxBackend(model) + res = [] + for data in dataset: + img = mx.array(data["image"]).transpose(2, 0, 1)[None] + x = backend.run(img)[0] + res.append((labels[mx.argmax(x).item()], mx.max(x).item())) + return res + + +if __name__ == "__main__": + for label, score in run("./car.jpg"): + print(f"Image containes a {label} with score {score:.3f}.") diff --git a/examples/resnet/imagenet_labels.txt b/examples/resnet/imagenet_labels.txt new file mode 100644 index 0000000..337b7dd --- /dev/null +++ b/examples/resnet/imagenet_labels.txt @@ -0,0 +1,1000 @@ +tench +goldfish +great_white_shark +tiger_shark +hammerhead +electric_ray +stingray +cock +hen +ostrich +brambling +goldfinch +house_finch +junco +indigo_bunting +robin +bulbul +jay +magpie +chickadee +water_ouzel +kite +bald_eagle +vulture +great_grey_owl +European_fire_salamander +common_newt +eft +spotted_salamander +axolotl +bullfrog +tree_frog +tailed_frog +loggerhead +leatherback_turtle +mud_turtle +terrapin +box_turtle +banded_gecko +common_iguana +American_chameleon +whiptail +agama +frilled_lizard +alligator_lizard +Gila_monster +green_lizard +African_chameleon +Komodo_dragon +African_crocodile +American_alligator +triceratops +thunder_snake +ringneck_snake +hognose_snake +green_snake +king_snake +garter_snake +water_snake +vine_snake +night_snake +boa_constrictor +rock_python +Indian_cobra +green_mamba +sea_snake +horned_viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black_and_gold_garden_spider +barn_spider +garden_spider +black_widow +tarantula +wolf_spider +tick +centipede +black_grouse +ptarmigan +ruffed_grouse +prairie_chicken +peacock +quail +partridge +African_grey +macaw +sulphur-crested_cockatoo +lorikeet +coucal +bee_eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted_merganser +goose +black_swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea_anemone +brain_coral +flatworm +nematode +conch +snail +slug +sea_slug +chiton +chambered_nautilus +Dungeness_crab +rock_crab +fiddler_crab +king_crab +American_lobster +spiny_lobster +crayfish +hermit_crab +isopod +white_stork +black_stork +spoonbill +flamingo +little_blue_heron +American_egret +bittern +crane +limpkin +European_gallinule +American_coot +bustard +ruddy_turnstone +red-backed_sandpiper +redshank +dowitcher +oystercatcher +pelican +king_penguin +albatross +grey_whale +killer_whale +dugong +sea_lion +Chihuahua +Japanese_spaniel +Maltese_dog +Pekinese +Shih-Tzu +Blenheim_spaniel +papillon +toy_terrier +Rhodesian_ridgeback +Afghan_hound +basset +beagle +bloodhound +bluetick +black-and-tan_coonhound +Walker_hound +English_foxhound +redbone +borzoi +Irish_wolfhound +Italian_greyhound +whippet +Ibizan_hound +Norwegian_elkhound +otterhound +Saluki +Scottish_deerhound +Weimaraner +Staffordshire_bullterrier +American_Staffordshire_terrier +Bedlington_terrier +Border_terrier +Kerry_blue_terrier +Irish_terrier +Norfolk_terrier +Norwich_terrier +Yorkshire_terrier +wire-haired_fox_terrier +Lakeland_terrier +Sealyham_terrier +Airedale +cairn +Australian_terrier +Dandie_Dinmont +Boston_bull +miniature_schnauzer +giant_schnauzer +standard_schnauzer +Scotch_terrier +Tibetan_terrier +silky_terrier +soft-coated_wheaten_terrier +West_Highland_white_terrier +Lhasa +flat-coated_retriever +curly-coated_retriever +golden_retriever +Labrador_retriever +Chesapeake_Bay_retriever +German_short-haired_pointer +vizsla +English_setter +Irish_setter +Gordon_setter +Brittany_spaniel +clumber +English_springer +Welsh_springer_spaniel +cocker_spaniel +Sussex_spaniel +Irish_water_spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old_English_sheepdog +Shetland_sheepdog +collie +Border_collie +Bouvier_des_Flandres +Rottweiler +German_shepherd +Doberman +miniature_pinscher +Greater_Swiss_Mountain_dog +Bernese_mountain_dog +Appenzeller +EntleBucher +boxer +bull_mastiff +Tibetan_mastiff +French_bulldog +Great_Dane +Saint_Bernard +Eskimo_dog +malamute +Siberian_husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great_Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon_griffon +Pembroke +Cardigan +toy_poodle +miniature_poodle +standard_poodle +Mexican_hairless +timber_wolf +white_wolf +red_wolf +coyote +dingo +dhole +African_hunting_dog +hyena +red_fox +kit_fox +Arctic_fox +grey_fox +tabby +tiger_cat +Persian_cat +Siamese_cat +Egyptian_cat +cougar +lynx +leopard +snow_leopard +jaguar +lion +tiger +cheetah +brown_bear +American_black_bear +ice_bear +sloth_bear +mongoose +meerkat +tiger_beetle +ladybug +ground_beetle +long-horned_beetle +leaf_beetle +dung_beetle +rhinoceros_beetle +weevil +fly +bee +ant +grasshopper +cricket +walking_stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage_butterfly +sulphur_butterfly +lycaenid +starfish +sea_urchin +sea_cucumber +wood_rabbit +hare +Angora +hamster +porcupine +fox_squirrel +marmot +beaver +guinea_pig +sorrel +zebra +hog +wild_boar +warthog +hippopotamus +ox +water_buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian_camel +llama +weasel +mink +polecat +black-footed_ferret +otter +skunk +badger +armadillo +three-toed_sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis_monkey +marmoset +capuchin +howler_monkey +titi +spider_monkey +squirrel_monkey +Madagascar_cat +indri +Indian_elephant +African_elephant +lesser_panda +giant_panda +barracouta +eel +coho +rock_beauty +anemone_fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic_gown +accordion +acoustic_guitar +aircraft_carrier +airliner +airship +altar +ambulance +amphibian +analog_clock +apiary +apron +ashcan +assault_rifle +backpack +bakery +balance_beam +balloon +ballpoint +Band_Aid +banjo +bannister +barbell +barber_chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing_cap +bath_towel +bathtub +beach_wagon +beacon +beaker +bearskin +beer_bottle +beer_glass +bell_cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo_tie +bonnet +bookcase +bookshop +bottlecap +bow +bow_tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof_vest +bullet_train +butcher_shop +cab +caldron +candle +cannon +canoe +can_opener +cardigan +car_mirror +carousel +carpenter's_kit +carton +car_wheel +cash_machine +cassette +cassette_player +castle +catamaran +CD_player +cello +cellular_telephone +chain +chainlink_fence +chain_mail +chain_saw +chest +chiffonier +chime +china_cabinet +Christmas_stocking +church +cinema +cleaver +cliff_dwelling +cloak +clog +cocktail_shaker +coffee_mug +coffeepot +coil +combination_lock +computer_keyboard +confectionery +container_ship +convertible +corkscrew +cornet +cowboy_boot +cowboy_hat +cradle +crane +crash_helmet +crate +crib +Crock_Pot +croquet_ball +crutch +cuirass +dam +desk +desktop_computer +dial_telephone +diaper +digital_clock +digital_watch +dining_table +dishrag +dishwasher +disk_brake +dock +dogsled +dome +doormat +drilling_platform +drum +drumstick +dumbbell +Dutch_oven +electric_fan +electric_guitar +electric_locomotive +entertainment_center +envelope +espresso_maker +face_powder +feather_boa +file +fireboat +fire_engine +fire_screen +flagpole +flute +folding_chair +football_helmet +forklift +fountain +fountain_pen +four-poster +freight_car +French_horn +frying_pan +fur_coat +garbage_truck +gasmask +gas_pump +goblet +go-kart +golf_ball +golfcart +gondola +gong +gown +grand_piano +greenhouse +grille +grocery_store +guillotine +hair_slide +hair_spray +half_track +hammer +hamper +hand_blower +hand-held_computer +handkerchief +hard_disc +harmonica +harp +harvester +hatchet +holster +home_theater +honeycomb +hook +hoopskirt +horizontal_bar +horse_cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw_puzzle +jinrikisha +joystick +kimono +knee_pad +knot +lab_coat +ladle +lampshade +laptop +lawn_mower +lens_cap +letter_opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic_compass +mailbag +mailbox +maillot +maillot +manhole_cover +maraca +marimba +mask +matchstick +maypole +maze +measuring_cup +medicine_chest +megalith +microphone +microwave +military_uniform +milk_can +minibus +miniskirt +minivan +missile +mitten +mixing_bowl +mobile_home +Model_T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito_net +motor_scooter +mountain_bike +mountain_tent +mouse +mousetrap +moving_van +muzzle +nail +neck_brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil_filter +organ +oscilloscope +overskirt +oxcart +oxygen_mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper_towel +parachute +parallel_bars +park_bench +parking_meter +passenger_car +patio +pay-phone +pedestal +pencil_box +pencil_sharpener +perfume +Petri_dish +photocopier +pick +pickelhaube +picket_fence +pickup +pier +piggy_bank +pill_bottle +pillow +ping-pong_ball +pinwheel +pirate +pitcher +plane +planetarium +plastic_bag +plate_rack +plow +plunger +Polaroid_camera +pole +police_van +poncho +pool_table +pop_bottle +pot +potter's_wheel +power_drill +prayer_rug +printer +prison +projectile +projector +puck +punching_bag +purse +quill +quilt +racer +racket +radiator +radio +radio_telescope +rain_barrel +recreational_vehicle +reel +reflex_camera +refrigerator +remote_control +restaurant +revolver +rifle +rocking_chair +rotisserie +rubber_eraser +rugby_ball +rule +running_shoe +safe +safety_pin +saltshaker +sandal +sarong +sax +scabbard +scale +school_bus +schooner +scoreboard +screen +screw +screwdriver +seat_belt +sewing_machine +shield +shoe_shop +shoji +shopping_basket +shopping_cart +shovel +shower_cap +shower_curtain +ski +ski_mask +sleeping_bag +slide_rule +sliding_door +slot +snorkel +snowmobile +snowplow +soap_dispenser +soccer_ball +sock +solar_dish +sombrero +soup_bowl +space_bar +space_heater +space_shuttle +spatula +speedboat +spider_web +spindle +sports_car +spotlight +stage +steam_locomotive +steel_arch_bridge +steel_drum +stethoscope +stole +stone_wall +stopwatch +stove +strainer +streetcar +stretcher +studio_couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension_bridge +swab +sweatshirt +swimming_trunks +swing +switch +syringe +table_lamp +tank +tape_player +teapot +teddy +television +tennis_ball +thatch +theater_curtain +thimble +thresher +throne +tile_roof +toaster +tobacco_shop +toilet_seat +torch +totem_pole +tow_truck +toyshop +tractor +trailer_truck +tray +trench_coat +tricycle +trimaran +tripod +triumphal_arch +trolleybus +trombone +tub +turnstile +typewriter_keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending_machine +vestment +viaduct +violin +volleyball +waffle_iron +wall_clock +wallet +wardrobe +warplane +washbasin +washer +water_bottle +water_jug +water_tower +whiskey_jug +whistle +wig +window_screen +window_shade +Windsor_tie +wine_bottle +wing +wok +wooden_spoon +wool +worm_fence +wreck +yawl +yurt +web_site +comic_book +crossword_puzzle +street_sign +traffic_light +book_jacket +menu +plate +guacamole +consomme +hot_pot +trifle +ice_cream +ice_lolly +French_loaf +bagel +pretzel +cheeseburger +hotdog +mashed_potato +head_cabbage +broccoli +cauliflower +zucchini +spaghetti_squash +acorn_squash +butternut_squash +cucumber +artichoke +bell_pepper +cardoon +mushroom +Granny_Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard_apple +pomegranate +hay +carbonara +chocolate_sauce +dough +meat_loaf +pizza +potpie +burrito +red_wine +espresso +cup +eggnog +alp +bubble +cliff +coral_reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba_diver +rapeseed +daisy +yellow_lady's_slipper +corn +acorn +hip +buckeye +coral_fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet_tissue \ No newline at end of file diff --git a/mlx/onnx/__init__.py b/mlx/onnx/__init__.py new file mode 100644 index 0000000..028b068 --- /dev/null +++ b/mlx/onnx/__init__.py @@ -0,0 +1 @@ +from .backend import MlxBackend, MlxBackendWrapper diff --git a/mlx/onnx/backend.py b/mlx/onnx/backend.py new file mode 100644 index 0000000..9952f3b --- /dev/null +++ b/mlx/onnx/backend.py @@ -0,0 +1,155 @@ +import importlib +import os +from typing import Any, Callable, List, Tuple + +import mlx.core as mx +import numpy as np +import onnx +from onnx.helper import tensor_dtype_to_np_dtype + +onnx_ops = importlib.import_module("mlx.onnx.ops") +DEBUG = os.getenv("DEBUG", "0") == "1" + + +class MlxBackendWrapper: + @classmethod + def prepare(cls, model: onnx.ModelProto, device: str): + return MlxBackend(model) + + @classmethod + def supports_device(cls, device: str) -> bool: + return device.lower() in ["cpu", "gpu"] + + +class MlxBackend: + def __init__(self, model: onnx.ModelProto): + self._model = model + self._cache = {} + self._registered_ops = {} + self.initializer_arrays() + + def register_op(self, name: str, op: Callable): + if name in self._registered_ops: + raise ValueError(f"Op {name} already registered") + self._registered_ops[name] = op + + def initializer_arrays(self): + for i in self._model.graph.initializer: + if i.name in self._cache: + continue + self._cache[i.name] = self.parse_array(i) + + def parse_array(self, inp: onnx.TensorProto) -> mx.array: + if inp.data_type == onnx.TensorProto.FLOAT and len(inp.float_data) > 0: + return mx.array( + np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), + dtype=mx.float32, + ) + elif inp.data_type == onnx.TensorProto.INT32 and len(inp.int32_data) > 0: + return mx.array( + np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), + dtype=mx.int32, + ) + elif inp.data_type == onnx.TensorProto.INT64 and len(inp.int64_data) > 0: + return mx.array( + np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), + dtype=mx.int64, + ) + elif len(inp.raw_data) > 0: + return mx.array( + np.frombuffer( + inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type) + ).reshape(inp.dims) + ) + else: + raise NotImplementedError( + f"Not implemented for {inp.data_type} {inp.name} {inp.dims}" + ) + + def get_input_dict(self, inputs): + input_names = [x.name for x in self._model.graph.input] + init_names = set([x.name for x in self._model.graph.initializer]) + real_inputs = [x for x in input_names if x not in init_names] + return dict(zip(real_inputs, inputs)) + + def parse_attributes(self, attrs): + res = {} + for x in attrs: + if x.type == onnx.AttributeProto.FLOAT: + res[x.name] = float(x.f) + elif x.type == onnx.AttributeProto.INT: + res[x.name] = int(x.i) + elif x.type == onnx.AttributeProto.STRING: + res[x.name] = x.s.decode("utf-8") + elif x.type == onnx.AttributeProto.TENSOR: + res[x.name] = self.parse_array(x.t) + # Sometimes this gets passed as args to functions that expect mx.array, so just converting + # them here to simplify the op code + elif x.type == onnx.AttributeProto.FLOATS: + res[x.name] = mx.array([float(f) for f in x.floats], dtype=mx.float32) + elif x.type == onnx.AttributeProto.INTS: + res[x.name] = mx.array([int(i) for i in x.ints], dtype=mx.int64) + elif x.type == onnx.AttributeProto.STRINGS: + res[x.name] = tuple(s.decode("utf-8") for s in x.strings) + elif x.type == onnx.AttributeProto.GRAPH: + raise NotImplementedError(f"Attribute type graph not implemented") + else: + raise NotImplementedError(f"Attribute type {x.type} not implemented") + return res + + def run(self, *inputs, **kwargs: Any) -> Tuple[mx.array, ...]: + if len(inputs) == 1 and isinstance(inputs[0], List): + inputs = tuple(inputs[0]) + if not isinstance(inputs, Tuple): + inputs = tuple(inputs) + + self.initializer_arrays() + inmap = self.get_input_dict(inputs) + + for i in self._model.graph.input: + if i.name in self._cache: + continue + if i.name in inmap: + if isinstance(inmap[i.name], mx.array): + self._cache[i.name] = inmap[i.name] + elif isinstance(inmap[i.name], list): + self._cache[i.name] = [ + mx.array(x) if not isinstance(x, mx.array) else x + for x in inmap[i.name] + ] + elif isinstance(inmap[i.name], np.ndarray): + self._cache[i.name] = mx.array(inmap[i.name]) + elif inmap[i.name] is None: + self._cache[i.name] = None + else: + raise NotImplementedError( + f"Input type {inmap[i.name]} not implemented" + ) + for i, node in enumerate(self._model.graph.node): + args = [self._cache[x] if x in self._cache else None for x in node.input] + opt = self.parse_attributes(node.attribute) + if DEBUG: + print( + f"Running op {node.input} {node.op_type} with args {len(args)} and opt {opt}" + ) + # Special case for split as outputs might need to be inferred from node + if node.op_type == "Split": + if "num_outputs" not in opt and len(args) != 2: + opt["num_outputs"] = len(node.output) + res = getattr(onnx_ops, node.op_type)(*args, **opt) + elif node.op_type in self._registered_ops: + res = self._registered_ops[node.op_type](*args, **opt) + elif hasattr(onnx_ops, node.op_type): + res = getattr(onnx_ops, node.op_type)(*args, **opt) + else: + raise NotImplementedError(f"Operation {node.op_type} not implemented") + + if not isinstance(res, tuple): + res = (res,) + if len(node.output) > len(res): + raise ValueError( + f"Expected {len(node.output)} outputs but got {len(res)}" + ) + for name, out in zip(node.output, res): + self._cache[name] = out + return tuple(self._cache[out.name] for out in self._model.graph.output) diff --git a/mlx/onnx/ops/__init__.py b/mlx/onnx/ops/__init__.py new file mode 100644 index 0000000..0ae5745 --- /dev/null +++ b/mlx/onnx/ops/__init__.py @@ -0,0 +1,567 @@ +import functools +import math +from typing import List, Optional, Union + +import mlx.core as mx +import mlx.nn.layers as layers +import mlx.nn.losses as losses +import onnx + +from .helper import dtype_helper +from .op_conv import Conv +from .op_depth import DepthToSpace, SpaceToDepth +from .op_dropout import Dropout +from .op_image import ImageDecoder +from .op_lrn import LRN +from .op_norm import ( + BatchNormalization, + GroupNormalization, + InstanceNormalization, + LayerNormalization, +) +from .op_onehot import OneHot +from .op_pad import Pad +from .op_pool import AveragePool, MaxPool +from .op_sequence import ( + ConcatFromSequence, + SequenceAt, + SequenceConstruct, + SequenceEmpty, + SequenceErase, + SequenceInsert, + SequenceLength, + SplitToSequence, +) +from .op_slice import Slice +from .op_split import Split +from .op_topk import TopK +from .op_window import BlackmanWindow, HammingWindow, HannWindow + +# Reference Docs: https://onnx.ai/onnx/operators/ + + +def Add(x: mx.array, y: mx.array, broadcast=None, axis=None): + return x + y + + +def Sub(x: mx.array, y: mx.array): + return x - y + + +def Mul(x: mx.array, y: mx.array): + return x * y + + +def Div(x: mx.array, y: mx.array): + return x / y + + +def Neg(x: mx.array): + return -x + + +def Pow(x: mx.array, y: mx.array): + return (x**y).astype(x.dtype) + + +def Sqrt(x: mx.array): + return x.sqrt() + + +def Abs(x: mx.array): + return x.abs() + + +def Exp(x: mx.array): + return x.exp() + + +def Log(x: mx.array): + return x.log() + + +def Sin(x: mx.array): + return x.sin() + + +def Sinh(x: mx.array): + return mx.sinh(x) + + +def Asin(x: mx.array): + return mx.arcsin(x) + + +def Asinh(x: mx.array): + return mx.arcsinh(x) + + +def Cos(x: mx.array): + return x.cos() + + +def Cosh(x: mx.array): + return mx.cosh(x) + + +def Acos(x: mx.array): + return mx.arccos(x) + + +def Acosh(x: mx.array): + return mx.arccosh(x) + + +def Tan(x: mx.array): + return x.sin() / x.cos() + + +def Tanh(x: mx.array): + return mx.sinh(x) / mx.cosh(x) + + +def Atan(x: mx.array): + return mx.arctan(x) + + +def Atanh(x: mx.array): + return mx.arctanh(x) + + +def Relu(x: mx.array): + return layers.relu(x) + + +def Floor(x: mx.array): + return mx.floor(x) + + +def Ceil(x: mx.array): + return mx.ceil(x) + + +def Sigmoid(x: mx.array): + return mx.sigmoid(x) + + +def Sign(x: mx.array): + return mx.sign(x) + + +def Softplus(x: mx.array): + return layers.softplus(x) + + +def HardSwish(x: mx.array): + return layers.hardswish(x) + + +def HardSigmoid(x: mx.array, alpha=0.2, beta=0.5): + return mx.clip(x * alpha + beta, 0, 1) + + +def Softsign(x: mx.array): + return layers.softsign(x) + + +def MatMul(x: mx.array, y: mx.array): + return x @ y + + +def MatMulInteger( + x: mx.array, + y: mx.array, + a_zero_point: Optional[mx.array] = None, + b_zero_point: Optional[mx.array] = None, +): + x = x.astype(mx.float32) + y = y.astype(mx.float32) + if a_zero_point is not None: + x = x - a_zero_point + if b_zero_point is not None: + y = y - b_zero_point + return (x @ y).astype(mx.int32) + + +def Cast(x: mx.array, to: int, saturate=1): + if to == onnx.TensorProto.DOUBLE: + raise NotImplementedError("mlx does not support double data type") + return x.astype(dtype_helper(to)) + + +def CastLike(x: mx.array, target_type: mx.array, saturate=1): + return x.astype(target_type.dtype) + + +def ConstantOfShape(x: mx.array, value: mx.array = None): + if value is None: + value = mx.array([0]) + shape = x.tolist() + return mx.ones(shape, dtype=value.dtype) * (value if shape[0] != 0 else 1) + + +def Tile(x: mx.array, repeats: mx.array): + return mx.tile(x, repeats.tolist()) + + +def Shape(x: mx.array, end=None, start=0): + return mx.array(x.shape[start:end], dtype=mx.int64) + + +def Constant( + value: mx.array = None, + value_float=None, + value_floats=None, + value_int=None, + value_ints=None, + value_string=None, + value_strings=None, +): + if value is not None: + return value + if value_float is not None: + return mx.array(value_float, dtype=mx.float32) + if value_floats is not None: + return mx.array(list(value_floats), dtype=mx.float32) + if value_int is not None: + return mx.array(value_int, dtype=mx.int32) + if value_ints is not None: + return mx.array(list(value_ints), dtype=mx.int32) + if value_string is not None or value_strings is not None: + raise NotImplementedError() + + +def Less(x: mx.array, y: mx.array): + return x < y + + +def LessOrEqual(x: mx.array, y: mx.array): + return x <= y + + +def Equal(x: mx.array, y: mx.array): + return x == y + + +def Greater(x: mx.array, y: mx.array): + return x > y + + +def GreaterOrEqual(x: mx.array, y: mx.array): + return x >= y + + +def Where(condition: mx.array, x: mx.array, y: mx.array): + return mx.where(condition, x, y) + + +def LeakyRelu(x: mx.array, alpha=0.01): + return layers.leaky_relu(x, alpha) + + +def And(x: mx.array, y: mx.array): + return x & y + + +def Or(x: mx.array, y: mx.array): + return x | y + + +def Trilu(x: mx.array, k=0, upper=1): + if isinstance(k, mx.array): + k = k.item() + return mx.triu(x, k) if upper else mx.tril(x, k) + + +def Transpose(x: mx.array, perm: mx.array = None): + return x.transpose() if perm is None else x.transpose(perm.tolist()) + + +def Identity(x: mx.array): + return x + + +def Sum(*args: List[mx.array]): + return functools.reduce(mx.array.__add__, args) + + +def Mean(*args: List[mx.array]): + return Sum(*args) / len(args) + + +def Max(*args: List[mx.array]): + return functools.reduce(mx.maximum, args) + + +def Min(*args: List[mx.array]): + return functools.reduce(mx.minimum, args) + + +def Elu(x: mx.array, alpha=1.0): + return layers.elu(x, alpha) + + +def Celu(x: mx.array, alpha=1.0): + return layers.celu(x, alpha) + + +def Reciprocal(x: mx.array): + return x.reciprocal() + + +def Mish(x: mx.array): + return layers.mish(x) + + +def PRelu(x: mx.array, slope: mx.array): + slope = slope[0] if slope.shape[-1] != x.shape[-1] else slope + return layers.prelu(x, slope) + + +def Selu(x: mx.array, alpha=1.67326319217681884765625, gamma=1.05070102214813232421875): + return gamma * (layers.relu(x) - layers.relu(-alpha * x.exp() + alpha)) + + +def Clip(x: mx.array, min=float("-inf"), max=float("inf")): + if max is None: + max = float("inf") + if min is None: + min = float("-inf") + return mx.clip(x, min, max).astype(x.dtype) + + +def Range(start: mx.array, limit: mx.array, delta: mx.array): + return mx.arange(start.item(), limit.item(), delta.item()) + + +def Size(x: Union[mx.array, list[int]]): + return mx.array(math.prod(x if isinstance(x, list) else x.shape), dtype=mx.int64) + + +def Shrink(x: mx.array, bias=0.0, lambd=0.5): + return (x < -lambd) * (x + bias) + (x > lambd) * (x - bias) + + +def Reshape(x: mx.array, shape: mx.array, allowzero=0): + new_shape = [ + int(d) if d != 0 else (0 if allowzero else x.shape[i]) + for i, d in enumerate(shape.tolist()) + ] + return x.reshape(new_shape) + + +def Squeeze(x: mx.array, axes: mx.array = None): + return mx.squeeze(x, axes.tolist() if axes is not None else None) + + +def Unsqueeze(x: mx.array, axes: mx.array): + return mx.expand_dims(x, axes.tolist()) + + +def Flatten(x: mx.array, axis=1): + new_shape = math.prod([1] + list(x.shape[:axis])) + return mx.reshape( + x, + ( + new_shape, + -1, + ), + ) + + +def axes_helper(axes: Optional[mx.array] = None, noop_with_empty_axes=0): + if isinstance(axes, tuple): + return axes + if axes is not None and isinstance(axes, mx.array) and axes.size > 0: + return axes.tolist() + return [] if noop_with_empty_axes else None + + +def ReduceMax(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return x.max(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceMin(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + if math.prod(x.shape) == 0: + return mx.array(float("inf")).astype(x.dtype) + return x.min(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceMean(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return x.mean(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceProd(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return x.prod(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceL1(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return x.abs().sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceL2(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return ( + x.square() + .sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + .sqrt() + ) + + +def ReduceSum(x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0): + return x.sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def ReduceLogSum( + x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0 +): + return x.sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims).log() + + +def ReduceLogSumExp( + x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0 +): + return x.exp().sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims).log() + + +def ReduceSumSquare( + x: mx.array, axes: mx.array = None, keepdims=1, noop_with_empty_axes=0 +): + return x.square().sum(axes_helper(axes, noop_with_empty_axes), keepdims=keepdims) + + +def Concat(*args: List[mx.array], axis): + return mx.concatenate(args, axis=axis) + + +def Gemm( + A: mx.array, + B: mx.array, + C: Optional[mx.array] = None, + alpha=1.0, + beta=1.0, + transA=0, + transB=0, + broadcast=0, +): + if transA: + A = A.transpose() + if transB: + B = B.transpose() + ret = alpha * (A @ B) + if C is not None: + ret += beta * C + return ret + + +def Softmax(x: mx.array, axis=-1): + return layers.softmax(x, axis=axis) + + +def LogSoftmax(x: mx.array, axis=-1): + return layers.log_softmax(x, axis=axis) + + +def Gelu(x: mx.array, approximate="none"): + return layers.gelu(x) if approximate == "none" else layers.gelu_fast_approx(x) + + +def Erf(x: mx.array): + return mx.erf(x) + + +def Round(x: mx.array): + return x.round() + + +def ArgMax(x: mx.array, axis=0, keepdims=1, select_last_index=0): + return mx.argmax(x, axis=axis, keepdims=keepdims).astype(mx.int64) + + +def ArgMin(x: mx.array, axis=0, keepdims=1, select_last_index=0): + return mx.argmin(x, axis=axis, keepdims=keepdims).astype(mx.int64) + + +def Expand(x: mx.array, shape: mx.array): + return x * mx.ones(shape.tolist()) + + +def CumSum(x: mx.array, axis: mx.array, exclusive=0, reverse=0): + return mx.cumsum(x, axis.item(), reverse=reverse, inclusive=not exclusive) + + +def EyeLike(x: mx.array, dtype=None, k=0): + if dtype is None: + dtype = x.dtype + else: + dtype = dtype_helper(dtype) + return mx.eye(x.shape[0], x.shape[1], k=k, dtype=dtype) + + +def Gather(x: mx.array, indices: mx.array, axis=0): + return mx.take(x, indices, axis=axis) + + +def GatherElements(x: mx.array, indices: mx.array, axis=0): + return mx.take_along_axis(x, indices, axis=axis) + + +def Not(x: mx.array): + return ~x + + +def Mod(x: mx.array, y: mx.array, fmod=0): + assert fmod == 0, "fmod not supported" + return x % y + + +def OptionalHasElement(x: Optional[mx.array] = None): + return mx.array(x is not None and len(x) > 0, dtype=mx.bool_) + + +def OptionalGetElement(x: Optional[mx.array] = None): + return x if x is not None else mx.array([]) + + +def IsInf(x: mx.array, detect_negative=1, detect_positive=1): + return (x == float("inf")) * bool(detect_positive) + (x == float("-inf")) * bool( + detect_negative + ) + + +def IsNaN(x: mx.array): + return x != x + + +def ThresholdedRelu(x: mx.array, alpha=1.0): + return mx.where(x > alpha, x, 0) + + +def Binarizer(x: mx.array, threshold=0.0): + return mx.where(x > threshold, 1.0, 0.0) + + +def GlobalAveragePool(x: mx.array): + return x.mean(axis=tuple(range(2, x.ndim)), keepdims=True) + + +def GlobalMaxPool(x: mx.array): + return x.max(axis=tuple(range(2, x.ndim)), keepdims=True) + + +def Xor(x: mx.array, y: mx.array): + return mx.where(x == y, False, True) + + +def Compress(x: mx.array, condition: mx.array, axis=None): + if axis is None: + x = x.flatten() + axis = 0 + axis = axis if axis >= 0 else axis + x.ndim + + # TODO: Replace with bool indexing when added + temp = [] + for i, v in enumerate(condition.tolist()): + if v: + temp.append(i) + temp = mx.array(temp, dtype=mx.int64) + return x[tuple([slice(None) if i != axis else temp for i in range(x.ndim)])] diff --git a/mlx/onnx/ops/helper.py b/mlx/onnx/ops/helper.py new file mode 100644 index 0000000..893f897 --- /dev/null +++ b/mlx/onnx/ops/helper.py @@ -0,0 +1,23 @@ +import mlx.core as mx +from onnx import TensorProto + +DTYPE_MAP = { + TensorProto.FLOAT: mx.float32, + TensorProto.UINT8: mx.uint8, + TensorProto.INT8: mx.int8, + TensorProto.UINT16: mx.uint16, + TensorProto.INT16: mx.int16, + TensorProto.INT32: mx.int32, + TensorProto.INT64: mx.int64, + TensorProto.BOOL: mx.bool_, + TensorProto.FLOAT16: mx.float16, + TensorProto.UINT32: mx.uint32, + TensorProto.UINT64: mx.uint64, + TensorProto.BFLOAT16: mx.bfloat16, + TensorProto.COMPLEX64: mx.complex64, +} + + +def dtype_helper(dtype: TensorProto.DataType) -> mx.Dtype: + assert dtype in DTYPE_MAP, f"Unsupported dtype {dtype}" + return DTYPE_MAP[dtype] diff --git a/mlx/onnx/ops/op_conv.py b/mlx/onnx/ops/op_conv.py new file mode 100644 index 0000000..e93fde7 --- /dev/null +++ b/mlx/onnx/ops/op_conv.py @@ -0,0 +1,38 @@ +import mlx.core as mx +from typing import Optional +from .pad import convert_pad, auto_pad as ap + +def Conv(x: mx.array, weight: mx.array, bias: Optional[mx.array]=None, dilations:Optional[mx.array]=None, group=1, auto_pad="NOTSET", kernel_shape:Optional[mx.array]=None, pads:Optional[mx.array]=None, strides:Optional[mx.array]=None): + assert group == 1, f"mlx only supports 1 group, got {group}" + if dilations is not None: + assert all(x == 1 for x in dilations.tolist()), "mlx only supports dilation 1" + + if isinstance(kernel_shape, mx.array): + kernel_shape = kernel_shape.tolist() + if isinstance(strides, mx.array): + strides = strides.tolist() + if strides is None: + strides = [1] * len(kernel_shape) + if isinstance(pads, mx.array): + pads = pads.tolist() + if pads is None: + pads = [0] * len(kernel_shape) + + if x.ndim < weight.ndim: + x = mx.expand_dims(x, 0) + + if auto_pad != "NOTSET": + padding = convert_pad(ap(x.shape, auto_pad, strides, kernel_shape)) + x = mx.pad(x, pad_width=[(0,0), (0,0)] + padding, constant_values=0) + + if x.ndim == 3: + c = mx.conv1d(x.transpose(0, 2, 1), weight.transpose(0, 2, 1), padding=pads[0] if pads is not None else 0, stride=strides[0] if strides is not None else 1) + c = c + bias if bias is not None else c + return c.transpose(0, 2, 1) + elif x.ndim == 4: + c = mx.conv2d(x.transpose(0, 2, 3, 1), weight.transpose(0, 2, 3, 1), padding=pads[:2] if pads is not None else 0, stride=strides if strides is not None else 1) + c = c + bias if bias is not None else c + return c.transpose(0, 3, 1, 2) + else: + raise NotImplementedError("mlx does not support conv other than 1d and 2d") + \ No newline at end of file diff --git a/mlx/onnx/ops/op_depth.py b/mlx/onnx/ops/op_depth.py new file mode 100644 index 0000000..fa6cfe3 --- /dev/null +++ b/mlx/onnx/ops/op_depth.py @@ -0,0 +1,57 @@ +import mlx.core as mx + +def DepthToSpace(x: mx.array, blocksize, mode="DCR"): + assert x.ndim == 4, "DepthToSpace only supports 4d input" + + b, c, h, w = x.shape + if mode == "DCR": + tmpshape = ( + b, + blocksize, + blocksize, + c // (blocksize * blocksize), + h, + w, + ) + reshaped = x.reshape(tmpshape) + transposed = mx.transpose(reshaped, [0, 3, 4, 1, 5, 2]) + else: + # assert mode == "CRD" + tmpshape = ( + b, + c // (blocksize * blocksize), + blocksize, + blocksize, + h, + w, + ) + reshaped = x.reshape(tmpshape) + transposed = mx.transpose(reshaped, [0, 1, 4, 2, 5, 3]) + finalshape = ( + b, + c // (blocksize * blocksize), + h * blocksize, + w * blocksize, + ) + return mx.reshape(transposed, finalshape) + +def SpaceToDepth(x: mx.array, blocksize:int): + assert x.ndim == 4, "SpaceToDepth only supports 4d input" + b, C, H, W = x.shape + tmpshape = ( + b, + C, + H // blocksize, + blocksize, + W // blocksize, + blocksize, + ) + reshaped = x.reshape(tmpshape).transpose([0, 3, 5, 1, 2, 4]) + finalshape = ( + b, + C * blocksize * blocksize, + H // blocksize, + W // blocksize, + ) + return reshaped.reshape(finalshape).astype(x.dtype) + \ No newline at end of file diff --git a/mlx/onnx/ops/op_dropout.py b/mlx/onnx/ops/op_dropout.py new file mode 100644 index 0000000..d2827bb --- /dev/null +++ b/mlx/onnx/ops/op_dropout.py @@ -0,0 +1,9 @@ +from typing import Optional + +import mlx.core as mx +import numpy as np + + +def Dropout(x: mx.array, ratio: int = 0.5, training_mode=0, seed: Optional[int] = None): + assert training_mode == 0, "Training mode not supported yet" + return x, mx.ones(x.shape, dtype=mx.bool_) diff --git a/mlx/onnx/ops/op_image.py b/mlx/onnx/ops/op_image.py new file mode 100644 index 0000000..5996f46 --- /dev/null +++ b/mlx/onnx/ops/op_image.py @@ -0,0 +1,26 @@ +import io + +import mlx.core as mx +import numpy as np + + +def ImageDecoder(x: mx.array, pixel_format="RGB"): + try: + import PIL.Image + except ImportError as e: + raise ImportError( + "Pillow is required for ImageDecoder. Please install it with `pip install Pillow`" + ) from e + img = PIL.Image.open(io.BytesIO(bytes(x))) + if pixel_format == "RGB": + img = np.array(img) + elif pixel_format == "BGR": + img = np.array(img)[:, :, ::-1] + elif pixel_format == "Grayscale": + img = img.convert("L") + img = np.array(img) + img = np.expand_dims(img, axis=2) + else: + raise ValueError(f"Unsupported pixel format: {pixel_format}") + + return mx.array(img, dtype=mx.uint8) diff --git a/mlx/onnx/ops/op_lrn.py b/mlx/onnx/ops/op_lrn.py new file mode 100644 index 0000000..5405eac --- /dev/null +++ b/mlx/onnx/ops/op_lrn.py @@ -0,0 +1,18 @@ +import math + +import mlx.core as mx + + +def LRN(x: mx.array, size: int, alpha=0.0001, beta=0.75, bias=1.0): + if x.ndim != 4: + raise NotImplementedError("LRN only supports 4D tensors") + square_sum = mx.zeros(x.shape).astype(x.dtype) + minc = x.shape[1] + c1 = int(math.floor((size - 1) / 2)) + c2 = int(math.ceil((size - 1) / 2)) + 1 + for c in range(x.shape[0]): + begin = max(0, c - c1) + end = min(minc, c + c2) + square_sum[:, c, :, :] = mx.sum(x[:, begin:end, :, :] ** 2, axis=1) + y = x / ((bias + (alpha / size) * square_sum) ** beta) + return (y.astype(x.dtype),) diff --git a/mlx/onnx/ops/op_norm.py b/mlx/onnx/ops/op_norm.py new file mode 100644 index 0000000..f82e89a --- /dev/null +++ b/mlx/onnx/ops/op_norm.py @@ -0,0 +1,30 @@ +import mlx.core as mx +from typing import Optional + +def norm(x, axis=-1, eps=1e-5, mean: Optional[mx.array]=None, var: Optional[mx.array]=None): + mean = mean if mean is not None else mx.mean(x, axis=axis, keepdims=True) + var = var if var is not None else mx.rsqrt(mx.var(x, axis=axis, keepdims=True) + eps) + return (x - mean) * var + +def BatchNormalization(x: mx.array, scale: mx.array, bias: mx.array, input_mean: mx.array, input_var: mx.array, momentum=0.9, epsilon=1e-5, spatial=1): + assert spatial == 1, "Spatial BatchNorm not supported" + t_shape = [1, -1] + [1] * (x.ndim - 2) + var = mx.rsqrt(input_var + epsilon) + return norm(x, eps=epsilon, mean=input_mean.reshape(t_shape), var=var.reshape(t_shape)) * scale.reshape(t_shape) + bias.reshape(t_shape) + +def GroupNormalization(x: mx.array, scale: mx.array, bias: mx.array, num_groups: int, epsilon=1e-5): + x_shape = x.shape + x = x.reshape([x_shape[0], num_groups, -1]) + x = norm(x, axis=-1, eps=epsilon) + return (scale.reshape([-1, 1]) * x + bias.reshape([-1, 1])).reshape(x_shape) + +def InstanceNormalization(x: mx.array, scale: mx.array, bias: mx.array, epsilon=1e-5): + return scale.reshape([-1, 1, 1]) * norm(x, axis=(2, 3), eps=epsilon) + bias.reshape([-1, 1, 1]) + +def LayerNormalization( + x: mx.array, scale: mx.array, bias: mx.array, axis=-1, stash_type=1, epsilon=1e-5 +): + axis = [i for i in range(axis if axis >= 0 else x.ndim + axis, x.ndim)] + mean = x.mean(axis=axis, keepdims=True) + invstd = (((x - mean) ** 2).mean(axis=axis, keepdims=True) + epsilon).rsqrt() + return scale * norm(x, axis=axis, eps=epsilon) + bias, mean, invstd \ No newline at end of file diff --git a/mlx/onnx/ops/op_onehot.py b/mlx/onnx/ops/op_onehot.py new file mode 100644 index 0000000..5b22e13 --- /dev/null +++ b/mlx/onnx/ops/op_onehot.py @@ -0,0 +1,17 @@ +import mlx.core as mx + + +def OneHot(indicies: mx.array, depth: mx.array, values: mx.array, axis=-1): + if isinstance(values, mx.array): + values = values.tolist() + if isinstance(depth, mx.array): + depth = depth.item() + depth_range = mx.arange(depth) + if axis < 0: + axis = indicies.ndim + axis + 1 + ls = list(indicies.shape[0:axis]) + rs = list(indicies.shape[axis : indicies.ndim]) + new_shape = [1] * len(ls) + list(depth_range.shape) + [1] * len(rs) + tgts = depth_range.reshape(new_shape) + vals = (indicies % depth).reshape(ls + [1] + rs) + return mx.where(tgts == vals, values[1], values[0]) diff --git a/mlx/onnx/ops/op_pad.py b/mlx/onnx/ops/op_pad.py new file mode 100644 index 0000000..0524f17 --- /dev/null +++ b/mlx/onnx/ops/op_pad.py @@ -0,0 +1,24 @@ +from typing import Optional + +import mlx.core as mx + +from .pad import convert_pad + + +def Pad( + x: mx.array, + pads: mx.array, + constant_value=0.0, + axes: Optional[mx.array] = None, + mode="constant", + value: Optional[float] = None, +): + assert mode == "constant", f"Only constant padding is supported, got {mode}" + if value is not None: + constant_value = value + if isinstance(pads, mx.array): + pads = pads.tolist() + if isinstance(axes, mx.array): + axes = axes.tolist() + pads = convert_pad(pads, x.ndim, axes) + return mx.pad(x, pads, constant_value) diff --git a/mlx/onnx/ops/op_pool.py b/mlx/onnx/ops/op_pool.py new file mode 100644 index 0000000..6beb12f --- /dev/null +++ b/mlx/onnx/ops/op_pool.py @@ -0,0 +1,232 @@ +import math +from typing import Callable, List, Optional + +import mlx.core as mx + +from .pad import auto_pad as ap +from .pad import convert_pad + + +def compute_strides(shape: List[int]): + return list( + reversed(mx.cumprod(mx.array([1] + list(reversed(shape))))[:-1].tolist()) + ) + + +def MaxPool( + x: mx.array, + kernel_shape=None, + auto_pad="NOTSET", + ceil_mode=0, + dilations: Optional[mx.array] = None, + pads=None, + storage_order=0, + strides=None, +): + return Pool( + x, + mx.max, + float("-inf"), + kernel_shape, + auto_pad, + ceil_mode, + dilations, + pads, + storage_order, + strides, + ) + + +def AveragePool( + x: mx.array, + kernel_shape=None, + auto_pad="NOTSET", + ceil_mode=0, + dilations: Optional[mx.array] = None, + pads=None, + storage_order=0, + strides=None, + count_include_pad=0, +): + res = Pool( + x, + mx.mean, + 0, + kernel_shape, + auto_pad, + ceil_mode, + dilations, + pads, + storage_order, + strides, + ) + if count_include_pad: + return res + div = Pool( + mx.ones_like(x), + mx.mean, + 0, + kernel_shape, + auto_pad, + ceil_mode, + dilations, + pads, + storage_order, + strides, + ) + return res / div + + +def Pool( + x: mx.array, + op: Callable[..., mx.array], + pad_fill: float, + kernel_shape=None, + auto_pad="NOTSET", + ceil_mode=0, + dilations: Optional[mx.array] = None, + pads=None, + storage_order=0, + strides=None, +): + """ + x: [Batch, Channel, Height, Width] + storage_order: how the data is layed out in the array 0 = row, 1 = col + ceil_mode: whether to use ceil mode when output calculating the shape 1 = floor 0 = ceil + pads: [x1_begin, x2_begin...x1_end, x2_end,...] + """ + assert x.ndim >= 3, "Pool only supports >= 3D input" + assert storage_order == 0, "Pool only supports storage_order=0 for now" + + if dilations is None: + dilations = [1] * len(kernel_shape) + if isinstance(dilations, mx.array): + dilations = dilations.tolist() + if any([d > 1 for d in dilations]): + raise NotImplementedError("Pool does not support dilation > 1") + + if isinstance(kernel_shape, mx.array): + kernel_shape = kernel_shape.tolist() + if isinstance(strides, mx.array): + strides = strides.tolist() + if strides is None: + strides = [1] * len(kernel_shape) + if isinstance(pads, mx.array): + pads = pads.tolist() + if pads is None: + pads = [0] * len(kernel_shape) * 2 + if auto_pad != "NOTSET": + pads = ap(x.shape, auto_pad, strides, kernel_shape) + if any([p > 0 for p in pads]): + pads = convert_pad(pads) + x = mx.pad(x, pad_width=[(0, 0), (0, 0)] + pads, constant_values=pad_fill) + + if ceil_mode == 1: + x = mx.pad( + x, + pad_width=[(0, 0), (0, 0)] + [(0, 1)] * (x.ndim - 2), + constant_values=pad_fill, + ) + if x.ndim == 3: + res = _pool1d(x, op, kernel_shape, strides, ceil_mode) + elif x.ndim == 4: + res = _pool2d(x, op, kernel_shape, strides, ceil_mode) + elif x.ndim == 5: + res = _pool3d(x, op, kernel_shape, strides, ceil_mode) + return res + + +def _pool1d( + x: mx.array, + op: Callable[..., mx.array], + kernel_shape: List[int], + strides: List[int], + ceil_mode: int, +): + [bs, ch, h] = x.shape + [b_stride, c_stride, h_stride] = compute_strides(x.shape) + _rop = lambda x: math.floor(x) if ceil_mode == 0 else math.ceil(x) + windows = mx.as_strided( + x, + shape=( + bs, + ch, + _rop((h - kernel_shape[0]) / strides[0]) + 1, + kernel_shape[0], + ), + strides=( + b_stride, + c_stride, + h_stride * strides[0], + h_stride, + ), + ) + return op(windows, axis=(3)) + + +def _pool2d( + x: mx.array, + op: Callable[..., mx.array], + kernel_shape: List[int], + strides: List[int], + ceil_mode: int, +): + [bs, ch, h, w] = x.shape + [b_stride, c_stride, h_stride, w_stride] = compute_strides(x.shape) + _rop = lambda x: math.floor(x) if ceil_mode == 0 else math.ceil(x) + windows = mx.as_strided( + x, + shape=( + bs, + ch, + _rop((h - kernel_shape[0]) / strides[0]) + 1, + _rop((w - kernel_shape[1]) / strides[1]) + 1, + kernel_shape[0], + kernel_shape[1], + ), + strides=( + b_stride, + c_stride, + h_stride * strides[0], + w_stride * strides[1], + h_stride, + w_stride, + ), + ) + return op(windows, axis=(4, 5)) + + +def _pool3d( + x: mx.array, + op: Callable[..., mx.array], + kernel_shape: List[int], + strides: List[int], + ceil_mode: int, +): + [bs, ch, h, w, d] = x.shape + [b_stride, c_stride, h_stride, w_stride, d_stride] = compute_strides(x.shape) + _rop = lambda x: math.floor(x) if ceil_mode == 0 else math.ceil(x) + windows = mx.as_strided( + x, + shape=( + bs, + ch, + _rop((h - kernel_shape[0]) / strides[0]) + 1, + _rop((w - kernel_shape[1]) / strides[1]) + 1, + _rop((d - kernel_shape[2]) / strides[2]) + 1, + kernel_shape[0], + kernel_shape[1], + kernel_shape[2], + ), + strides=( + b_stride, + c_stride, + h_stride * strides[0], + w_stride * strides[1], + d_stride * strides[2], + h_stride, + w_stride, + d_stride, + ), + ) + return op(windows, axis=(5, 6, 7)) diff --git a/mlx/onnx/ops/op_sequence.py b/mlx/onnx/ops/op_sequence.py new file mode 100644 index 0000000..ac55c0c --- /dev/null +++ b/mlx/onnx/ops/op_sequence.py @@ -0,0 +1,60 @@ +import mlx.core as mx +from typing import List, Optional + +def SplitToSequence(x: mx.array, split: Optional[mx.array]=None, axis:int=0, keepdims=0): + if split is None: + split_len = [1] * x.shape[axis] + elif split.ndim == 0: + dim = x.shape[axis] + _len = split.item() + n = dim // int(_len) + split_len = [_len] * n + left = dim - _len * n + if left > 0: + split_len.append(left) + else: + split_len = split.tolist() + sli = [slice(0, s) for s in x.shape] + res = [] + pos = 0 + for spl in split_len: + sli[axis] = slice(pos, pos + spl) + pos += spl + res.append(x[tuple(sli)]) + return res + +def SequenceConstruct(*args: List[mx.array]): + return [*args] + +def SequenceLength(x): + return mx.array(len(x), dtype=mx.int64) + +def SequenceEmpty(): + return [] + +def SequenceAt(seq: List[mx.array], index: mx.array): + if isinstance(index, mx.array): + index = index.item() + return seq[index] + +def SequenceErase(seq: List[mx.array], index: Optional[mx.array]=None): + if index is None: + index = -1 + else: + index = index.item() + return seq[:index] + seq[index + 1:] + +def ConcatFromSequence(seq: List[mx.array], axis: int=0, new_axis=0): + if new_axis == 1: + sc = [s[..., None] for s in seq] + return mx.concatenate(sc, axis=axis) + return mx.concatenate(seq, axis=axis) + +def SequenceInsert(seq: List[mx.array], value: mx.array, ind=None): + if ind is not None: + ind = ind.item() + if ind is None: + seq.append(value) + else: + seq.insert(ind, value) + return seq \ No newline at end of file diff --git a/mlx/onnx/ops/op_slice.py b/mlx/onnx/ops/op_slice.py new file mode 100644 index 0000000..b7662d3 --- /dev/null +++ b/mlx/onnx/ops/op_slice.py @@ -0,0 +1,18 @@ +import mlx.core as mx +from typing import Optional + +def Slice( + x: mx.array, + starts: mx.array, + ends: mx.array, + axes: Optional[mx.array] = None, + steps: Optional[mx.array] = None, +): + if axes is None: + axes = mx.arange(x.ndim) + if steps is None: + steps = mx.ones(starts.shape, dtype=mx.int64) + slices = [slice(0, d) for d in x.shape] + for start, end, axe, step in zip(starts, ends, axes, steps): + slices[axe.item()] = slice(start.item(), end.item(), step.item()) + return x[tuple(slices)] \ No newline at end of file diff --git a/mlx/onnx/ops/op_split.py b/mlx/onnx/ops/op_split.py new file mode 100644 index 0000000..09bc949 --- /dev/null +++ b/mlx/onnx/ops/op_split.py @@ -0,0 +1,22 @@ +import mlx.core as mx +import math +from typing import Optional + +def Split(x: mx.array, split: Optional[mx.array] = None, num_outputs=None, axis=0): + if split is None: + if x.shape[axis] % num_outputs == 0: + split = [x.shape[axis] // num_outputs] * num_outputs + else: + cnt = math.ceil(x.shape[axis] / num_outputs) + split = [cnt] * (num_outputs - 1) + [ + x.shape[axis] - cnt * (num_outputs - 1) + ] + split = mx.array(split, dtype=mx.int64) + sli = [slice(0, s) for s in x.shape] + res = [] + pos = 0 + for spl in split.tolist(): + sli[axis] = slice(pos, pos + spl) + pos += spl + res.append(x[tuple(sli)]) + return tuple(res) \ No newline at end of file diff --git a/mlx/onnx/ops/op_topk.py b/mlx/onnx/ops/op_topk.py new file mode 100644 index 0000000..669b0a9 --- /dev/null +++ b/mlx/onnx/ops/op_topk.py @@ -0,0 +1,29 @@ +import mlx.core as mx + +def TopK(x: mx.array, k: mx.array, axis=-1, largest=1, sorted=1): + assert sorted == 1, "[TopK] Only sorted is supported" + if isinstance(k, mx.array): + k = k.item() + if x.ndim == 2 and axis == 1: + sample = mx.arange(x.shape[0])[:, None] + if largest == 0: + sorted_indices = mx.argpartition(x, kth=k - 1, axis=axis) + sorted_indices = sorted_indices[:, :k] + sorted_indices = sorted_indices[sample, mx.argsort(x[sample, sorted_indices])] + else: + sorted_indices = mx.argpartition(-x, kth=k-1, axis=axis) + sorted_indices = sorted_indices[:, :k] + sorted_indices = sorted_indices[sample, mx.argsort(-x[sample, sorted_indices])] + sorted_distances = x[sample, sorted_indices] + return (sorted_distances, sorted_indices.astype(mx.int64)) + + if largest == 0: + sorted_indices = mx.argsort(x, axis=axis) + sorted_values = mx.sort(x, axis=axis) + else: + sorted_indices = mx.argsort(-x, axis=axis) + sorted_values = -mx.sort(-x, axis=axis) + ark = mx.arange(k) + topk_sorted_indices = mx.take(sorted_indices, ark, axis=axis) + topk_sorted_values = mx.take(sorted_values, ark, axis=axis) + return topk_sorted_values, topk_sorted_indices.astype(mx.int64) diff --git a/mlx/onnx/ops/op_window.py b/mlx/onnx/ops/op_window.py new file mode 100644 index 0000000..fe78e18 --- /dev/null +++ b/mlx/onnx/ops/op_window.py @@ -0,0 +1,40 @@ +import math + +import mlx.core as mx + +from .helper import dtype_helper + + +def start(size, output_datatype, periodic): + dtype = dtype_helper(output_datatype) + N_1 = size if periodic == 1 else size - 1 + return mx.arange(size, dtype=dtype), N_1 + + +def HannWindow(size: mx.array, output_datatype=1, periodic=1): + if isinstance(size, mx.array): + size = size.item() + ni, N_1 = start(size, output_datatype, periodic) + res = mx.sin(ni * math.pi / N_1) ** 2 + return res.astype(dtype_helper(output_datatype)) + + +def BlackmanWindow(size: mx.array, output_datatype=1, periodic=1): + if isinstance(size, mx.array): + size = size.item() + ni, N_1 = start(size, output_datatype, periodic) + res = ( + 0.42 + - 0.5 * mx.cos(2 * math.pi * ni / N_1) + + 0.08 * mx.cos(4 * math.pi * ni / N_1) + ) + return res.astype(dtype_helper(output_datatype)) + + +def HammingWindow(size: mx.array, output_datatype=1, periodic=1): + if isinstance(size, mx.array): + size = size.item() + ni, N_1 = start(size, output_datatype, periodic) + alpha = 25.0 / 46.0 + res = alpha - mx.cos(2 * math.pi * ni / N_1) * (1 - alpha) + return res.astype(dtype_helper(output_datatype)) diff --git a/mlx/onnx/ops/pad.py b/mlx/onnx/ops/pad.py new file mode 100644 index 0000000..877431a --- /dev/null +++ b/mlx/onnx/ops/pad.py @@ -0,0 +1,46 @@ +import mlx.core as mx +from typing import List, Union, Optional +import math + +def convert_pad(onnx_pads: List[int], ndims:Optional[int]=None, axes:Optional[int]=None): + """ + Convert onnx padding to mlx padding + Onnx padding is [x1_begin, x2_begin...x1_end, x2_end,...] + Mlx padding is [(x1_begin, x1_end), (x2_begin, x2_end)...] + """ + if ndims and len(onnx_pads) // 2 != ndims: + onnx_pads = onnx_pads * ndims + if ndims is None: + ndims = len(onnx_pads) // 2 + if axes is None: + axes = list(range(ndims)) + res = [(0,0)] * ndims + naxes = len(axes) + for i in range(naxes): + res[axes[i]] = (onnx_pads[i], onnx_pads[i+naxes]) + return res + +def auto_pad(shape: List[int], auto_pad:str, strides: Optional[Union[int, List[int]]], kernel_shape: List[int]): + """ + Convert auto_pad to valid padding, valid options for auto_pad are: NOTSET, SAME_UPPER, SAME_LOWER, VALID + Default value is NOTSET which means explicit padding is used + SAME_UPPER or SAME_LOWER means pad the input so that `out_shape[i] = ceil(in_shape[i] / strides[i])` for each axis `i`. + """ + res = [] + if auto_pad == "NOTSET": + return res + if strides is None: + strides = [1] * len(kernel_shape) + if isinstance(strides, int): + strides = [strides] * len(kernel_shape) + if auto_pad in ("SAME_UPPER", "SAME_LOWER"): + for (dim, stride, kdim) in zip(shape[-len(kernel_shape):], strides, kernel_shape): + res.append((math.ceil(dim / stride)-1)*stride+((kdim-1)+1)-dim) + temp = [] + for s in res: + temp.append(s // 2) + temp.append(s-s // 2) + res = temp + return res[::2] + res[1::2] if auto_pad == "SAME_UPPER" else res[1::2] + res[::2] + + raise NotImplementedError(f"auto_pad {auto_pad} not implemented") \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..5cbe926 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +# Copyright © 2024 Apple Inc. + +from setuptools import setup + +setup( + name="mlx-onnx", + version="0.0.1", + author="MLX Contributors", + author_email="mlx@group.apple.com", + description="MLX backend for ONNX", + url="https://github.com/ml-explore/mlx-onnx", + install_requires=["mlx", "onnx"], + extras_require={ + "test": ["numpy", "pytest"], + "dev": ["pre-commit"], + }, + packages=["mlx.onnx"], + python_requires=">=3.8", +) diff --git a/tests/test_onnx.py b/tests/test_onnx.py new file mode 100644 index 0000000..0f409f6 --- /dev/null +++ b/tests/test_onnx.py @@ -0,0 +1,216 @@ +import os +import unittest + +import mlx.core as mx +import numpy as np +import onnx.backend.test + +from mlx.onnx import MlxBackend, MlxBackendWrapper + + +# need to conver to numpy for the testing suite +class TestMlxBackend(MlxBackend): + def __init__(self, model): + super().__init__(model) + + def run(self, inputs, **kwargs): + t = super().run(inputs, **kwargs) + return tuple( + np.array(x) if isinstance(x, mx.array) else [np.array(i) for i in x] + for x in t + ) + + +class TestMlxBackendWrapper(MlxBackendWrapper): + @classmethod + def prepare(cls, model: onnx.ModelProto, device: str): + return TestMlxBackend(model) + + +btest = onnx.backend.test.BackendTest(TestMlxBackendWrapper, __name__) + +# btest.include("test_sce_*") +btest.exclude("test_sce_*") +# TODO: these are upcasting to float32 +btest.exclude("test_div_uint8_cpu") + +# TODO: Debug these errors +btest.exclude("test_onehot_negative_indices_cpu") +# TODO: Implement +btest.exclude("test_ReplicationPad2d_*") +btest.exclude("test_wrap_pad_*") +btest.exclude("test_ReflectionPad2d_*") +btest.exclude("test_edge_*") +btest.exclude("test_reflect_pad_cpu") +btest.exclude("test_center_crop_pad_*") +btest.exclude("test_operator_pad_*") + +btest.exclude("test_operator_convtranspose_cpu") +btest.exclude("test_ConvTranspose2d_*") +btest.exclude("test_ConstantPad2d_*") +btest.exclude("test_convtranspose_*") + +# TODO: Implement dilations / col format +btest.exclude("test_averagepool_2d_dilations_cpu") +btest.exclude("test_averagepool_3d_dilations_*") +btest.exclude("test_maxpool_with_argmax_2d_precomputed_pads_cpu") +btest.exclude("test_maxpool_2d_dilations_cpu") +btest.exclude("test_maxpool_with_argmax_2d_precomputed_strides_cpu") +btest.exclude("test_maxpool_3d_dilations_*") +btest.exclude("test_MaxPool1d_stride_padding_dilation_cpu") +btest.exclude("test_MaxPool2d_stride_padding_dilation_cpu") +btest.exclude("test_Conv2d_groups_thnn_cpu") + +btest.exclude("test_maxunpool_*") + +# TODO: These are training parameters +btest.exclude("test_batchnorm_example_training_mode_cpu") +btest.exclude("test_batchnorm_epsilon_training_mode_cpu") +btest.exclude("test_BatchNorm*") + +btest.exclude("test_gelu_tanh_*") +btest.exclude("test_bitshift_*") +btest.exclude("test_bitwise_*") +btest.exclude("test_gathernd_*") +btest.exclude("test_tfidfvectorizer_*") +btest.exclude("test_unique_*") +btest.exclude("test_einsum_*") +btest.exclude("test_convinteger_*") +btest.exclude("test_nonmaxsuppression_*") +btest.exclude("test_hardmax_*") +btest.exclude("test_scatternd_*") +btest.exclude("test_scatter_*") +btest.exclude("test_scatter_elements_*") +btest.exclude("test_gridsample_*") +btest.exclude("test_bernoulli_*") + +btest.exclude("test_roialign_*") +btest.exclude("test_nonzero_example_cpu") +btest.exclude("test_upsample_nearest_cpu") +btest.exclude("test_lppool_*") +btest.exclude("test_reversesequence_*") +btest.exclude("test_col2im_*") +btest.exclude("test_deform_conv_*") +btest.exclude("test_basic_deform_conv_*") +btest.exclude("test_stft_*") +btest.exclude("test_det_*") +btest.exclude("test_dft_*") +btest.exclude("test_adagrad_*") +btest.exclude("test_momentum_*") +btest.exclude("test_nesterov_momentum_cpu") +btest.exclude("test_adam_*") + +btest.exclude("test_gru_*") +btest.exclude("test_rnn_*") +btest.exclude("test_simple_rnn_*") +btest.exclude("test_lstm_*") + +btest.exclude("test_training_dropout_*") + +btest.exclude("test_melweightmatrix_cpu") +btest.exclude("test_resize_*") +btest.exclude("test_regex_*") + +btest.exclude("test_nllloss_*") +btest.exclude("test_mvn_*") + +btest.exclude("test_ai_onnx_ml_*") + +# TODO: Quantize ops +btest.exclude("test_qlinearconv_*") +btest.exclude("test_qlinearmatmul_*") +btest.exclude("test_quantizelinear_*") +btest.exclude("test_dynamicquantizelinear_*") +btest.exclude("test_dequantizelinear_*") + +# Exclude conv due to either dilation or groups +btest.exclude("test_Conv1d_dilated_cpu") +btest.exclude("test_Conv1d_groups_cpu") +btest.exclude("test_Conv2d_depthwise_cpu") +btest.exclude("test_Conv2d_depthwise_padded_cpu") +btest.exclude("test_Conv2d_depthwise_strided_cpu") +btest.exclude("test_Conv2d_depthwise_with_multiplier_cpu") +btest.exclude("test_Conv2d_dilated_cpu") +btest.exclude("test_Conv2d_groups_cpu") +btest.exclude("test_Conv3d_*") +btest.exclude("test_bvlc_alexnet_cpu") +btest.exclude("test_squeezenet_cpu") +btest.exclude("test_shufflenet_cpu") + +btest.exclude("test_cast_no_saturate_FLOAT_to_FLOAT8*") +btest.exclude("test_cast_FLOAT_to_FLOAT8*") +btest.exclude("test_cast_no_saturate_FLOAT16_to_FLOAT8*") +btest.exclude("test_cast_FLOAT16_to_FLOAT8*") +btest.exclude("test_cast_FLOAT_to_BFLOAT16_cpu") +btest.exclude("test_cast_STRING_to_FLOAT_cpu") +btest.exclude("test_cast_BFLOAT16_to_FLOAT_cpu") +btest.exclude("test_cast_FLOAT_to_STRING_cpu") + +btest.exclude("test_castlike_FLOAT_to_BFLOAT16*") +btest.exclude("test_castlike_FLOAT_to_STRING*") +btest.exclude("test_castlike_BFLOAT16_*") +btest.exclude("test_castlike_STRING*") +btest.exclude("test_castlike_FLOAT_to_FLOAT8*") + +# TODO: need to go through and handle these better +btest.exclude("test_argmax_keepdims_example_select_last_index_cpu") +btest.exclude("test_argmax_negative_axis_keepdims_example_select_last_index_cpu") +btest.exclude("test_argmax_no_keepdims_example_select_last_index_cpu") +btest.exclude("test_argmin_no_keepdims_example_select_last_index_cpu") +btest.exclude("test_argmin_negative_axis_keepdims_example_select_last_index_cpu") +btest.exclude("test_argmin_keepdims_example_select_last_index_cpu") +btest.exclude("test_scan_sum_cpu") + + +# TODO: fmod support +btest.exclude("test_mod_mixed_sign_float32_cpu") +btest.exclude("test_mod_mixed_sign_float16_cpu") +btest.exclude("test_mod_int64_fmod_cpu") + +# TODO: Graph tests +btest.exclude("test_range_float_type_positive_delta_expanded_cpu") +btest.exclude("test_range_int32_type_negative_delta_expanded_cpu") +btest.exclude("test_scan9_sum_cpu") +btest.exclude("test_loop16_seq_none_cpu") +btest.exclude("test_loop13_seq_cpu") +btest.exclude("test_loop11_cpu") +btest.exclude("test_if_*") +btest.exclude("test_affine_grid_*") + +# TODO: Add gradient support +btest.exclude("test_gradient_*") + +btest.exclude("test_sequence_map_*") +btest.exclude("test_strnorm_*") +btest.exclude("string") + +# float64 datatype +btest.exclude("test_castlike_FLOAT_to_DOUBLE*") +btest.exclude("test_castlike_FLOAT16_to_DOUBLE*") +btest.exclude("test_sequence_model7_cpu") +btest.exclude("test_max_float64_cpu") +btest.exclude("test_min_float64_cpu") +btest.exclude("test_reduce_log_sum_exp_*") +btest.exclude("test_operator_addconstant_cpu") +btest.exclude("test_operator_add_size1_singleton_broadcast_cpu") +btest.exclude("test_operator_add_broadcast_cpu") +btest.exclude("test_operator_add_size1_broadcast_cpu") +btest.exclude("test_operator_add_size1_right_broadcast_cpu") +btest.exclude("test_cumsum_*") +btest.exclude("test_eyelike_with_dtype_cpu") +btest.exclude("test_mod_mixed_sign_float64_cpu") +btest.exclude("test_cast_FLOAT16_to_DOUBLE_cpu") +btest.exclude("test_cast_FLOAT_to_DOUBLE_cpu") + +for x in btest.test_suite: + if "OnnxBackendRealModelTest" in str(type(x)): + model = str(x).split(" ")[0] + if os.getenv("MODELS", "0") == "0": + btest.exclude(model) + else: + btest.include(model) + +globals().update(btest.enable_report().test_cases) + +if __name__ == "__main__": + unittest.main()