-
Notifications
You must be signed in to change notification settings - Fork 29
Add CIFAR10 dataset, dataloader and training scripts #76
base: main
Are you sure you want to change the base?
Conversation
I added mimage as a submodule, but can also just add the |
Thanks a lot for this, I looked at it and the example is solid. The thing I'm struggling with the most is it introducing a dependency to "all of basalt" while it is used for an "example using basalt". But the thing is, we have YoloV8 coming up, so we'll need the ability to read images for that as well. Maybe the right place for this (and all our examples) is a dedicated Right now I'm thinking to leave this open and postpone a decision on this, to see how the other computer vision example turns out. Also the fact that there is no official package manager makes things harder. |
This PR adds the code required to train a CNN on the CIFAR10 dataset.
I adapted the Dataloader already there to one that looks more like a Pytorch dataloader.
I would make it generic, but at this moment Mojo doesn't seem to support generic assignment?
I.e. I created a BaseDataset, so that the DataLoader can possibly call this, but I get a compiler error:
TODO: dynamic traits not supported yet, please use a compile time generic instead of 'BaseDataset'
Maybe one of you knows how to make the compile time generics work here?
I haven't fully checked how well the model trains (add accuracy?), but the loss is about the same for PyTorch and Basalt.
Could also maybe increase the model size for CIFAR 10.
The BasaltTensor
__init__(inout self, owned tensor: _Tensor[dtype]):
results in some kind of double free making the program segfault, so it's just a simple copy for the moment.