Skip to content

Commit

Permalink
Upload code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaowoguanren0615 committed Jul 14, 2024
1 parent fb7916d commit 13a58ac
Show file tree
Hide file tree
Showing 33 changed files with 4,314 additions and 1 deletion.
113 changes: 112 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,113 @@
# MambaVision
<h1 align='center'>MambaVision</h1>

This is a warehouse for MambaVision-Pytorch-model, can be used to train your image-datasets for vision tasks.

### [MambaVision: A Hybrid Mamba-Transformer Vision Backbone](https://arxiv.org/pdf/2407.08083)
### [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962)

![image](https://github.com/jiaowoguanren0615/VisionTransformer/blob/main/sample_png/KAN-model.jpg)
![image](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)


## Preparation
### Install mamba_ssm & causal_conv1d
[Github Tutorial website](https://github.com/jiaowoguanren0615/Install_Mamba)

### Download the dataset:
[flower_dataset](https://www.kaggle.com/datasets/alxmamaev/flowers-recognition).

## Project Structure
```
├── datasets: Load datasets
├── my_dataset.py: Customize reading data sets and define transforms data enhancement methods
├── split_data.py: Define the function to read the image dataset and divide the training-set and test-set
├── threeaugment.py: Additional data augmentation methods
├── models: MambaVision Model
├── build_models.py: Construct MambaVision models
├── helpers.py: Compute scaled dot product attention
├── scheduler:
├──scheduler_main.py: Fundamental Scheduler module
├──scheduler_factory.py: Create lr_scheduler methods according to parameters what you set
├──other_files: Construct lr_schedulers (cosine_lr, poly_lr, multistep_lr, etc)
├── util:
├── engine.py: Function code for a training/validation process
├── losses.py: Knowledge distillation loss, combined with teacher model (if any)
├── lr_decay.py: Define "inverse_sqrt_lr_decay" function for "Adafactor" optimizer
├── lr_sched.py: Define "adjust_learning_rate" function
├── optimizer.py: Define Sophia & Adafactor & LAMB optimizer(for mambavision models training)
├── samplers.py: Define the parameter of "sampler" in DataLoader
├── utils.py: Record various indicator information and output and distributed environment
├── estimate_model.py: Visualized evaluation indicators ROC curve, confusion matrix, classification report, etc.
└── train_gpu.py: Training model startup file (including infer process)
```

## Precautions
Before you use the code to train your own data set, please first enter the ___train_gpu.py___ file and modify the ___data_root___, ___batch_size___, ___data_len___, ___num_workers___ and ___nb_classes___ parameters. If you want to draw the confusion matrix and ROC curve, you only need to set the ___predict___ parameter to __True__.
Moreover, you can set the ___opt_auc___ parameter to True if you want to optimize your model for a better performance(maybe~).


## Train this model

### Parameters Meaning:
```
1. nproc_per_node: <The number of GPUs you want to use on each node (machine/server)>
2. CUDA_VISIBLE_DEVICES: <Specify the index of the GPU corresponding to a single node (machine/server) (starting from 0)>
3. nnodes: <number of nodes (machine/server)>
4. node_rank: <node (machine/server) serial number>
5. master_addr: <master node (machine/server) IP address>
6. master_port: <master node (machine/server) port number>
```
### Transfer Learning:
Step 1: Write the ___pre-training weight path___ into the ___args.fintune___ in string format.
Step 2: Modify the ___args.freeze_layers___ according to your own GPU memory. If you don't have enough memory, you can set this to True to freeze the weights of the remaining layers except the last layer of classification-head without updating the parameters. If you have enough memory, you can set this to False and not freeze the model weights.

#### Here is an example for setting parameters:
![image](https://github.com/jiaowoguanren0615/VisionTransformer/blob/main/sample_png/transfer_learning.jpg)


### Note:
If you want to use multiple GPU for training, whether it is a single machine with multiple GPUs or multiple machines with multiple GPUs, each GPU will divide the batch_size equally. For example, batch_size=4 in my train_gpu.py. If I want to use 2 GPUs for training, it means that the batch_size on each GPU is 4. ___Do not let batch_size=1 on each GPU___, otherwise BN layer maybe report an error.

### train model with single-machine single-GPU:
```
python train_gpu.py
```

### train model with single-machine multi-GPU:
```
python -m torch.distributed.run --nproc_per_node=8 train_gpu.py
```

### train model with single-machine multi-GPU:
(using a specified part of the GPUs: for example, I want to use the second and fourth GPUs)
```
CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.run --nproc_per_node=2 train_gpu.py
```

### train model with multi-machine multi-GPU:
(For the specific number of GPUs on each machine, modify the value of --nproc_per_node. If you want to specify a certain GPU, just add CUDA_VISIBLE_DEVICES= to specify the index number of the GPU before each command. The principle is the same as single-machine multi-GPU training)
```
On the first machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=0 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py
On the second machine: python -m torch.distributed.run --nproc_per_node=1 --nnodes=2 --node_rank=1 --master_addr=<Master node IP address> --master_port=<Master node port number> train_gpu.py
```


## Citation
```
@article{chen2022pali,
title={Pali: A jointly-scaled multilingual language-image model},
author={Chen, Xi and Wang, Xiao and Changpinyo, Soravit and Piergiovanni, AJ and Padlewski, Piotr and Salz, Daniel and Goodman, Sebastian and Grycner, Adam and Mustafa, Basil and Beyer, Lucas and others},
journal={arXiv preprint arXiv:2209.06794},
year={2022}
}
```

```
@article{you2019large,
title={Large batch optimization for deep learning: Training bert in 76 minutes},
author={You, Yang and Li, Jing and Reddi, Sashank and Hseu, Jonathan and Kumar, Sanjiv and Bhojanapalli, Srinadh and Song, Xiaodan and Demmel, James and Keutzer, Kurt and Hsieh, Cho-Jui},
journal={arXiv preprint arXiv:1904.00962},
year={2019}
}
```
Binary file added confusion_matrix.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .mydataset import build_dataset, build_transform, MyDataset
from .split_data import read_split_data
from .threeaugment import new_data_aug_generator
from .transforms import resolve_data_config
80 changes: 80 additions & 0 deletions datasets/mydataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from PIL import Image
from torchvision import transforms
from .split_data import read_split_data
from torch.utils.data import Dataset
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, create_transform


class MyDataset(Dataset):
def __init__(self, image_paths, image_labels, transforms=None):
self.image_paths = image_paths
self.image_labels = image_labels
self.transforms = transforms

def __getitem__(self, item):
image = Image.open(self.image_paths[item]).convert('RGB')
label = self.image_labels[item]
if self.transforms:
image = self.transforms(image)
return image, label

def __len__(self):
return len(self.image_paths)

@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels



def build_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform

t = []
if resize_im:
# size = int((256 / 224) * args.input_size)
size = int((1.0 / 0.96) * args.input_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(args.input_size))

t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)


def build_dataset(args):
train_image_path, train_image_label, val_image_path, val_image_label, class_indices = read_split_data(args.data_root)

train_transform = build_transform(True, args)
valid_transform = build_transform(False, args)

train_set = MyDataset(train_image_path, train_image_label, train_transform)
valid_set = MyDataset(val_image_path, val_image_label, valid_transform)

return train_set, valid_set

97 changes: 97 additions & 0 deletions datasets/split_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os, cv2, json, random
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt


def read_split_data(root, plot_image=False):
filepaths = []
labels = []
bad_images = []

random.seed(0)
assert os.path.exists(root), 'Your root does not exists!!!'

classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
classes.sort()
class_indices = {k: v for v, k in enumerate(classes)}

json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)

with open('./classes_indices.json', 'w') as json_file:
json_file.write(json_str)

every_class_num = []
supported = ['.jpg', '.png', '.jpeg', '.PNG', '.JPG', '.JPEG']

for klass in classes:
classpath = os.path.join(root, klass)
images = [os.path.join(root, klass, i) for i in os.listdir(classpath) if os.path.splitext(i)[-1] in supported]
every_class_num.append(len(images))
flist = sorted(os.listdir(classpath))
desc = f'{klass:23s}'
for f in tqdm(flist, ncols=110, desc=desc, unit='file', colour='blue'):
fpath = os.path.join(classpath, f)
fl = f.lower()
index = fl.rfind('.')
ext = fl[index:]
if ext in supported:
try:
img = cv2.imread(fpath)
filepaths.append(fpath)
labels.append(klass)
except:
bad_images.append(fpath)
print('defective image file: ', fpath)
else:
bad_images.append(fpath)

Fseries = pd.Series(filepaths, name='filepaths')
Lseries = pd.Series(labels, name='labels')
df = pd.concat([Fseries, Lseries], axis=1)

print(f'{len(df.labels.unique())} kind of images were found in the datasets')
train_df, test_df = train_test_split(df, train_size=.8, shuffle=True, random_state=123, stratify=df['labels'])

train_image_path = train_df['filepaths'].tolist()
val_image_path = test_df['filepaths'].tolist()

train_image_label = [class_indices[i] for i in train_df['labels'].tolist()]
val_image_label = [class_indices[i] for i in test_df['labels'].tolist()]

sample_df = train_df.sample(n=50, replace=False)
ht, wt, count = 0, 0, 0
for i in range(len(sample_df)):
fpath = sample_df['filepaths'].iloc[i]
try:
img = cv2.imread(fpath)
h = img.shape[0]
w = img.shape[1]
ht += h
wt += w
count += 1
except:
pass
have = int(ht / count)
wave = int(wt / count)
aspect_ratio = have / wave
print('{} images were found in the datasets.\n{} for training, {} for validation'.format(
sum(every_class_num), len(train_image_path), len(val_image_path)
))
print('average image height= ', have, ' average image width= ', wave, ' aspect ratio h/w= ', aspect_ratio)

if plot_image:
plt.bar(range(len(classes)), every_class_num, align='center')
plt.xticks(range(len(classes)), classes)

for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')

plt.xlabel('image class')
plt.ylabel('number of images')

plt.title('class distribution')
plt.show()

return train_image_path, train_image_label, val_image_path, val_image_label, class_indices
Loading

0 comments on commit 13a58ac

Please sign in to comment.