Skip to content

Commit 7d3dee3

Browse files
authored
(1) adding support for evaluation skipping; (2) updating model and data… (#728)
* (1) adding support for evaluation skipping; (2) update model and data download instructions; (3) clean up * adding reference TTT per issue 727 request
1 parent 10786e3 commit 7d3dee3

File tree

4 files changed

+21
-30
lines changed

4 files changed

+21
-30
lines changed

llama2_70b_lora/README.md

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,26 @@ git clone https://github.com/mlperf/logging.git mlperf-logging
4141
pip install -e mlperf-logging
4242
```
4343
## Download Data and Model
44-
data can be downloaded from:
45-
[mlperf drive - train data](https://drive.google.com/file/d/1-JgY1mEafcJ7qhggt6UR3OEKAciIPd5s/view?usp=sharing)
46-
[mlperf drive - validation data](https://drive.google.com/file/d/1jrm6Lacrq49AYv0uB_Qy22xRmfPixQvs/view?usp=sharing)
47-
[mlperf drive - llama-v2 model](https://drive.google.com/drive/folders/1sTeuxkPhwkNPKIPFnOLIYCcK53oB3Ypc?usp=sharing)
48-
As defaults the scripts assume the model is under at ```./llama-v2-fused-qkv``` and the both train and validation are under ```dataset``` folder.
44+
MLCommons hosts the model for download exclusively by MLCommons Members. You must first agree to the [confidentiality notice](https://docs.google.com/forms/d/e/1FAIpQLSc_8VIvRmXM3I8KQaYnKf7gy27Z63BBoI_I1u02f4lw6rBp3g/viewform), then follow the [link[(https://drive.google.com/drive/folders/11tBZvvrh0FCm3XuR5E849K42TqftYdUF)] to a directory containing [Rclone download instructions](https://docs.google.com/document/d/1Yp2T_TsVfg8uEoEv0wa-dGP4R7r1EOHucTvDNWznWzE/edit#heading=h.at8a3matgbrk). Follow steps 1-3 to install and activate Rclone. Finally, download the model to the desired download directory (default ./models):
45+
```
46+
mkdir models
47+
cd models
48+
rclone copy mlc-llama2:Llama2-70b-fused-qkv-mlperf ./Llama2-70b-fused-qkv-mlperf -P
49+
```
50+
Similarly download the data to the desired download directory (default ./dataset):
51+
```
52+
mkdir dataset
53+
cd dataset
54+
rclone copy mlc-llama2:training/scrolls_gov_report_8k ./scrolls_gov_report_8k -P
55+
```
4956
5057
## Llama2-70B on 8 devices
5158
5259
Run:
5360
```bash
5461
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
5562
--dataset_path "./dataset" \
56-
--model_path "/software/users/ihubara/lora_clean/llama-v2-fused-qkv" \
63+
--model_path "/models/llama-v2-fused-qkv" \
5764
--max_seq_len 8192 \
5865
--bf16 True \
5966
--logging_steps 24 \
@@ -81,23 +88,5 @@ where the Accelerate config file is [this one](https://github.com/regisss/lora/b
8188

8289
> Using flash attention with `--use_flash_attn` is necessary for training on 8k-token sequences.
8390
84-
Learning curves of such a run can be found here: https://huggingface.co/regisss/test_5/tensorboard
85-
86-
87-
## Evaluation
88-
89-
To run evaluation for summarizing texts, you can run:
90-
- Without LoRA adapter weights:
91-
```
92-
python scripts/eval.py --model_name meta-llama/Llama-2-70b-hf --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
93-
```
94-
- With LoRA adapter weights:
95-
```
96-
python scripts/eval.py --peft_model_name path_to_my_lora_model --max_new_tokens 900 --seq_length 8192 --do_sample --dataset_name "tau/scrolls" --dataset_config_name "gov_report"
97-
```
98-
## expected outcome
99-
100-
A clean output (train and eval loss) of a singel run with 440 steps can be found under
101-
```
102-
convergence_example.txt
103-
```
91+
## Reference code running time
92+
On 8xA100 cards the reference $\textcolor{red}{\textbf{UNOPTIMIZED}}$ code the TTT on average is 120-140 minutes.

llama2_70b_lora/run_llama_70B_scrolls_r16.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
accelerate launch --config_file configs/default_config.yaml scripts/train.py \
22
--dataset_path "./dataset" \
3-
--model_path "/software/users/ihubara/lora_clean/llama-v2-fused-qkv" \
3+
--model_path "./models/llama-v2-fused-qkv" \
44
--max_seq_len 8192 \
55
--bf16 True \
66
--logging_steps 24 \

llama2_70b_lora/scripts/mlperf_logging_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(self, logger, train_dataset_length, eval_dataset_length,lora_alpha)
9090
}
9191

9292
def on_train_begin(self, args, state, control, **kwargs):
93-
self.gbs=int(args.per_device_train_batch_size * args.gradient_accumulation_steps * os.getenv("WORLD_SIZE", 1))
93+
self.gbs=int(args.per_device_train_batch_size * args.gradient_accumulation_steps * int(os.getenv("WORLD_SIZE", 1)))
9494
self.mllogger.event(
9595
key=constants.CACHE_CLEAR, value="True",
9696
)
@@ -170,7 +170,7 @@ def on_step_begin(
170170
)
171171
control.should_log = True
172172

173-
if state.global_step % (state.eval_steps) == 0 and state.global_step > 0:
173+
if state.global_step % (state.eval_steps) == 0 and state.global_step > args.eval_delay:
174174
self.mllogger.end(
175175
constants.BLOCK_STOP,
176176
value="",

llama2_70b_lora/scripts/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from dataclasses import dataclass, field
1717
from typing import Optional
18-
18+
import os
1919
from datasets import load_dataset
2020
from mlperf_logging_utils import LoraLogger, MLPerfCallback
2121
from transformers import HfArgumentParser, Trainer, TrainingArguments
@@ -136,6 +136,7 @@ class ScriptArguments:
136136

137137
def main(args):
138138
loralogger = LoraLogger(target_eval_loss=args.target_eval_loss)
139+
gbs=args.per_device_train_batch_size * args.gradient_accumulation_steps * int(os.getenv("WORLD_SIZE", 1))
139140
training_arguments = TrainingArguments(
140141
output_dir=args.output_dir,
141142
per_device_train_batch_size=args.per_device_train_batch_size,
@@ -154,6 +155,7 @@ def main(args):
154155
save_strategy="no",
155156
max_steps=args.max_steps,
156157
eval_steps=args.eval_steps,
158+
eval_delay=int(0.125*gbs+2)*args.eval_steps,
157159
save_steps=args.save_steps,
158160
logging_steps=args.logging_steps,
159161
push_to_hub=args.push_to_hub,

0 commit comments

Comments
 (0)