Skip to content

Commit 2f07c8b

Browse files
author
maxtext authors
committed
Merge pull request #1752 from AI-Hypercomputer:distillation-readme
PiperOrigin-RevId: 762477859
2 parents d702bfc + 80c0884 commit 2f07c8b

File tree

2 files changed

+181
-0
lines changed

2 files changed

+181
-0
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Knowledge Distillation
2+
3+
## Overview
4+
Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources.
5+
6+
This guide focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed:
7+
8+
1. **Offline Distillation (Dataset Generation):**
9+
* The pre-trained teacher model first generates a new dataset of input-output pairs.
10+
* The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques.
11+
12+
2. **Online Distillation (Logit Matching):**
13+
* During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously.
14+
* The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs.
15+
16+
## Running Offline Distillation with MaxText
17+
18+
The following recipe demonstrates the process of offline distillation using **Deepseek2-16b** as the teacher model and **Llama2-7b** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Here’s a step-by-step guide:
19+
20+
### Prerequisites
21+
22+
#### a. Setup environment variables
23+
24+
```bash
25+
export HF_TOKEN = <Hugging Face access token>
26+
export BASE_DIRECTORY = <Directory to store distillation results>
27+
export HF_REPO_NAME = <Hugging Face repository name to store teacher-generated dataset>
28+
export USERNAME_OR_ORG = <Owner of Hugging Face repository>
29+
export RUN_NAME = <unique name for the run>
30+
```
31+
32+
#### b. Install dependencies
33+
34+
```
35+
git clone https://github.com/AI-Hypercomputer/maxtext.git
36+
python3 -m venv ~/venv-maxtext
37+
source ~/venv-maxtext/bin/activate
38+
cd maxtext
39+
pip install -r requirements.txt
40+
```
41+
42+
### 1. Obtain and Prepare the Teacher Model
43+
44+
#### a. Download Model from Hugging Face
45+
46+
```bash
47+
huggingface-cli login # Provide your Hugging Face token
48+
huggingface-cli download deepseek-ai/DeepSeek-V2-Lite-Chat --repo-type model --local-dir ~/deepseek2-16b-chat
49+
```
50+
51+
#### b. Convert Checkpoint to MaxText Format
52+
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
53+
54+
```bash
55+
# Get unscanned checkpoint for efficient decoding
56+
JAX_PLATFORMS=cpu \
57+
python3 -m MaxText.convert_deepseek_unscanned_ckpt \
58+
--base_model_path ~/deepseek2-16b-chat \
59+
--maxtext_model_path ${BASE_DIRECTORY}/deepseek2-16-chat/unscanned \
60+
--model_size deepseek2-16b
61+
```
62+
63+
### 2. Obtain and Prepare the Student Model
64+
65+
#### a. Download Model from Hugging Face
66+
67+
```bash
68+
huggingface-cli download meta-llama/Llama-2-7b-chat-hf --repo-type model --local-dir ~/llama2-7b-chat
69+
```
70+
71+
#### b. Convert Checkpoint to MaxText Format
72+
MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint.
73+
74+
```bash
75+
# Get scanned checkpoint for fine-tuning
76+
JAX_PLATFORMS=cpu \
77+
python3 -m MaxText.llama_or_mistral_ckpt \
78+
--base-model-path ~/llama2-7b-chat \
79+
--maxtext-model-path ${BASE_DIRECTORY}/llama2-7b-chat/scanned \
80+
--model-size llama2-7b
81+
```
82+
83+
### 3. Generate Dataset using the Teacher Model
84+
Once the teacher model's checkpoint is in the MaxText format, you can run inference to generate the dataset that will be used to fine-tune the student model.
85+
86+
### 3.a. Run the JetStream Server
87+
88+
Example command to run JetStream server on `v4-8`:
89+
90+
```bash
91+
python3 -m MaxText.maxengine_server MaxText/configs/base.yml \
92+
tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \
93+
load_parameters_path=${BASE_DIRECTORY}/deepseek2-16-chat/unscanned/0/items \
94+
model_name=deepseek2-16b \
95+
per_device_batch_size=10 ici_tensor_parallelism=4 \
96+
max_target_length=2048 max_prefill_predict_length=64 \
97+
hf_access_token=$HF_TOKEN \
98+
scan_layers=False \
99+
multi_sampling=True decode_sampling_strategy=weighted
100+
```
101+
102+
Set `multi_sampling` to `True` to generate multiple independent completions per prompt.
103+
104+
105+
### 3.b. Generate Dataset using JetStream Server
106+
In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`:
107+
108+
```bash
109+
python3 -m MaxText.generate_distillation_data \
110+
--tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \
111+
--dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft \
112+
--data-columns messages \
113+
--max-prefill-length 64 --max-target-length 2048 \
114+
--hf-access-token $HF_TOKEN \
115+
--use-chat-template --remove-local-dataset-files \
116+
--num-generations 2 --batch-size 1024 --num-batches 200 \
117+
upload-to-hf --hf-repo-id ${HF_REPO_NAME}
118+
```
119+
120+
When `multi_sampling=True` (Step 3.a), the `--num-generations` parameter specifies the number of distinct completions to generate per prompt. The `--batch-size` parameter controls how many prompts are processed per batch, and `--num-batches` defines how many such batches to run. The total number of prompt-completion pairs generated is approximately `num_batches * batch_size * num_generations`.
121+
122+
For example, with `--batch-size 1024`, `--num-generations 2`, and `--num-batches 200`, this would yield `200 * 1024 * 2 = 409,600` prompt-completion pairs.
123+
124+
It's important to note that some prompts may be filtered out by pre-processing logic before inference. If the prompt sequences are longer than `max-prefill-length`, then those prompts will be filtered out in pre-processing stage.
125+
126+
Additionally, the generated dataset can be uploaded to either Hugging Face or Google Cloud Storage (GCS). To upload to Hugging Face, use the `upload-to-hf --hf-repo-id <hf_repo_name>` flags. To upload to GCS, use the `upload-to-gcs --gcs-bucket <gcs bucket name> --gcs-data-path <path in gcs bucket>` flags.
127+
128+
### 4. Fine-tune the Student Model using Supervised Fine Tuning
129+
You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText.
130+
131+
### 4.a. Fine-tune the Student Model using Dataset Generated in Step 3
132+
133+
Example command to run fine-tuning on v4-8:
134+
135+
```bash
136+
python3 -m MaxText.sft_trainer MaxText/configs/sft.yml \
137+
run_name=${RUN_NAME} \
138+
base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \
139+
tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \
140+
hf_path=${USERNAME_OR_ORG}/${HF_REPO_NAME} \
141+
train_split='train' train_data_columns=['prompt','completion'] \
142+
load_parameters_path=${BASE_DIRECTORY}/llama2-7b-chat/scanned/0/items \
143+
model_name=llama2-7b \
144+
per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
145+
max_target_length=2048 \
146+
hf_access_token=$HF_TOKEN
147+
```
148+
149+
### 4.b. **[OPTIONAL]** Fine-tune the Student Model using the Original Dataset
150+
151+
The checkpoint from the student model's fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset.
152+
153+
```bash
154+
# Get the latest checkpoint for fine-tuned student model
155+
CHECKPOINTS_PATH=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b/${RUN_NAME}/checkpoints
156+
checkpoints=$(gcloud storage ls $CHECKPOINTS_PATH)
157+
integer_dirs=()
158+
for dir in $checkpoints; do
159+
dir_name=$(basename "$dir")
160+
if [[ "$dir_name" =~ ^[0-9]+$ ]]; then
161+
integer_dirs+=("$dir_name")
162+
fi
163+
done
164+
sorted_dirs=($(printf '%s\n' "${integer_dirs[@]}" | sort -n))
165+
largest_dir="${sorted_dirs[-1]}"
166+
FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items
167+
168+
# Fine-tune student model on original dataset
169+
python3 -m MaxText.sft_trainer MaxText/configs/sft.yml \
170+
run_name=${RUN_NAME} \
171+
base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \
172+
tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \
173+
hf_path='HuggingFaceH4/ultrachat_200k' \
174+
train_split='train_sft' train_data_columns=['messages'] \
175+
load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \
176+
model_name=llama2-7b \
177+
per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \
178+
max_target_length=2048 \
179+
hf_access_token=$HF_TOKEN
180+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ google-cloud-monitoring
1414
google-api-core
1515
google-api-python-client
1616
grain[parquet]>=0.2.6
17+
huggingface_hub
1718
flax>=v0.10.6
1819
jaxtyping
1920
ml-collections

0 commit comments

Comments
 (0)