Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat long clip experiment #21

Open
wants to merge 4 commits into
base: feat-long-clip
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __call__(self, features: torch.Tensor, pca_dim: Optional[int] = None):

all_features = torch.cat(gathered_features, dim=0)
if pca_dim:
all_features = PCA(all_features)
all_features = PCA(all_features,pca_dim)
return all_features


Expand All @@ -95,7 +95,6 @@ def gather_features(
rank=rank,
world_size=world_size,
use_horovod=use_horovod,
pca_dim=pca_dim,
)
return (
gather(image_features, pca_dim=pca_dim), # apply PCA on image faetures if set
Expand Down Expand Up @@ -161,7 +160,7 @@ def get_logits(self, image_features, text_features, logit_scale, pca_dim: Option
logits_per_text = logits_per_image.T
else:
if pca_dim:
image_features = PCA(image_features)
image_features = PCA(image_features, pca_dim)
logits_per_image = logit_scale * image_features @ text_features.T
logits_per_text = logit_scale * text_features @ image_features.T

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
"timm_proj": "none"
},
"text_cfg": {
"context_length": 77,
"hf_model_name": "jinaai/jina-bert-v2-base-en-flash",
"hf_tokenizer_name": "jinaai/jina-bert-v2-base-en-flash",
"context_length": 248,
"hf_model_name": "jinaai/jina-bert-v2-base-en",
"hf_tokenizer_name": "jinaai/jina-bert-v2-base-en",
"hf_pooler_type": "mean_pooler",
"hf_trust_remote_code": true,
"proj_type": null
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def PCA(input_tensor, PCA_dim):
mean = torch.mean(input_tensor, dim=0)
X_centered = input_tensor - mean.unsqueeze(0)
X_centered = X_centered.float()
cov_matrix = torch.mm(X_centered.T, X_centered)
cov_matrix = torch.mm(X_centered.T, X_centered).type_as(X_centered)
eigenvalues, eigenvectors = torch.linalg.eig(cov_matrix)
eigenvalues = eigenvalues.float()
eigenvectors = eigenvectors.float()
Expand Down
2 changes: 1 addition & 1 deletion src/training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def flatten(dictionary, parent_key='', separator='_'):
clip_model=model,
_tokenizer=tokenizer,
hf_tokenizer_name=args.mteb_tokenizer_name,
max_seq_length=args.mteb_max_seq_length,
max_seq_length=args.mteb_max_sequence_length,
device=args.device,
)

Expand Down
54 changes: 45 additions & 9 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer

try:
Expand All @@ -31,6 +32,7 @@
get_tokenizer,
trace_model,
)
from open_clip.tokenizer import DEFAULT_CONTEXT_LENGTH

from training.data import MultiS3EmbeddingDataset, dynamic_collate, get_multimodal_data
from training.distributed import broadcast_object, init_distributed_device, is_master
Expand Down Expand Up @@ -81,8 +83,6 @@ def get_latest_checkpoint(path: str, remote: bool):


def create_embeddings_dataloader(args):
# emb_tokenizer_name,
# emb_tokenizer_max_length

import training.embloss as embeddings_loss_module

Expand Down Expand Up @@ -161,7 +161,11 @@ def create_embeddings_dataloader(args):
tokenizer_options={
'padding': 'max_length',
'truncation': True,
'max_length': args.emb_tokenizer_max_length,
'max_length': (
args.emb_max_sequence_length
or args.max_sequence_length
or DEFAULT_CONTEXT_LENGTH
),
'return_tensors': 'pt',
},
input_type_dict=input_type_dict,
Expand Down Expand Up @@ -450,12 +454,13 @@ def main(args):
# create optimizer and scaler
optimizer = None
scaler = None

'''
if args.train_data or args.dataset_type == 'synthetic':
assert not args.trace, 'Cannot train with traced model'
model, optimizer, scaler = create_optimizer(
args=args, model=model, dsinit=dsinit
)
'''

