|
| 1 | +# Knowledge Distillation |
| 2 | + |
| 3 | +## Overview |
| 4 | +Knowledge Distillation is a compression technique that transfers knowledge from a larger (teacher) model to a smaller (student) model. This allows the smaller model to achieve performance levels closer to the larger one, but with significantly fewer parameters and computational resources. |
| 5 | + |
| 6 | +This guide focuses on **response-based knowledge distillation**, a technique where the student model is trained to replicate the outputs and behaviors of the teacher model. Within response-based knowledge distillation, two primary methods are often employed: |
| 7 | + |
| 8 | +1. **Offline Distillation (Dataset Generation):** |
| 9 | + * The pre-trained teacher model first generates a new dataset of input-output pairs. |
| 10 | + * The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques. |
| 11 | + |
| 12 | +2. **Online Distillation (Logit Matching):** |
| 13 | + * During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously. |
| 14 | + * The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs. |
| 15 | + |
| 16 | +## Running Offline Distillation with MaxText |
| 17 | + |
| 18 | +The following recipe demonstrates the process of offline distillation using **Deepseek2-16b** as the teacher model and **Llama2-7b** as the student model. Since this recipe fine-tunes the student model using Supervised Fine-Tuning (SFT), it's crucial to use the conversational variant for both the teacher and student models. Here’s a step-by-step guide: |
| 19 | + |
| 20 | +### Prerequisites |
| 21 | + |
| 22 | +#### a. Setup environment variables |
| 23 | + |
| 24 | +```bash |
| 25 | +export HF_TOKEN = <Hugging Face access token> |
| 26 | +export BASE_DIRECTORY = <Directory to store distillation results> |
| 27 | +export HF_REPO_NAME = <Hugging Face repository name to store teacher-generated dataset> |
| 28 | +export USERNAME_OR_ORG = <Owner of Hugging Face repository> |
| 29 | +export RUN_NAME = <unique name for the run> |
| 30 | +``` |
| 31 | + |
| 32 | +#### b. Install dependencies |
| 33 | + |
| 34 | +``` |
| 35 | +git clone https://github.com/AI-Hypercomputer/maxtext.git |
| 36 | +python3 -m venv ~/venv-maxtext |
| 37 | +source ~/venv-maxtext/bin/activate |
| 38 | +cd maxtext |
| 39 | +pip install -r requirements.txt |
| 40 | +``` |
| 41 | + |
| 42 | +### 1. Obtain and Prepare the Teacher Model |
| 43 | + |
| 44 | +#### a. Download Model from Hugging Face |
| 45 | + |
| 46 | +```bash |
| 47 | +huggingface-cli login # Provide your Hugging Face token |
| 48 | +huggingface-cli download deepseek-ai/DeepSeek-V2-Lite-Chat --repo-type model --local-dir ~/deepseek2-16b-chat |
| 49 | +``` |
| 50 | + |
| 51 | +#### b. Convert Checkpoint to MaxText Format |
| 52 | +MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint. |
| 53 | + |
| 54 | +```bash |
| 55 | +# Get unscanned checkpoint for efficient decoding |
| 56 | +JAX_PLATFORMS=cpu \ |
| 57 | +python3 -m MaxText.convert_deepseek_unscanned_ckpt \ |
| 58 | + --base_model_path ~/deepseek2-16b-chat \ |
| 59 | + --maxtext_model_path ${BASE_DIRECTORY}/deepseek2-16-chat/unscanned \ |
| 60 | + --model_size deepseek2-16b |
| 61 | +``` |
| 62 | + |
| 63 | +### 2. Obtain and Prepare the Student Model |
| 64 | + |
| 65 | +#### a. Download Model from Hugging Face |
| 66 | + |
| 67 | +```bash |
| 68 | +huggingface-cli download meta-llama/Llama-2-7b-chat-hf --repo-type model --local-dir ~/llama2-7b-chat |
| 69 | +``` |
| 70 | + |
| 71 | +#### b. Convert Checkpoint to MaxText Format |
| 72 | +MaxText requires checkpoints to be in a specific format. You'll need to convert the downloaded Hugging Face checkpoints to a MaxText-compatible checkpoint. |
| 73 | + |
| 74 | +```bash |
| 75 | +# Get scanned checkpoint for fine-tuning |
| 76 | +JAX_PLATFORMS=cpu \ |
| 77 | +python3 -m MaxText.llama_or_mistral_ckpt \ |
| 78 | + --base-model-path ~/llama2-7b-chat \ |
| 79 | + --maxtext-model-path ${BASE_DIRECTORY}/llama2-7b-chat/scanned \ |
| 80 | + --model-size llama2-7b |
| 81 | +``` |
| 82 | + |
| 83 | +### 3. Generate Dataset using the Teacher Model |
| 84 | +Once the teacher model's checkpoint is in the MaxText format, you can run inference to generate the dataset that will be used to fine-tune the student model. |
| 85 | + |
| 86 | +### 3.a. Run the JetStream Server |
| 87 | + |
| 88 | +Example command to run JetStream server on `v4-8`: |
| 89 | + |
| 90 | +```bash |
| 91 | +python3 -m MaxText.maxengine_server MaxText/configs/base.yml \ |
| 92 | + tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \ |
| 93 | + load_parameters_path=${BASE_DIRECTORY}/deepseek2-16-chat/unscanned/0/items \ |
| 94 | + model_name=deepseek2-16b \ |
| 95 | + per_device_batch_size=10 ici_tensor_parallelism=4 \ |
| 96 | + max_target_length=2048 max_prefill_predict_length=64 \ |
| 97 | + hf_access_token=$HF_TOKEN \ |
| 98 | + scan_layers=False \ |
| 99 | + multi_sampling=True decode_sampling_strategy=weighted |
| 100 | +``` |
| 101 | + |
| 102 | +Set `multi_sampling` to `True` to generate multiple independent completions per prompt. |
| 103 | + |
| 104 | + |
| 105 | +### 3.b. Generate Dataset using JetStream Server |
| 106 | +In a new tab in your terminal, run the following command to generate dataset from teacher model. Note that this is an example command to run on `v4-8`: |
| 107 | + |
| 108 | +```bash |
| 109 | +python3 -m MaxText.generate_distillation_data \ |
| 110 | + --tokenizer-path deepseek-ai/DeepSeek-V2-Lite-chat \ |
| 111 | + --dataset-path HuggingFaceH4/ultrachat_200k --data-split train_sft \ |
| 112 | + --data-columns messages \ |
| 113 | + --max-prefill-length 64 --max-target-length 2048 \ |
| 114 | + --hf-access-token $HF_TOKEN \ |
| 115 | + --use-chat-template --remove-local-dataset-files \ |
| 116 | + --num-generations 2 --batch-size 1024 --num-batches 200 \ |
| 117 | + upload-to-hf --hf-repo-id ${HF_REPO_NAME} |
| 118 | +``` |
| 119 | + |
| 120 | +When `multi_sampling=True` (Step 3.a), the `--num-generations` parameter specifies the number of distinct completions to generate per prompt. The `--batch-size` parameter controls how many prompts are processed per batch, and `--num-batches` defines how many such batches to run. The total number of prompt-completion pairs generated is approximately `num_batches * batch_size * num_generations`. |
| 121 | + |
| 122 | +For example, with `--batch-size 1024`, `--num-generations 2`, and `--num-batches 200`, this would yield `200 * 1024 * 2 = 409,600` prompt-completion pairs. |
| 123 | + |
| 124 | +It's important to note that some prompts may be filtered out by pre-processing logic before inference. If the prompt sequences are longer than `max-prefill-length`, then those prompts will be filtered out in pre-processing stage. |
| 125 | + |
| 126 | +Additionally, the generated dataset can be uploaded to either Hugging Face or Google Cloud Storage (GCS). To upload to Hugging Face, use the `upload-to-hf --hf-repo-id <hf_repo_name>` flags. To upload to GCS, use the `upload-to-gcs --gcs-bucket <gcs bucket name> --gcs-data-path <path in gcs bucket>` flags. |
| 127 | + |
| 128 | +### 4. Fine-tune the Student Model using Supervised Fine Tuning |
| 129 | +You can now fine-tune your smaller student model using supervised fine-tuning technique in MaxText. |
| 130 | + |
| 131 | +### 4.a. Fine-tune the Student Model using Dataset Generated in Step 3 |
| 132 | + |
| 133 | +Example command to run fine-tuning on v4-8: |
| 134 | + |
| 135 | +```bash |
| 136 | +python3 -m MaxText.sft_trainer MaxText/configs/sft.yml \ |
| 137 | + run_name=${RUN_NAME} \ |
| 138 | + base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \ |
| 139 | + tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \ |
| 140 | + hf_path=${USERNAME_OR_ORG}/${HF_REPO_NAME} \ |
| 141 | + train_split='train' train_data_columns=['prompt','completion'] \ |
| 142 | + load_parameters_path=${BASE_DIRECTORY}/llama2-7b-chat/scanned/0/items \ |
| 143 | + model_name=llama2-7b \ |
| 144 | + per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \ |
| 145 | + max_target_length=2048 \ |
| 146 | + hf_access_token=$HF_TOKEN |
| 147 | +``` |
| 148 | + |
| 149 | +### 4.b. **[OPTIONAL]** Fine-tune the Student Model using the Original Dataset |
| 150 | + |
| 151 | +The checkpoint from the student model's fine-tuning (on the teacher-generated dataset) can be used for a subsequent fine-tuning stage. In this step, the student model is fine-tuned on the original dataset that was initially provided to the teacher model for generating the dataset. |
| 152 | + |
| 153 | +```bash |
| 154 | +# Get the latest checkpoint for fine-tuned student model |
| 155 | +CHECKPOINTS_PATH=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b/${RUN_NAME}/checkpoints |
| 156 | +checkpoints=$(gcloud storage ls $CHECKPOINTS_PATH) |
| 157 | +integer_dirs=() |
| 158 | +for dir in $checkpoints; do |
| 159 | + dir_name=$(basename "$dir") |
| 160 | + if [[ "$dir_name" =~ ^[0-9]+$ ]]; then |
| 161 | + integer_dirs+=("$dir_name") |
| 162 | + fi |
| 163 | +done |
| 164 | +sorted_dirs=($(printf '%s\n' "${integer_dirs[@]}" | sort -n)) |
| 165 | +largest_dir="${sorted_dirs[-1]}" |
| 166 | +FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items |
| 167 | + |
| 168 | +# Fine-tune student model on original dataset |
| 169 | +python3 -m MaxText.sft_trainer MaxText/configs/sft.yml \ |
| 170 | + run_name=${RUN_NAME} \ |
| 171 | + base_output_directory=${BASE_DIRECTORY}/distillation/deepseek2-16b-distill-llama2-7b \ |
| 172 | + tokenizer_path=meta-llama/Llama-2-7b-chat-hf tokenizer_type=huggingface \ |
| 173 | + hf_path='HuggingFaceH4/ultrachat_200k' \ |
| 174 | + train_split='train_sft' train_data_columns=['messages'] \ |
| 175 | + load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \ |
| 176 | + model_name=llama2-7b \ |
| 177 | + per_device_batch_size=2 ici_expert_parallelism=-1 ici_fsdp_parallelism=4 \ |
| 178 | + max_target_length=2048 \ |
| 179 | + hf_access_token=$HF_TOKEN |
| 180 | +``` |
0 commit comments