-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
685 lines (593 loc) · 26.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
"""
Code adapted from: https://github.com/loubnabnl/santacoder-finetuning
"""
import argparse
from functools import wraps
import os
import json
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import time
from push_checkpoints import push_checkpoints
from datasets.load import load_dataset, load_from_disk
from datasets import DatasetDict, Dataset
from number_of_tokens import get_total_tokens, get_total_tokens_from_iterable
from dataset_loader import ConstantLengthDataset, PaddedDataset, TQDMWraper
from lora import hacky_model_convert, find_all_linear_names, SavePeftModelCallback
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from pathlib import Path
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
BitsAndBytesConfig,
Trainer,
TrainingArguments,
TrainerCallback,
TrainerState,
TrainerControl,
)
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
class SaveTokenizerCallback(TrainerCallback):
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
self.tokenizer.save_pretrained(checkpoint_folder)
def neftune_post_forward_hook(module, input, output): # NOTE: copypasted from TRL
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set
`module.neftune_noise_alpha` to the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + \
torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
def load_special_tokens(tokenizer):
"""
Loads the special tokens from the special_tokens_map.json file.
"""
thisFolder = os.path.dirname(os.path.abspath(__file__))
file = open(os.path.join(thisFolder, "special_tokens_map.json"))
special_tokens_map = json.load(file)
tokenizer.add_special_tokens(special_tokens_map)
def unwrap_model(model: nn.Module) -> nn.Module: # NOTE: copypasted from TRL
"""
Recursively unwraps a model from potential containers (as used in distributed training).
Args:
model (`torch.nn.Module`): The model to unwrap.
"""
# since there could be multiple levels of wrapping, unwrap recursively
if hasattr(model, "module"):
return unwrap_model(model.module)
else:
return model
class BetterTrainer(Trainer):
def __init__(
self,
neftune_noise_alpha: Optional[float] = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.neftune_noise_alpha = neftune_noise_alpha
@wraps(Trainer.train)
def train(self, *args, **kwargs): # NOTE: copypasted from TRL
# Activate neftune right before training.
if self.neftune_noise_alpha is not None:
self.model = self._trl_activate_neftune(self.model)
output = super().train(*args, **kwargs)
# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None:
unwrapped_model = unwrap_model(self.model)
embeddings = unwrapped_model.get_input_embeddings()
self.neftune_hook_handle.remove()
if hasattr(embeddings, "neftune_noise_alpha"):
del embeddings.neftune_noise_alpha
return output
def _trl_activate_neftune(self, model): # NOTE: copypasted from TRL
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://arxiv.org/abs/2310.05914
Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts.
"""
print("*** Activating NEFTune ***")
unwrapped_model = unwrap_model(model)
embeddings = unwrapped_model.get_input_embeddings()
embeddings.neftune_noise_alpha = self.neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(
neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
def get_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str,
default="bigcode/starcoderbase", help="Path to the model to train.")
parser.add_argument("--model_revision", type=str, default="main",
help="Revision of the model to train, if on the hub.")
parser.add_argument("--dataset_name", type=str,
default="bigcode/starcoderdata", help="Name of the dataset to train on.")
parser.add_argument("--dataset_revision", type=str, default="main",
help="Revision of the dataset to train on, if on the hub.")
parser.add_argument("--subset", type=str, default=None,
help="Subset of the dataset to use.")
parser.add_argument("--split", type=str, default="train",
help="Split of the dataset to use.")
parser.add_argument("--perc_valid_set", type=float, default=0.005,
help="Percentage of the dataset to use as validation set.")
parser.add_argument("--data_column", type=str, default="content",
help="Column of the dataset to use as input.")
parser.add_argument("--min_edu_score", type=float, default=0.0,
help="Minimum education score of the examples to use.")
parser.add_argument("--edu_score_column", type=str, default=None,
help="Column of the dataset to use as education score.")
parser.add_argument("--no_shuffle_train", action="store_true",
help="Do not shuffle the training set.")
parser.add_argument("--objective", type=str, default="lm",
choices=["lm", "seqcls"],
help="Objective to train on.")
parser.add_argument("--lora", action="store_true", help="Enable LoRA.")
parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank.")
parser.add_argument("--lora_alpha", type=int,
default=32, help="LoRA alpha.")
parser.add_argument("--lora_dropout", type=float,
default=0.05, help="LoRA dropout.")
parser.add_argument("--lora_bits", type=int, default=8,
choices=[4, 8], help="LoRA quantization bits.")
parser.add_argument("--lora_extreme", action="store_true",
help="Enable extreme quantization (QLoRA).")
parser.add_argument("--seq_length", type=int,
default=4096, help="Sequence length (i.e. context window size).")
parser.add_argument("--epochs", type=int, default=10,
help="Number of epochs.")
parser.add_argument("--batch_size", type=int,
default=2, help="Batch size.")
parser.add_argument("--gradient_accumulation_steps",
type=int, default=8, help="Gradient accumulation steps.")
parser.add_argument("--total_tokens", type=int,
help="Total number of tokens in the dataset. If not provided, will be computed.")
parser.add_argument("--no_approx_tokens", action="store_true",
help="Disables token approximation.")
parser.add_argument("--dataset_loader", type=str,
default="constant", choices=["constant", "padded"],
help="Dataset loader to use.")
parser.add_argument("--pad_token_id", type=int,
default=None, help="Pad token id.")
parser.add_argument("--trim_longer", action="store_true",
help="Trim longer sequences when using the padded dataset loader.")
parser.add_argument("--mask_loss_till_token_id", type=int, default=None,
help="Allows to mask the loss until a certain token id.")
parser.add_argument("--learning_rate", type=float,
default=5e-5, help="Learning rate.")
parser.add_argument("--lr_scheduler_type", type=str, default="linear", choices=[
"cosine", "linear", "constant"], help="Learning rate scheduler type.")
parser.add_argument("--num_warmup_steps", type=int,
default=100, help="Number of warmup steps.")
parser.add_argument("--weight_decay", type=float,
default=0.05, help="Weight decay.")
parser.add_argument("--attention_dropout", type=float, default=None,
help="Attention dropout -- may not be supported by all models.")
parser.add_argument("--neft_alpha", type=float,
default=None, help="NEFTune noise alpha.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--no_fp16", action="store_false",
help="Disable fp16.")
parser.add_argument("--bf16", action="store_true", help="Use bfloat16.")
parser.add_argument("--torch_dtype", type=str, default=None, choices=[
"float16", "bfloat16", "float32"], help="Force the model to use a certain dtype.")
parser.add_argument("--no_gradient_checkpointing", action="store_false",
help="Disable gradient checkpointing.")
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
parser.add_argument("--num_workers", type=int, default=None,
help="Number of workers for the dataset loader.")
parser.add_argument("--output_dir", type=str, default="./checkpoints",
help="Output directory for checkpoints.")
parser.add_argument("--log_freq", default=1, type=int,
help="Frequency of logging (in steps).")
parser.add_argument("--no_wandb", action="store_true",
help="Disable wandb logging.")
parser.add_argument("--eval_freq", default=1.0, type=float,
help="Evaluate X times per epoch, can be < 1.")
parser.add_argument("--save_freq", default=1.0, type=float,
help="Save X times per epoch, can be < 1.")
parser.add_argument("--checkpoint", type=str, default=None,
help="Checkpoint to resume training from.")
parser.add_argument("--save_strategy", type=str, default="steps", choices=[
"steps", "epoch"], help="Save strategy.")
parser.add_argument("--save_total_limit", type=int,
default=10, help="Total number of checkpoints to save.")
parser.add_argument("--push_to_hub", type=str, default=None,
help="Push the checkpoints to the hub and specify the repo name.")
parser.add_argument("--local-rank", type=int, default=0)
parser.add_argument("--custom_tokenizer", type=str,
default=None, help="Path to a custom tokenizer.")
parser.add_argument("--eval_dataset", type=str, default=None,
help="Path to the evaluation dataset. Needs to have split='test' and follow same format as train.")
parser.add_argument("--save_best_model", action="store_true")
parser.add_argument("--deepspeed", type=str)
parser.add_argument("--fa2", action="store_true")
return parser
def is_main(args):
"""
Returns True if the process is the main process.
"""
return args.local_rank in [-1, 0]
def get_rank(args):
"""
Returns the rank of the process.
"""
return args.local_rank if args.local_rank != -1 else 0
def chars_token_ratio(dataset, tokenizer, data_column, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
total_characters += len(example[data_column])
total_tokens += len(tokenizer(example[data_column]).tokens())
return total_characters / total_tokens
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def get_num_gpus(args):
"""
Returns the number of GPUs used in the training.
"""
# NOTE: using torch.cuda.device_count() isn't bulletproof, but it's good enough for our purposes
return 1 if args.local_rank == -1 else torch.cuda.device_count()
def load_source_dataset(args):
"""
Loads the source dataset from the given arguments.
"""
num_gpus = get_num_gpus(args)
# if dataset is a path, load it from the path
if os.path.isdir(args.dataset_name):
dataset = load_from_disk(args.dataset_name)
# if DatasetDict, select the split
if isinstance(dataset, DatasetDict):
dataset = dataset[args.split]
else:
kwargs = {}
if args.subset:
kwargs["data_dir"] = args.subset
dataset = load_dataset(
args.dataset_name,
revision=args.dataset_revision,
split=args.split,
num_proc=args.num_workers // num_gpus,
**kwargs,
)
return dataset
def dataset_splits(dataset, args) -> Tuple[Dataset, Optional[Dataset]]:
"""
Splits the dataset into training and validation sets based on the arguments.
"""
if args.eval_dataset:
valid_data = load_dataset(args.eval_dataset, split="test")
train_data = dataset
elif args.perc_valid_set == 0:
train_data = dataset
valid_data = None
else:
dataset = dataset.train_test_split( # type: ignore
test_size=args.perc_valid_set, seed=args.seed)
train_data = dataset["train"]
valid_data = dataset["test"]
if args.edu_score_column:
train_data = train_data.filter(
lambda example: example[args.edu_score_column] >= args.min_edu_score
)
if not args.eval_dataset:
assert valid_data is not None
valid_data = valid_data.filter(
lambda example: example[args.edu_score_column] >= args.min_edu_score
)
if not args.no_shuffle_train:
train_data = train_data.shuffle(seed=args.seed)
return train_data, valid_data
def dataset_loader_constructor_factory(tokenizer, args):
"""
Returns the dataset loader constructor based on the arguments.
"""
if args.dataset_loader == "constant":
ctr = chars_token_ratio(
load_source_dataset(args), tokenizer, args.data_column)
print(
f"The character to token ratio of the dataset is: {ctr:.2f}")
return lambda data, infinite: ConstantLengthDataset(
tokenizer,
data,
infinite=infinite,
seq_length=args.seq_length,
chars_per_token=ctr,
content_field=args.data_column,
)
elif args.dataset_loader == "padded":
return lambda data, infinite: PaddedDataset(
tokenizer,
data,
infinite=infinite,
seq_length=args.seq_length,
content_field=args.data_column,
pad_token_id=args.pad_token_id,
trim_longer=args.trim_longer,
mask_loss_till_token_id=args.mask_loss_till_token_id,
)
else:
raise ValueError(
f"Invalid dataset loader: {args.dataset_loader}.")
def create_dataloaders(tokenizer, args, tqdm=True):
"""
Creates the dataset loaders for training and validation.
"""
# TODO: for multi-node, this won't work
num_gpus = get_num_gpus(args)
train_data, valid_data = dataset_splits(load_source_dataset(args), args)
print(
f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data) if valid_data else None}"
)
ds_constructor = dataset_loader_constructor_factory(tokenizer, args)
total_tokens = args.total_tokens
if total_tokens is None:
# approximate if dataset is too large (greater than 50k examples)
if len(train_data) > 50000 and not args.no_approx_tokens:
print(
f"Dataset is too large ({len(train_data)} examples). Approximating the number of tokens. Disable with --no_approx_tokens.")
total_tokens_50k = get_total_tokens(
train_data, tokenizer, args.data_column, 50000)
total_tokens = total_tokens_50k * (len(train_data) // 50000)
else:
total_tokens = get_total_tokens_from_iterable(
ds_constructor(train_data, infinite=False))
training_examples = total_tokens // args.seq_length
effective_batch_size = args.batch_size * \
args.gradient_accumulation_steps * num_gpus
max_steps = max(1, int(training_examples /
effective_batch_size * args.epochs))
if is_main(args):
print(f" #### SCALING LAWS ####")
print(f" ###### Examples ######")
print(f"Total tokens: {total_tokens}")
print(f"Seq length: {args.seq_length}")
print(f"Training examples: {training_examples}")
print(f" ####### Batch #######")
print(f"Batch size: {args.batch_size}")
print(
f"Gradient accumulation steps: {args.gradient_accumulation_steps}")
print(f"Number of GPUs: {num_gpus}")
print(f"Effective batch size: {effective_batch_size}")
print(f"Epoch: {args.epochs}")
print(f"####### RESULT ###########")
print(f"# Max steps: {max_steps} #")
print(f"##########################")
train_dataset = ds_constructor(train_data, infinite=True)
valid_dataset = ds_constructor(
valid_data, infinite=False) if valid_data else None
if tqdm and is_main(args):
train_dataset = TQDMWraper(
train_dataset, num_iters=training_examples * args.epochs, desc="Training")
if valid_dataset:
valid_dataset = TQDMWraper(
valid_dataset, desc="Evaluating")
return max_steps, train_dataset, valid_dataset
def get_model_class(args):
"""
Returns the model class based on the arguments.
"""
if args.objective == "lm":
return AutoModelForCausalLM
elif args.objective == "seqcls":
return AutoModelForSequenceClassification
else:
raise ValueError(f"Invalid training objective: {args.objective}")
def dtype_from_str(dtype_str):
"""
Converts the string representation of a dtype to a torch dtype.
"""
if dtype_str == "float16":
return torch.float16
elif dtype_str == "bfloat16":
return torch.bfloat16
elif dtype_str == "float32":
return torch.float32
else:
raise ValueError(f"Invalid dtype: {dtype_str}")
def run_training(args, max_steps, train_data, val_data):
"""
Runs the training loop.
"""
os.makedirs(args.output_dir, exist_ok=True)
model_extra_kwargs = {}
if args.lora:
config = {}
if args.lora_bits == 8:
config["load_in_8bit"] = True
elif args.lora_bits == 4:
config["load_in_4bit"] = True
else:
assert False, f"Invalid lora_bits: {args.lora_bits}"
if args.lora_extreme: # extreme quantization
print("LOADING EXTREME QUANTIZATION!!!!!!!")
config["load_in_8bit"] = False # disable if set by user
config["load_in_4bit"] = True
config["llm_int8_threshold"] = 6.0
config["llm_int8_has_fp16_weight"] = False
config["bnb_4bit_quant_type"] = "nf4"
config["bnb_4bit_use_double_quant"] = True
dtype = None
if args.bf16:
dtype = torch.bfloat16
else:
dtype = torch.float16
config["bnb_4bit_compute_dtype"] = dtype
model_extra_kwargs["device_map"] = {
"": args.local_rank if args.local_rank != -1 else 0
}
model_extra_kwargs["quantization_config"] = BitsAndBytesConfig(
**config)
if args.fa2:
# need to set dtype to either float16 or bfloat16
if args.bf16:
model_extra_kwargs["torch_dtype"] = torch.bfloat16
else:
model_extra_kwargs["torch_dtype"] = torch.float16
if args.torch_dtype: # overrides everything else
model_extra_kwargs["torch_dtype"] = dtype_from_str(args.torch_dtype)
if args.attention_dropout is not None: # some models dont support this
model_extra_kwargs["attention_dropout"] = args.attention_dropout
train_data.start_iteration = 0
# calculate eval and save steps from max steps
steps_per_epoch = max_steps // args.epochs
eval_steps = int(steps_per_epoch * args.eval_freq)
eval_steps = None if eval_steps == 0 else eval_steps # disable if 0
save_steps = int(steps_per_epoch * args.save_freq)
print(f"Eval steps: {eval_steps} -- Save steps: {save_steps}")
extra_training_args = {}
if args.deepspeed:
extra_training_args["deepspeed"] = args.deepspeed
training_args = TrainingArguments(
output_dir=args.output_dir,
torch_compile=True,
dataloader_drop_last=True,
evaluation_strategy="steps" if eval_steps else "no",
max_steps=max_steps,
eval_steps=eval_steps,
save_steps=save_steps,
logging_steps=args.log_freq,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
warmup_steps=args.num_warmup_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
gradient_checkpointing=args.no_gradient_checkpointing,
save_total_limit=99999 if args.lora else args.save_total_limit,
save_strategy=args.save_strategy,
fp16=args.no_fp16,
bf16=args.bf16,
weight_decay=args.weight_decay,
report_to=["wandb"] if not args.no_wandb else [],
load_best_model_at_end=args.save_best_model,
ddp_find_unused_parameters=False,
**extra_training_args,
)
print(
f"*** [{get_rank(args)}] Loading the model with '{args.objective}' objective ***")
# disable caching mechanism when using gradient checkpointing
model = get_model_class(args).from_pretrained(
args.model_path,
revision=args.model_revision,
trust_remote_code=True,
use_cache=not args.no_gradient_checkpointing,
use_flash_attention_2=args.fa2,
**model_extra_kwargs,
)
if args.lora:
print("Preparing model for LoRA training")
prepare_model_for_kbit_training(
model, use_gradient_checkpointing=not args.no_gradient_checkpointing)
all_linear_layers = find_all_linear_names(model)
added_modules = set(["c_proj", "c_attn", "q_attn"])
modules = list(added_modules.union(all_linear_layers))
print(f"Target modules: {modules}")
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=modules,
)
model.enable_input_require_grads()
model = get_peft_model(model, lora_config)
hacky_model_convert(args, model)
print_trainable_parameters(model) if not args.deepspeed else None
if is_main(args) and not args.no_wandb:
import wandb
wandb_name = None
if not os.getenv("WANDB_NAME"):
date = time.strftime("%Y-%m-%d-%H-%M")
lora_str = "_lora" if args.lora else ""
model_name = args.model_path.rstrip("/").split("/")[-1]
dataset_name = args.dataset_name.rstrip("/").split("/")[-1]
wandb_name = f"{model_name}_{dataset_name}_{date}_{lora_str}"
try:
wandb.init(name=wandb_name)
except Exception as e:
print(
f"Failed to initialize wandb -- Can disable it with the `--no_wandb` option.\nError: {e}")
raise e
trainer_extra_kwargs: Dict[str, Any] = {
"callbacks": [SaveTokenizerCallback(train_data.get_tokenizer())],
}
if args.lora:
trainer_extra_kwargs["callbacks"] += [SavePeftModelCallback]
trainer = BetterTrainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
neftune_noise_alpha=args.neft_alpha,
**trainer_extra_kwargs
)
print(f"*** [{get_rank(args)}] Training... ***")
if args.checkpoint:
print(f"***** Loading checkpoint from {args.checkpoint} *****")
trainer.train(args.checkpoint)
else:
# find latest checkpoint
chks = []
for checkpoint in Path(args.output_dir).glob("checkpoint-*"):
try:
num = int(checkpoint.name.split("-")[-1])
chks.append(num)
except ValueError:
continue
if len(chks) > 0:
chks.sort()
last_chk = chks[-1]
print(
f"***** Automatically detected checkpoint. Loading checkpoint from {last_chk} *****")
trainer.train(f"{args.output_dir}/checkpoint-{last_chk}")
else:
trainer.train()
if args.push_to_hub:
push_checkpoints(args.output_dir, args.push_to_hub)
if args.save_best_model:
print("Saving best model...")
model.save_pretrained(os.path.join(args.output_dir, "best/"))