# optionally resume from a checkpoint
start_epoch = 0
Expand Down Expand Up @@ -486,7 +491,7 @@ def main(args):
checkpoint = pt_load(
os.path.join(args.resume, 'state.pt'), map_location='cpu'
)
if 'epoch' in checkpoint:
if not 'epoch' in checkpoint:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch = checkpoint['epoch']
sd = checkpoint['state_dict']
Expand All @@ -508,14 +513,25 @@ def main(args):
)
else:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model.load_state_dict(checkpoint)
sd = checkpoint['state_dict']
if (
not args.distributed
and next(iter(sd.items()))[0].startswith('module')
):
sd = {k[len('module.'):]: v for k, v in sd.items()}
model.load_state_dict(sd)
#model.load_state_dict(checkpoint)
logging.info(
f'=> loaded checkpoint \'{args.resume}\' (epoch {start_epoch})'
)

if args.train_data or args.dataset_type == 'synthetic':
assert not args.trace, 'Cannot train with traced model'
model, optimizer, scaler = create_optimizer(
args=args, model=model, dsinit=dsinit
)
# initialize datasets
# multimodal
tokenizer = get_tokenizer(args.model)
tokenizer = get_tokenizer(args.model, context_length=args.max_sequence_length)
data = get_multimodal_data(
args,
(preprocess_train, preprocess_val),
Expand All @@ -530,6 +546,25 @@ def main(args):
if args.mtl:
emb_dataset, emb_dataloader, emb_losses = create_embeddings_dataloader(args)

long_clip_dataloader=None,
if args.longclip:
from training.sharegpt4v import share4v_train_dataset, share4v_val_dataset

trainset = share4v_train_dataset(preprocess_train, tokenizer)
train_sampler = (
DistributedSampler(dataset=trainset, shuffle=True)
if args.distributed
else None
)
long_clip_dataloader = torch.utils.data.DataLoader(
trainset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=32,
pin_memory=True,
shuffle=True
)

# create scheduler if train
scheduler = None
if 'train' in data and optimizer is not None:
Expand Down Expand Up @@ -639,6 +674,7 @@ def main(args):
emb_dataloader=emb_dataloader,
emb_losses=emb_losses,
tb_writer=writer,
long_clip_dataloader=long_clip_dataloader,
)
completed_epoch = epoch + 1

Expand Down
18 changes: 15 additions & 3 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ def parse_args(args):
help='Override default image resize (& crop) mode during inference',
)
parser.add_argument('--aug-cfg', nargs='*', default={}, action=ParseKwargs)
parser.add_argument(
'--max-sequence-length',
default=None,
type=int,
help='CLIP training max sequence length.',
)
parser.add_argument(
'--grad-checkpointing',
default=False,
Expand Down Expand Up @@ -605,7 +611,7 @@ def parse_args(args):
help='The tokenizer to use when running the MTEB benchmark.',
)
parser.add_argument(
'--mteb-max-seq-length',
'--mteb-max-sequence-length',
type=int,
default=8192,
help='The max sequence length used during MTEB evaluation.',
Expand Down Expand Up @@ -653,9 +659,9 @@ def parse_args(args):
help='The tokenizer to use for the embedding dataloader.',
)
parser.add_argument(
'--emb-tokenizer-max-length',
'--emb-max-sequence-length',
type=int,
default=128,
default=None,
help='The max sequence length of the embedding dataloader.',
)
parser.add_argument(
Expand Down Expand Up @@ -695,6 +701,12 @@ def parse_args(args):
action='store_true',
help='If set to true apply pca to image features and collect long & short loss',
)
parser.add_argument(
'--pca-dim',
type=int,
default=None,
help='What dim of pca to apply to short loss, if args.longclip is set to true',
)

args = parser.parse_args(args)

Expand Down
78 changes: 78 additions & 0 deletions src/training/sharegpt4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
import os
import random

import cv2
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image

data4v_root = '/home/akoukounas/ShareGPT4V/data/'
json_name = 'share-captioner_coco_lcs_sam_1246k_1107.json'
image_root = '/home/akoukounas/ShareGPT4V/data/'


class share4v_val_dataset(data.Dataset):
def __init__(self, preprocess, tokenizer):
self.data4v_root = data4v_root
self.json_name = json_name
self.image_root = image_root
self.total_len = 1000
with open(data4v_root + json_name, 'r', encoding='utf8') as fp:
self.json_data = json.load(fp)[: self.total_len]
self.preprocess = preprocess
self.tokenizer = tokenizer

def __len__(self):
return self.total_len

def __getitem__(self, index):
caption = self.json_data[index]['conversations'][1]['value']
caption = caption.replace('\n', ' ')
image_name = self.image_root + self.json_data[index]['image']
image = Image.open(image_name)
image_tensor = self.preprocess(image)
return image_tensor, caption


