3
3
import os
4
4
import torch
5
5
import wandb
6
- import deepspeed
7
6
import functools
8
7
from torch .nn .parallel import DistributedDataParallel as DDP
9
8
from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
16
15
world_info_from_env ,
17
16
get_fsdp_config ,
18
17
get_fsdp_checkpoint_config ,
19
- get_deepspeed_config ,
20
18
)
21
19
from open_flamingo .train .train_utils import (
22
20
train_one_epoch ,
23
21
random_seed ,
24
- load_deepspeed_checkpoint ,
25
22
find_most_recent_checkpoint ,
26
23
load_checkpoint ,
27
24
save_checkpoint ,
28
- save_deepspeed_checkpoint ,
29
25
)
30
26
from open_flamingo .train .losses import (
31
27
SUPPORTED_LOSSES ,
@@ -44,8 +40,8 @@ def main():
44
40
parser .add_argument (
45
41
"--model_family" , default = "flamingo" , type = str , choices = SUPPORTED_MODEL_FAMILIES
46
42
)
47
- parser .add_argument ("--vision_encoder_path" , default = "ViT-L -14" , type = str )
48
- parser .add_argument ("--vision_encoder_pretrained" , default = "openai " , type = str )
43
+ parser .add_argument ("--vision_encoder_path" , default = "ViT-SO400M -14-SigLIP-384 " , type = str )
44
+ parser .add_argument ("--vision_encoder_pretrained" , default = "webli " , type = str )
49
45
parser .add_argument ("--lm_path" , default = "facebook/opt-1.3b" , type = str )
50
46
parser .add_argument (
51
47
"--tokenizer_path" ,
@@ -73,7 +69,7 @@ def main():
73
69
parser .add_argument (
74
70
"--resume_from_checkpoint" ,
75
71
type = str ,
76
- help = "path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default. If using deepspeed this should be a directory, not a file. " ,
72
+ help = "path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default." ,
77
73
default = None ,
78
74
)
79
75
parser .add_argument (
@@ -187,20 +183,6 @@ def main():
187
183
"--fsdp_sharding_strategy" , default = "full" , type = str , choices = ["full" , "hybrid" ]
188
184
)
189
185
190
- # deepspeed args
191
- parser .add_argument (
192
- "--deepspeed" ,
193
- default = False ,
194
- action = "store_true" ,
195
- help = "Use deepspeed for distributed training." ,
196
- )
197
- parser .add_argument (
198
- "--deepspeed_stage" ,
199
- default = 2 ,
200
- type = int ,
201
- help = "DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)" ,
202
- )
203
-
204
186
# wandb args
205
187
parser .add_argument ("--report_to_wandb" , default = False , action = "store_true" )
206
188
parser .add_argument (
@@ -251,16 +233,10 @@ def main():
251
233
if args .save_checkpoints_to_wandb and not args .report_to_wandb :
252
234
raise ValueError ("save_checkpoints_to_wandb requires report_to_wandb" )
253
235
254
- if args .fsdp and args .deepspeed :
255
- raise ValueError ("Select either FSDP or deepspeed for distributed training." )
256
-
257
236
if args .fsdp :
258
- print (
259
- "Warning: FSDP is experimental and not fully tested. Preference should be given to Deepspeed."
260
- )
261
237
assert (
262
- "dev" in torch . __version__ and torch .__version__ > "2.0.1"
263
- ), "FSDP requires torch nightly > 2.0.1"
238
+ torch .__version__ > "2.0.1"
239
+ ), "FSDP requires torch > 2.0.1"
264
240
265
241
# Set up distributed training
266
242
args .local_rank , args .rank , args .world_size = world_info_from_env ()
@@ -269,13 +245,7 @@ def main():
269
245
if args .offline :
270
246
os .environ ["WANDB_MODE" ] = "offline"
271
247
os .environ ["TRANSFORMERS_OFFLINE" ] = "1"
272
- if args .deepspeed :
273
- torch .cuda .set_device (args .local_rank )
274
- deepspeed .init_distributed ()
275
- ds_config = get_deepspeed_config (args )
276
- device_id = args .local_rank
277
- else :
278
- device_id = init_distributed_device (args )
248
+ device_id = init_distributed_device (args )
279
249
280
250
random_seed (args .seed )
281
251
@@ -316,8 +286,8 @@ def main():
316
286
args .resume_from_checkpoint = find_most_recent_checkpoint (args )
317
287
318
288
if (
319
- args .resume_from_checkpoint is not None and not args . deepspeed
320
- ): # deepspeed handles checkpoint loading
289
+ args .resume_from_checkpoint is not None
290
+ ):
321
291
resume_from_epoch , checkpoint = load_checkpoint (args , model )
322
292
else :
323
293
resume_from_epoch = 0
@@ -327,7 +297,6 @@ def main():
327
297
model .init_gradient_checkpointing ()
328
298
329
299
# Initialize FSDP / DDP, and ensure the model is on GPU
330
- # Deepspeed is initialized later
331
300
if args .fsdp :
332
301
auto_wrap_policy = functools .partial (
333
302
lambda_auto_wrap_policy , lambda_fn = model .get_fsdp_lambda_fn ()
@@ -336,7 +305,7 @@ def main():
336
305
distributed_model = FSDP (
337
306
model , auto_wrap_policy = auto_wrap_policy , ** wrapper_kwargs
338
307
)
339
- elif not args . deepspeed :
308
+ else :
340
309
model = model .to (device_id )
341
310
distributed_model = DDP (model , device_ids = [device_id ])
342
311
@@ -351,7 +320,7 @@ def main():
351
320
)
352
321
353
322
# load optimizer checkpoint
354
- if args .resume_from_checkpoint is not None and not args . deepspeed :
323
+ if args .resume_from_checkpoint is not None :
355
324
osd = checkpoint ["optimizer_state_dict" ]
356
325
if args .fsdp :
357
326
FSDP .set_state_dict_type (
@@ -370,7 +339,7 @@ def main():
370
339
]
371
340
total_training_steps = (
372
341
getattr (args , f"train_num_samples_{ datasets_to_train_on [0 ]} " )
373
- // getattr (args , f"batch_size_{ datasets_to_train_on [0 ]} " )
342
+ // ( getattr (args , f"batch_size_{ datasets_to_train_on [0 ]} " ) * args . gradient_accumulation_steps * args . world_size )
374
343
) * args .num_epochs
375
344
376
345
if args .rank == 0 :
@@ -395,21 +364,9 @@ def main():
395
364
)
396
365
397
366
# load lr scheduler checkpoint
398
- if args .resume_from_checkpoint is not None and not args . deepspeed :
367
+ if args .resume_from_checkpoint is not None :
399
368
lr_scheduler .load_state_dict (checkpoint ["lr_scheduler_state_dict" ])
400
369
401
- if args .deepspeed :
402
- distributed_model , optimizer , _ , lr_scheduler = deepspeed .initialize (
403
- model = model ,
404
- optimizer = optimizer ,
405
- args = args ,
406
- config = ds_config ,
407
- lr_scheduler = lr_scheduler ,
408
- dist_init_required = True ,
409
- )
410
- if args .resume_from_checkpoint is not None :
411
- resume_from_epoch = load_deepspeed_checkpoint (args , distributed_model )
412
-
413
370
# Initialize the loss fn
414
371
loss_fn = get_loss_fn (args .loss )
415
372
@@ -435,10 +392,7 @@ def main():
435
392
wandb = wandb ,
436
393
)
437
394
438
- if args .deepspeed :
439
- save_deepspeed_checkpoint (distributed_model , epoch , args )
440
- else :
441
- save_checkpoint (distributed_model , optimizer , lr_scheduler , epoch , args )
395
+ save_checkpoint (distributed_model , optimizer , lr_scheduler , epoch , args )
442
396
443
397
444
398
if __name__ == "__main__" :
0 commit comments