Skip to content

Commit f6e2c27

Browse files
committed
examples: use compress_model, remove compress_quantized_weights; update save flow and docs; Fixes #2105
1 parent f9e7426 commit f6e2c27

File tree

1 file changed

+51
-93
lines changed

1 file changed

+51
-93
lines changed

examples/quantize_and_pack_int4.ipynb

Lines changed: 51 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
{
1717
"cell_type": "code",
18-
"execution_count": 2,
18+
"execution_count": 12,
1919
"metadata": {},
2020
"outputs": [],
2121
"source": [
@@ -25,8 +25,7 @@
2525
"from compressed_tensors.quantization import (\n",
2626
" QuantizationConfig,\n",
2727
" QuantizationStatus,\n",
28-
" apply_quantization_config,\n",
29-
" compress_quantized_weights\n",
28+
" apply_quantization_config\n",
3029
")\n",
3130
"from compressed_tensors.compressors import ModelCompressor\n",
3231
"from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator\n",
@@ -37,51 +36,9 @@
3736
},
3837
{
3938
"cell_type": "code",
40-
"execution_count": 3,
39+
"execution_count": 13,
4140
"metadata": {},
4241
"outputs": [
43-
{
44-
"data": {
45-
"application/vnd.jupyter.widget-view+json": {
46-
"model_id": "c883cdc8ecd04866bd01d61796b81c26",
47-
"version_major": 2,
48-
"version_minor": 0
49-
},
50-
"text/plain": [
51-
"config.json: 0%| | 0.00/560 [00:00<?, ?B/s]"
52-
]
53-
},
54-
"metadata": {},
55-
"output_type": "display_data"
56-
},
57-
{
58-
"data": {
59-
"application/vnd.jupyter.widget-view+json": {
60-
"model_id": "32b18b14b6774ce7b61d2854a1ed5f49",
61-
"version_major": 2,
62-
"version_minor": 0
63-
},
64-
"text/plain": [
65-
"model.safetensors: 0%| | 0.00/4.40G [00:00<?, ?B/s]"
66-
]
67-
},
68-
"metadata": {},
69-
"output_type": "display_data"
70-
},
71-
{
72-
"data": {
73-
"application/vnd.jupyter.widget-view+json": {
74-
"model_id": "370c6d18521a4b65833a411728be1ed7",
75-
"version_major": 2,
76-
"version_minor": 0
77-
},
78-
"text/plain": [
79-
"generation_config.json: 0%| | 0.00/129 [00:00<?, ?B/s]"
80-
]
81-
},
82-
"metadata": {},
83-
"output_type": "display_data"
84-
},
8542
{
8643
"data": {
8744
"text/plain": [
@@ -113,7 +70,7 @@
11370
")"
11471
]
11572
},
116-
"execution_count": 3,
73+
"execution_count": 13,
11774
"metadata": {},
11875
"output_type": "execute_result"
11976
}
@@ -122,7 +79,7 @@
12279
"# load a dense, unquantized tiny llama model\n",
12380
"device = \"cuda:0\"\n",
12481
"model_name = \"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\"\n",
125-
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=\"auto\")\n",
82+
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device, torch_dtype=torch.bfloat16)\n",
12683
"model"
12784
]
12885
},
@@ -139,7 +96,7 @@
13996
},
14097
{
14198
"cell_type": "code",
142-
"execution_count": 23,
99+
"execution_count": 14,
143100
"metadata": {},
144101
"outputs": [],
145102
"source": [
@@ -164,7 +121,7 @@
164121
},
165122
{
166123
"cell_type": "code",
167-
"execution_count": null,
124+
"execution_count": 15,
168125
"metadata": {},
169126
"outputs": [],
170127
"source": [
@@ -177,7 +134,7 @@
177134
},
178135
{
179136
"cell_type": "code",
180-
"execution_count": null,
137+
"execution_count": 16,
181138
"metadata": {},
182139
"outputs": [],
183140
"source": [
@@ -198,14 +155,14 @@
198155
},
199156
{
200157
"cell_type": "code",
201-
"execution_count": 28,
158+
"execution_count": 17,
202159
"metadata": {},
203160
"outputs": [
204161
{
205162
"name": "stderr",
206163
"output_type": "stream",
207164
"text": [
208-
"Running calibration: 512it [00:33, 15.42it/s]\n"
165+
"Running calibration: 512it [00:58, 8.82it/s]\n"
209166
]
210167
}
211168
],
@@ -233,20 +190,24 @@
233190
"\n",
234191
"Notice that at this point, the weight itself is still a floating point and has not been quantized. \n",
235192
"\n",
236-
"To convert the weights to an integer type, we need to apply the `compress_quantized_weights` function. After compressing the weights, a forward pass of the model can no longer be run in PyTorch"
193+
"To convert the weights to an integer type, we need to apply the `compress_model` function. After compressing the weights, a forward pass of the model can no longer be run in PyTorch.\n",
194+
"\n",
195+
"After compressing the quantized model with the `pack-quantized` format, weights are represented as logical int4 values packed into int32 containers ( `weight_packed` ), with the original shape recorded in `weight_shape`.\n",
196+
"\n",
197+
"This packed representation is what gets saved to disk when using ModelCompressor.compress_model(model)."
237198
]
238199
},
239200
{
240201
"cell_type": "code",
241-
"execution_count": 29,
202+
"execution_count": 18,
242203
"metadata": {},
243204
"outputs": [
244205
{
245206
"name": "stdout",
246207
"output_type": "stream",
247208
"text": [
248-
"Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
249-
"Weight min: -1.587890625 max: 1.0283203125 dtype: torch.float16\n"
209+
"Scale: tensor([-3.0465e+26], device='cuda:0', dtype=torch.bfloat16), Zero Point: tensor([0], device='cuda:0', dtype=torch.int8)\n",
210+
"Weight min: -1.5859375 max: 1.03125 dtype: torch.bfloat16\n"
250211
]
251212
}
252213
],
@@ -262,64 +223,62 @@
262223
},
263224
{
264225
"cell_type": "code",
265-
"execution_count": 30,
226+
"execution_count": 19,
266227
"metadata": {},
267228
"outputs": [
229+
{
230+
"name": "stderr",
231+
"output_type": "stream",
232+
"text": [
233+
"Compressing model: 154it [00:02, 59.75it/s]"
234+
]
235+
},
268236
{
269237
"name": "stdout",
270238
"output_type": "stream",
271239
"text": [
272-
"Scale: tensor([17296.], device='cuda:4', dtype=torch.float16), Zero Point: tensor([0], device='cuda:4', dtype=torch.int8)\n",
273-
"Weight min: 0 max: 0 dtype: torch.int8\n"
240+
"Compressed weight scale: tensor([-3.0465e+26], device='cuda:0', dtype=torch.bfloat16), zero point: tensor([0], device='cuda:0', dtype=torch.int8)\n",
241+
"Compressed weight dtype: torch.int32\n",
242+
"Compressed weight shape: torch.Size([2048, 256])\n",
243+
"Uncompressed weight shape: tensor([2048, 2048], device='cuda:0')\n"
244+
]
245+
},
246+
{
247+
"name": "stderr",
248+
"output_type": "stream",
249+
"text": [
250+
"\n"
274251
]
275252
}
276253
],
277254
"source": [
278255
"# convert quantized weights to integers\n",
279-
"model.apply(compress_quantized_weights)\n",
256+
"compressor = ModelCompressor(quantization_config=config)\n",
257+
"compressor.compress_model(model)\n",
280258
"\n",
281259
"state_dict = model.state_dict()\n",
282260
"example_layer = \"model.layers.0.self_attn.q_proj.weight\"\n",
283261
"scale = state_dict[example_layer + \"_scale\"]\n",
284262
"zero_point = state_dict[example_layer + \"_zero_point\"]\n",
285-
"weight = state_dict[example_layer]\n",
286-
"print(f\"Scale: {scale}, Zero Point: {zero_point}\")\n",
287-
"print(f\"Weight min: {torch.min(weight)} max: {torch.max(weight)} dtype: {weight.dtype}\")"
288-
]
289-
},
290-
{
291-
"cell_type": "markdown",
292-
"metadata": {},
293-
"source": [
294-
"After compressing the quantized model, the weight matrix has a range of int4 but is stored in an int8. \n",
295-
"\n",
296-
"We can further compress the model on disk using the `pack-quantized` format we specified in the config. This compression format will pack the int4 weights into int32"
263+
"weight = state_dict[example_layer + \"_packed\"]\n",
264+
"shape = state_dict[example_layer + \"_shape\"]\n",
265+
"print(f\"Compressed weight scale: {scale}, zero point: {zero_point}\")\n",
266+
"print(f\"Compressed weight dtype: {weight.dtype}\")\n",
267+
"print(f\"Compressed weight shape: {weight.shape}\")\n",
268+
"print(f\"Uncompressed weight shape: {shape}\")"
297269
]
298270
},
299271
{
300272
"cell_type": "code",
301-
"execution_count": 31,
273+
"execution_count": 20,
302274
"metadata": {},
303275
"outputs": [
304276
{
305277
"name": "stdout",
306278
"output_type": "stream",
307279
"text": [
308-
"Compression format: pack-quantized\n"
309-
]
310-
},
311-
{
312-
"name": "stderr",
313-
"output_type": "stream",
314-
"text": [
315-
"Quantized Compression: 100%|██████████| 509/509 [00:03<00:00, 153.70it/s]\n"
316-
]
317-
},
318-
{
319-
"name": "stdout",
320-
"output_type": "stream",
321-
"text": [
322-
"Size of the model's weights on disk using safetensors: 712.23 MB\n"
280+
"Compression format: pack-quantized\n",
281+
"Size of the model's weights on disk using safetensors: 712.25 MB\n"
323282
]
324283
}
325284
],
@@ -330,9 +289,8 @@
330289
"compression_format = config.format\n",
331290
"print(f\"Compression format: {compression_format}\")\n",
332291
"\n",
333-
"compressor = ModelCompressor(quantization_config=config)\n",
334-
"compressed_state_dict = compressor.compress(model)\n",
335-
"model.save_pretrained(output_dir, state_dict=compressed_state_dict)\n",
292+
"\n",
293+
"model.save_pretrained(output_dir, state_dict=model.state_dict())\n",
336294
"compressor.update_config(output_dir)\n",
337295
"\n",
338296
"compressed_size_on_disk_mb = os.path.getsize(os.path.join(output_dir, \"model.safetensors\")) / 1024 / 1024\n",
@@ -356,7 +314,7 @@
356314
"name": "python",
357315
"nbconvert_exporter": "python",
358316
"pygments_lexer": "ipython3",
359-
"version": "3.10.12"
317+
"version": "3.12.12"
360318
}
361319
},
362320
"nbformat": 4,

0 commit comments

Comments
 (0)