-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fb7916d
commit 13a58ac
Showing
33 changed files
with
4,314 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,113 @@ | ||
# MambaVision | ||
<h1 align='center'>MambaVision</h1> | ||
|
||
This is a warehouse for MambaVision-Pytorch-model, can be used to train your image-datasets for vision tasks. | ||
|
||
### [MambaVision: A Hybrid Mamba-Transformer Vision Backbone](https://arxiv.org/pdf/2407.08083) | ||
### [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962) | ||
|
||
data:image/s3,"s3://crabby-images/e7978/e79784b9d4b071f4c5932b92b5e84887aef67cf3" alt="image" | ||
data:image/s3,"s3://crabby-images/e31a0/e31a0ea96a7f6599c1fa44f762ec7df2e37d6d60" alt="image" | ||
|
||
|
||
## Preparation | ||
### Install mamba_ssm & causal_conv1d | ||
[Github Tutorial website](https://github.com/jiaowoguanren0615/Install_Mamba) | ||
|
||
### Download the dataset: | ||
[flower_dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition). | ||
|
||
## Project Structure | ||
``` | ||
├── datasets: Load datasets | ||
├── my_dataset.py: Customize reading data sets and define transforms data enhancement methods | ||
├── split_data.py: Define the function to read the image dataset and divide the training-set and test-set | ||
├── threeaugment.py: Additional data augmentation methods | ||
├── models: MambaVision Model | ||
├── build_models.py: Construct MambaVision models | ||
├── helpers.py: Compute scaled dot product attention | ||
├── scheduler: | ||
├──scheduler_main.py: Fundamental Scheduler module | ||
├──scheduler_factory.py: Create lr_scheduler methods according to parameters what you set | ||
├──other_files: Construct lr_schedulers (cosine_lr, poly_lr, multistep_lr, etc) | ||
├── util: | ||
├── engine.py: Function code for a training/validation process | ||
├── losses.py: Knowledge distillation loss, combined with teacher model (if any) | ||
├── lr_decay.py: Define "inverse_sqrt_lr_decay" function for "Adafactor" optimizer | ||
├── lr_sched.py: Define "adjust_learning_rate" function | ||
├── optimizer.py: Define Sophia & Adafactor & LAMB optimizer(for mambavision models training) | ||
├── samplers.py: Define the parameter of "sampler" in DataLoader | ||
├── utils.py: Record various indicator information and output and distributed environment | ||
├── estimate_model.py: Visualized evaluation indicators ROC curve, confusion matrix, classification report, etc. | ||
└── train_gpu.py: Training model startup file (including infer process) | ||
``` | ||
|
||
## Precautions | ||
Before you use the code to train your own data set, please first enter the ___train_gpu.py___ file and modify the ___data_root___, ___batch_size___, ___data_len___, ___num_workers___ and ___nb_classes___ parameters. If you want to draw the confusion matrix and ROC curve, you only need to set the ___predict___ parameter to __True__. | ||
Moreover, you can set the ___opt_auc___ parameter to True if you want to optimize your model for a better performance(maybe~). | ||
|
||
|
||
## Train this model | ||
|
||
### Parameters Meaning: | ||
``` | ||
1. nproc_per_node: <The number of GPUs you want to use on each node (machine/server)> | ||
2. CUDA_VISIBLE_DEVICES: <Specify the index of the GPU corresponding to a single node (machine/server) (starting from 0)> | ||
3. nnodes: <number of nodes (machine/server)> | ||
4. node_rank: <node (machine/server) serial number> | ||
5. master_addr: <master node (machine/server) IP address> | ||
6. master_port: <master node (machine/server) port number> | ||
``` | ||
### Transfer Learning: | ||
Step 1: Write the ___pre-training weight path___ into the ___args.fintune___ in string format. | ||
Step 2: Modify the ___args.freeze_layers___ according to your own GPU memory. If you don't have enough memory, you can set this to True to freeze the weights of the remaining layers except the last layer of classification-head without updating the parameters. If you have enough memory, you can set this to False and not freeze the model weights. | ||
|
||
#### Here is an example for setting parameters: | ||
data:image/s3,"s3://crabby-images/abcb6/abcb628cfbd3e852dad1a612e0f769c0dff1bd9d" alt="image" | ||
|
||
|
||
### Note: | ||
If you want to use multiple GPU for training, whether it is a single machine with multiple GPUs or multiple machines with multiple GPUs, each GPU will divide the batch_size equally. For example, batch_size=4 in my train_gpu.py. If I want to use 2 GPUs for training, it means that the batch_size on each GPU is 4. ___Do not let batch_size=1 on each GPU___, otherwise BN layer maybe report an error. | ||
|
||
### train model with single-machine single-GPU: | ||
``` | ||
python train_gpu.py | ||
``` | ||
|
||
### train model with single-machine multi-GPU: | ||
``` | ||
python -m torch.distributed.run --nproc_per_node=8 train_gpu.py | ||
``` | ||
|
||
### train model with single-machine multi-GPU: | ||
(using a specified part of the GPUs: for example, I want to use the second and fourth GPUs) | ||
``` | ||
CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.run --nproc_per_node=2 train_gpu.py | ||
``` | ||
|
||
### train model with multi-machine multi-GPU: | ||
(For the specific number of GPUs on each machine, modify the value of --nproc_per_node. If you want to specify a certain GPU, just add CUDA_VISIBLE_DEVICES= to specify the index number of the GPU before each command. The principle is the same as single-machine multi-GPU training) | ||
``` | ||
On the first machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py | ||
On the second machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py | ||
``` | ||
|
||
|
||
## Citation | ||
``` | ||
@article{chen2022pali, | ||
title={Pali: A jointly-scaled multilingual language-image model}, | ||
author={Chen, Xi and Wang, Xiao and Changpinyo, Soravit and Piergiovanni, AJ and Padlewski, Piotr and Salz, Daniel and Goodman, Sebastian and Grycner, Adam and Mustafa, Basil and Beyer, Lucas and others}, | ||
journal={arXiv preprint arXiv:2209.06794}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
``` | ||
@article{you2019large, | ||
title={Large batch optimization for deep learning: Training bert in 76 minutes}, | ||
author={You, Yang and Li, Jing and Reddi, Sashank and Hseu, Jonathan and Kumar, Sanjiv and Bhojanapalli, Srinadh and Song, Xiaodan and Demmel, James and Keutzer, Kurt and Hsieh, Cho-Jui}, | ||
journal={arXiv preprint arXiv:1904.00962}, | ||
year={2019} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .mydataset import build_dataset, build_transform, MyDataset | ||
from .split_data import read_split_data | ||
from .threeaugment import new_data_aug_generator | ||
from .transforms import resolve_data_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import torch | ||
from PIL import Image | ||
from torchvision import transforms | ||
from .split_data import read_split_data | ||
from torch.utils.data import Dataset | ||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform | ||
|
||
|
||
class MyDataset(Dataset): | ||
def __init__(self, image_paths, image_labels, transforms=None): | ||
self.image_paths = image_paths | ||
self.image_labels = image_labels | ||
self.transforms = transforms | ||
|
||
def __getitem__(self, item): | ||
image = Image.open(self.image_paths[item]).convert('RGB') | ||
label = self.image_labels[item] | ||
if self.transforms: | ||
image = self.transforms(image) | ||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.image_paths) | ||
|
||
@staticmethod | ||
def collate_fn(batch): | ||
images, labels = tuple(zip(*batch)) | ||
images = torch.stack(images, dim=0) | ||
labels = torch.as_tensor(labels) | ||
return images, labels | ||
|
||
|
||
|
||
def build_transform(is_train, args): | ||
resize_im = args.input_size > 32 | ||
if is_train: | ||
# this should always dispatch to transforms_imagenet_train | ||
transform = create_transform( | ||
input_size=args.input_size, | ||
is_training=True, | ||
color_jitter=args.color_jitter, | ||
auto_augment=args.aa, | ||
interpolation=args.train_interpolation, | ||
re_prob=args.reprob, | ||
re_mode=args.remode, | ||
re_count=args.recount, | ||
) | ||
if not resize_im: | ||
# replace RandomResizedCropAndInterpolation with | ||
# RandomCrop | ||
transform.transforms[0] = transforms.RandomCrop( | ||
args.input_size, padding=4) | ||
return transform | ||
|
||
t = [] | ||
if resize_im: | ||
# size = int((256 / 224) * args.input_size) | ||
size = int((1.0 / 0.96) * args.input_size) | ||
t.append( | ||
# to maintain same ratio w.r.t. 224 images | ||
transforms.Resize(size, interpolation=3), | ||
) | ||
t.append(transforms.CenterCrop(args.input_size)) | ||
|
||
t.append(transforms.ToTensor()) | ||
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | ||
return transforms.Compose(t) | ||
|
||
|
||
def build_dataset(args): | ||
train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root) | ||
|
||
train_transform = build_transform(True, args) | ||
valid_transform = build_transform(False, args) | ||
|
||
train_set = MyDataset(train_image_path, train_image_label, train_transform) | ||
valid_set = MyDataset(val_image_path, val_image_label, valid_transform) | ||
|
||
return train_set, valid_set | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import os, cv2, json, random | ||
import pandas as pd | ||
from tqdm import tqdm | ||
from sklearn.model_selection import train_test_split | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def read_split_data(root, plot_image=False): | ||
filepaths = [] | ||
labels = [] | ||
bad_images = [] | ||
|
||
random.seed(0) | ||
assert os.path.exists(root), 'Your root does not exists!!!' | ||
|
||
classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] | ||
classes.sort() | ||
class_indices = {k: v for v, k in enumerate(classes)} | ||
|
||
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4) | ||
|
||
with open('./classes_indices.json', 'w') as json_file: | ||
json_file.write(json_str) | ||
|
||
every_class_num = [] | ||
supported = ['.jpg', '.png', '.jpeg', '.PNG', '.JPG', '.JPEG'] | ||
|
||
for klass in classes: | ||
classpath = os.path.join(root, klass) | ||
images = [os.path.join(root, klass, i) for i in os.listdir(classpath) if os.path.splitext(i)[-1] in supported] | ||
every_class_num.append(len(images)) | ||
flist = sorted(os.listdir(classpath)) | ||
desc = f'{klass:23s}' | ||
for f in tqdm(flist, ncols=110, desc=desc, unit='file', colour='blue'): | ||
fpath = os.path.join(classpath, f) | ||
fl = f.lower() | ||
index = fl.rfind('.') | ||
ext = fl[index:] | ||
if ext in supported: | ||
try: | ||
img = cv2.imread(fpath) | ||
filepaths.append(fpath) | ||
labels.append(klass) | ||
except: | ||
bad_images.append(fpath) | ||
print('defective image file: ', fpath) | ||
else: | ||
bad_images.append(fpath) | ||
|
||
Fseries = pd.Series(filepaths, name='filepaths') | ||
Lseries = pd.Series(labels, name='labels') | ||
df = pd.concat([Fseries, Lseries], axis=1) | ||
|
||
print(f'{len(df.labels.unique())} kind of images were found in the datasets') | ||
train_df, test_df = train_test_split(df, train_size=.8, shuffle=True, random_state=123, stratify=df['labels']) | ||
|
||
train_image_path = train_df['filepaths'].tolist() | ||
val_image_path = test_df['filepaths'].tolist() | ||
|
||
train_image_label = [class_indices[i] for i in train_df['labels'].tolist()] | ||
val_image_label = [class_indices[i] for i in test_df['labels'].tolist()] | ||
|
||
sample_df = train_df.sample(n=50, replace=False) | ||
ht, wt, count = 0, 0, 0 | ||
for i in range(len(sample_df)): | ||
fpath = sample_df['filepaths'].iloc[i] | ||
try: | ||
img = cv2.imread(fpath) | ||
h = img.shape[0] | ||
w = img.shape[1] | ||
ht += h | ||
wt += w | ||
count += 1 | ||
except: | ||
pass | ||
have = int(ht / count) | ||
wave = int(wt / count) | ||
aspect_ratio = have / wave | ||
print('{} images were found in the datasets.\n{} for training, {} for validation'.format( | ||
sum(every_class_num), len(train_image_path), len(val_image_path) | ||
)) | ||
print('average image height= ', have, ' average image width= ', wave, ' aspect ratio h/w= ', aspect_ratio) | ||
|
||
if plot_image: | ||
plt.bar(range(len(classes)), every_class_num, align='center') | ||
plt.xticks(range(len(classes)), classes) | ||
|
||
for i, v in enumerate(every_class_num): | ||
plt.text(x=i, y=v + 5, s=str(v), ha='center') | ||
|
||
plt.xlabel('image class') | ||
plt.ylabel('number of images') | ||
|
||
plt.title('class distribution') | ||
plt.show() | ||
|
||
return train_image_path, train_image_label, val_image_path, val_image_label, class_indices |
Oops, something went wrong.