class share4v_train_dataset(data.Dataset):
def __init__(self, preprocess, tokenizer):
self.data4v_root = data4v_root
self.json_name = json_name
self.image_root = image_root
self.total_len = 1000
with open(data4v_root + json_name, 'r', encoding='utf8') as fp:
self.json_data = json.load(fp)[self.total_len :]
self.preprocess = preprocess
self.tokenizer = tokenizer

def __len__(self):
return len(self.json_data)

def __getitem__(self, index):
try: # try except is used in case any image is missing (we have not tested sam dataset)
caption = self.json_data[index]['conversations'][1]['value']
caption = caption.replace('\n', ' ')

caption_short = caption.split('. ')[0]

image_name = self.image_root + self.json_data[index]['image']
image = Image.open(image_name)
image_tensor = self.preprocess(image)
return image_tensor, caption, caption_short
except:
print(image_name)
with open('./image_names.txt', 'w') as file:
file.write(image_name + '\n')

index += 570486 # first 570486 images are from sam, after that index all images are all okay
caption = self.json_data[index]['conversations'][1]['value']
caption = caption.replace('\n', ' ')

caption_short = caption.split('. ')[0]

image_name = self.image_root + self.json_data[index]['image']
image = Image.open(image_name)
image_tensor = self.preprocess(image)
return image_tensor, caption, caption_short
37 changes: 27 additions & 10 deletions src/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def __iter__(self):
def __next__(self):
return None, (None, None)

class DummyLongClipDataloader:

def __iter__(self):
return self

def __next__(self):
return None, None, None


def train_one_epoch(
model,
Expand All @@ -89,6 +97,7 @@ def train_one_epoch(
emb_dataloader=None,
emb_losses=None,
tb_writer=None,
long_clip_dataloader=None,
):
device = torch.device(args.device)
autocast = get_autocast(args.precision)
Expand All @@ -103,7 +112,12 @@ def train_one_epoch(
assert emb_losses is not None
else:
emb_dataloader = DummyEmbeddingsDataloader()


if args.longclip:
assert long_clip_dataloader is not None
else:
long_clip_dataloader = DummyLongClipDataloader()

# set epoch in process safe manner via sampler or shared_epoch
data['train'].set_epoch(epoch)
dataloader = data['train'].dataloader
Expand All @@ -130,8 +144,8 @@ def train_one_epoch(
start = time.time()

# training loop
for i, (mm_batch, (emb_dataset, (emb_batch, emb_labels))) in enumerate(zip(
dataloader, islice(emb_dataloader, 1, None)
for i, (mm_batch, (emb_dataset, (emb_batch, emb_labels)), long_clip_batch) in enumerate(zip(
dataloader, islice(emb_dataloader, 1, None), long_clip_dataloader
)):

i_accum = i // args.accum_freq
Expand All @@ -140,14 +154,15 @@ def train_one_epoch(
if not args.skip_scheduler:
scheduler(step)

images, texts = mm_batch
images, texts, = mm_batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
texts = texts.to(device=device, non_blocking=True)
if args.longclip:
images, texts, texts_short = long_clip_batch
images = images.to(device=device, dtype=input_dtype, non_blocking=True)
images_short = images.clone()
texts_short = []
for text in texts:
texts_short.append(text.split(". ")[0])
texts = texts.to(device=device, non_blocking=True)
texts_short = texts_short.to(device=device, non_blocking=True)
if emb_batch:
for batch in emb_batch:
batch.to(device=device)
Expand Down Expand Up @@ -205,10 +220,12 @@ def train_one_epoch(

losses['embedding_loss'] = args.emb_loss_weight * embedding_loss

if args.longclip:
if args.longclip and args.pca_dim is not None:
modelout_short = model(images_short, texts_short)
loss_short = loss(**modelout_short, output_dict=True, pca_dim=32)
losses['short_loss'] = 0.1 * loss_short
loss_short = loss(
**modelout_short, output_dict=True, pca_dim=args.pca_dim
)
losses['short_loss'] = 0.1 * loss_short['contrastive_loss']
total_loss = sum(losses.values())
losses['loss'] = total_loss
backward(total_loss, model, scaler=scaler, deepspeed=args.deepspeed)
Expand Down