Skip to content

Commit 0389d45

Browse files
Bug fixes (#63)
* Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update compiler.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update compiler.py * Update compiler.py * Fix requires grad * Update peft_utils.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update peft_utils.py * Update peft_utils.py * Update peft_utils.py * Update peft_utils.py * _get_dtype * Update utils.py * better attribution * Update compiler.py * Last layer GC * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Saving, llama.cpp * Update llama_cpp.py * Update llama_cpp.py * Add error handling for forward method in patch_gradient_accumulation (#32) * Update peft_utils.py * Update peft_utils.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update __init__.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update llama_cpp.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update saving_utils.py * Update __init__.py * Create vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Licensing, bug fixes * Update patching_utils.py * Update vllm_utils.py * Update __init__.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * rotary * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * load lora from tensors * 0.7.1 lora request * Update vllm_lora_request.py * Update vllm_lora_request.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update compiler.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update __init__.py * Create logging_utils.py * Update logging_utils.py * Update logging_utils.py * Update logging_utils.py * Update logging_utils.py * Update vllm_utils.py * fix_zero_training_loss * Update dataset_utils.py * Update training_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_lora_worker_manager.py * Update vllm_lora_worker_manager.py * Update vllm_utils.py * Update vllm_lora_worker_manager.py * Update vllm_lora_worker_manager.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update vllm_utils.py * Update __init__.py * Create rl_replacements.py * Update __init__.py * Fixes * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update saving_utils.py * Update saving_utils.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update __init__.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update __init__.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update gradient_checkpointing.py * Update __init__.py * compiling issues * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * Update compiler.py * SFT dataset prepare * Update pyproject.toml * Update rl_replacements.py * Update rl_replacements.py * Update rl_replacements.py --------- Co-authored-by: Edd <[email protected]>
1 parent 936b85a commit 0389d45

File tree

6 files changed

+208
-45
lines changed

6 files changed

+208
-45
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "unsloth_zoo"
77
dynamic = ["version"]
88
description = "Utils for Unsloth"
99
readme = "README.md"
10-
requires-python = ">=3.9"
10+
requires-python = ">=3.9,<=3.12"
1111
license = {file = "LICENSE"}
1212
keywords = ["ai", "llm",]
1313
authors = [
@@ -26,15 +26,15 @@ dependencies = [
2626
"triton ; platform_system == 'Linux'",
2727
"packaging",
2828
"tyro",
29-
"transformers>=4.46.1",
29+
"transformers>=4.46.1,!=4.47.0",
3030
"datasets>=2.16.0",
3131
"sentencepiece>=0.2.0",
3232
"tqdm",
3333
"psutil",
3434
"wheel>=0.42.0",
3535
"numpy",
3636
"accelerate>=0.34.1",
37-
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0",
37+
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,!=0.15.0,<=0.15.2",
3838
"peft>=0.7.1,!=0.11.0",
3939
"protobuf<4.0.0",
4040
"huggingface_hub",

unsloth_zoo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# You should have received a copy of the GNU Lesser General Public License
1515
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1616

17-
__version__ = "2025.3.1"
17+
__version__ = "2025.3.2"
1818

1919
from importlib.util import find_spec
2020
if find_spec("unsloth") is None:

unsloth_zoo/compiler.py

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from .utils import Version, is_main_process
3737
import triton
3838
from .peft_utils import get_lora_layer_modules
39+
from importlib.metadata import version as importlib_version
40+
from packaging.version import Version
3941

4042
# Disable some compilations if old versions are seen
4143
OLD_TORCH_VERSION = Version(torch.__version__) < Version("2.5.0")
@@ -62,10 +64,22 @@ def filter(self, x): return not (self.text in x.getMessage())
6264
global COMBINED_UNSLOTH_NAME
6365
global UNSLOTH_COMPILE_LOCATION
6466
global UNSLOTH_CREATED_FUNCTIONS
67+
global UNSLOTH_COMPILE_LOCATION_USE_TEMP
6568
COMBINED_UNSLOTH_NAME = "unsloth_compiled_module"
6669
UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache"
6770
UNSLOTH_CREATED_FUNCTIONS = []
68-
71+
UNSLOTH_COMPILE_LOCATION_USE_TEMP = False
72+
73+
# Try creating a directory for cache, or else use a temporary folder
74+
try:
75+
os.makedirs(UNSLOTH_COMPILE_LOCATION, exist_ok = True)
76+
if not os.path.exists(UNSLOTH_COMPILE_LOCATION): raise
77+
except:
78+
from tempfile import TemporaryDirectory
79+
UNSLOTH_COMPILE_LOCATION_USE_TEMP = True
80+
UNSLOTH_COMPILE_LOCATION = TemporaryDirectory(ignore_cleanup_errors = True).name
81+
print(f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = {UNSLOTH_COMPILE_LOCATION}")
82+
pass
6983

7084
_license_header = """
7185
# Unsloth Zoo - Utilities for Unsloth
@@ -210,8 +224,11 @@ def create_new_function(
210224
add_torch_compile = False,
211225
):
212226
# All Unsloth Zoo code licensed under LGPLv3
227+
old_new_source = new_source
228+
213229
global UNSLOTH_CREATED_FUNCTIONS
214230
global UNSLOTH_COMPILE_LOCATION
231+
global UNSLOTH_COMPILE_LOCATION_USE_TEMP
215232
if new_source[0] == " ":
216233
spaces = new_source.find("def")
217234
new_source = new_source.split("\n")
@@ -237,6 +254,24 @@ def create_new_function(
237254
# Fix super() Not necessary anymore!
238255
# new_source = new_source.replace("super()", "super(type(self), self)")
239256

257+
# Check versioning
258+
try: unsloth_zoo_version = importlib_version("unsloth_zoo")
259+
except: unsloth_zoo_version = "0"
260+
try: unsloth_version = importlib_version("unsloth")
261+
except: unsloth_version = "0"
262+
try: transformers_version = importlib_version("transformers")
263+
except: transformers_version = "0"
264+
try: trl_version = importlib_version("trl")
265+
except: trl_version = "0"
266+
267+
versioning = '"""\n' + \
268+
f'{unsloth_zoo_version}\n'\
269+
f'{unsloth_version}\n'\
270+
f'{transformers_version}\n'\
271+
f'{trl_version}\n__UNSLOTH_VERSIONING__\n' + '"""\n'
272+
273+
write_new_source = versioning + new_source
274+
240275
# Check location
241276
if is_main_process():
242277
if not os.path.exists(UNSLOTH_COMPILE_LOCATION):
@@ -247,35 +282,72 @@ def create_new_function(
247282
function_location = location
248283
if overwrite or not os.path.isfile(function_location):
249284
with open(function_location, "wb", buffering = 0) as file:
250-
file.write(new_source.encode("utf-8"))
285+
file.write(write_new_source.encode("utf-8"))
251286
file.flush()
252287
os.fsync(file.fileno())
253288
pass
254289
pass
255-
else:
256-
# Wait until file is created
257-
location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py")
258-
function_location = location
259-
if overwrite or not os.path.isfile(function_location):
260-
while not os.path.isfile(function_location): continue
290+
pass
291+
# Wait until file is created
292+
file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, f"{name}.py")
293+
trials = 0
294+
if overwrite or not os.path.isfile(file_location):
295+
while not os.path.isfile(file_location):
296+
if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled modules!")
297+
trials += 1
298+
time.sleep(0.01)
299+
pass
300+
# Check versioning, and overwrite if any packages changed
301+
with open(file_location, "r") as f: f = f.read()
302+
303+
# Check if exactly equivalent:
304+
rewrite = False
305+
if f != write_new_source:
306+
rewrite = True
307+
elif not overwrite:
308+
if "__UNSLOTH_VERSIONING__" not in f:
309+
rewrite = True
310+
else:
311+
versions = f[:f.find('__UNSLOTH_VERSIONING__')]
312+
if versioning[:versioning.find('__UNSLOTH_VERSIONING__')] != versions:
313+
rewrite = True
314+
pass
315+
if rewrite:
316+
return create_new_function(
317+
name = name,
318+
new_source = old_new_source,
319+
model_location = model_location,
320+
functions = functions,
321+
prepend = prepend,
322+
append = append,
323+
overwrite = True,
324+
add_torch_compile = add_torch_compile,
325+
)
261326
pass
262327

263328
# Try loading new module
264329
new_module = None
330+
trials = 0
265331
while True:
332+
if trials == 1000: raise RuntimeError("Unsloth: Failed to create dynamic compiled")
266333
try:
267334
new_module = importlib.import_module(UNSLOTH_COMPILE_LOCATION + "." + name)
268335
break
269336
except:
270-
# Instead use sys modules for dynamic loading
271337
module_name = f"unsloth_cache_{name}"
272338
file_location = os.path.join(UNSLOTH_COMPILE_LOCATION, name) + ".py"
339+
340+
# Instead use sys modules for dynamic loading
273341
spec = importlib.util.spec_from_file_location(module_name, file_location)
274342
new_module = importlib.util.module_from_spec(spec)
275343
sys.modules[module_name] = new_module
276344
spec.loader.exec_module(new_module)
277345

346+
# Temp modules can only use dynamic loading
347+
if UNSLOTH_COMPILE_LOCATION_USE_TEMP: break
348+
278349
time.sleep(0.01)
350+
trials += 1
279351
pass
280352
pass
281353
if new_module is None:
@@ -1454,31 +1526,20 @@ def unsloth_compile_transformers(
14541526

14551527
all_code = "\n\n".join(final_all_standalone_classes)
14561528

1457-
if import_from_cache:
1458-
try:
1459-
combined_module = importlib.import_module(f"{UNSLOTH_COMPILE_LOCATION}.{COMBINED_UNSLOTH_NAME}_{model_type}")
1460-
import_from_cache = True
1461-
except:
1462-
import_from_cache = False
1463-
else:
1464-
import_from_cache = False
1465-
pass
1466-
if not import_from_cache:
1467-
try:
1468-
combined_module = create_new_function(
1469-
f"{COMBINED_UNSLOTH_NAME}_{model_type}",
1470-
all_code,
1471-
model_location,
1472-
functions,
1473-
prepend = \
1474-
_disabled_sdpa_code + \
1475-
f"\ntorch_compile_options = {torch_compile_options}\n" + \
1476-
_cross_entropy_code + "\n"
1477-
)
1478-
except Exception as exception:
1479-
raise RuntimeError(exception)
1480-
combined_module = None
1481-
pass
1529+
try:
1530+
combined_module = create_new_function(
1531+
f"{COMBINED_UNSLOTH_NAME}_{model_type}",
1532+
all_code,
1533+
model_location,
1534+
functions,
1535+
prepend = \
1536+
_disabled_sdpa_code + \
1537+
f"\ntorch_compile_options = {torch_compile_options}\n" + \
1538+
_cross_entropy_code + "\n"
1539+
)
1540+
except Exception as exception:
1541+
raise RuntimeError(exception)
1542+
combined_module = None
14821543

14831544
if compile_torch_modules and not disable:
14841545

unsloth_zoo/dataset_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def train_on_responses_only(
182182
"""
183183
# All Unsloth Zoo code licensed under LGPLv3
184184
tokenizer = trainer.processing_class if hasattr(trainer, "processing_class") else trainer.tokenizer
185-
185+
186186
if not hasattr(tokenizer, "_unsloth_input_part") or \
187187
not hasattr(tokenizer, "_unsloth_output_part"):
188188

@@ -288,20 +288,29 @@ def _train_on_responses_only(examples):
288288
return { "labels" : all_labels }
289289
pass
290290

291+
from multiprocessing import cpu_count
292+
num_proc = cpu_count()
293+
291294
if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
292-
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
295+
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
293296
pass
294297

295298
if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
296299
# Eval datasets could be a dict!
297300
if type(trainer.eval_dataset) is dict:
298301
for key, value in trainer.eval_dataset.items():
299-
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True)
302+
trainer.eval_dataset[key] = value.map(_train_on_responses_only, batched = True, num_proc = num_proc)
300303
else:
301-
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True)
304+
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True, num_proc = num_proc)
302305
pass
303306
pass
304307

308+
# Edit data collator as well if not DataCollatorForSeq2Seq
309+
from transformers import DataCollatorForSeq2Seq
310+
if hasattr(trainer, "data_collator") and \
311+
not isinstance(trainer.data_collator, DataCollatorForSeq2Seq):
312+
trainer.data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer)
313+
305314
# Check if all labels randomnly got masked to nothing - maybe wrong chat template?
306315
from .training_utils import fix_zero_training_loss
307316
fix_zero_training_loss(None, tokenizer, trainer.train_dataset)

unsloth_zoo/rl_replacements.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import inspect
2323
import os
2424
import numpy as np
25+
from typing import Union, Callable, Optional, List, Dict
2526

2627
RL_REPLACEMENTS = dict()
2728

@@ -137,7 +138,7 @@ def accumulate_chunk(new_hidden_states_j, old_hidden_states_j, input_ids_j, mask
137138
fullgraph = True,
138139
options = torch_compile_options,
139140
)
140-
141+
141142
grad_inputs_chunks = torch.chunk(grad_inputs, chunks = n_chunks, dim = 0)
142143
new_hidden_states = torch.chunk(_new_hidden_states, chunks = n_chunks, dim = 0)
143144
old_hidden_states = torch.chunk(_old_hidden_states, chunks = n_chunks, dim = 0)
@@ -235,6 +236,98 @@ def grpo_accumulated_loss(
235236
pass
236237
RL_REPLACEMENTS["grpo_accumulated_loss"] = grpo_accumulated_loss
237238

239+
240+
from datasets import (Dataset, IterableDataset,)
241+
from trl.trainer.utils import ConstantLengthDataset
242+
# Faster SFTTrainer prepare_dataset
243+
def sft_prepare_dataset(
244+
self,
245+
dataset: Union[Dataset, IterableDataset],
246+
processing_class,
247+
args,
248+
packing: bool,
249+
formatting_func: Optional[Callable[[dict], str]],
250+
dataset_name: str,
251+
) -> Union[Dataset, IterableDataset]:
252+
# All Unsloth Zoo code licensed under LGPLv3
253+
if isinstance(dataset, ConstantLengthDataset): return dataset
254+
255+
map_kwargs = {}
256+
use_desc = isinstance(dataset, Dataset)
257+
258+
# Get max length
259+
max_seq_length = getattr(args, "max_length", 0)
260+
if max_seq_length == 0: max_seq_length = getattr(args, "max_seq_length", 0)
261+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq_length", 0)
262+
if max_seq_length == 0: max_seq_length = getattr(self, "max_seq", 0)
263+
dataset_text_field = getattr(args, "dataset_text_field", "text")
264+
do_truncation = max_seq_length != 0
265+
do_formatting_func = False
266+
267+
# Check if already tokenized so skip
268+
from transformers import DataCollatorForSeq2Seq
269+
column_names = set(next(iter(dataset)).keys())
270+
if "input_ids" in column_names:
271+
# Most likely forgot data collator!
272+
from transformers import DataCollatorForSeq2Seq
273+
self.data_collator = DataCollatorForSeq2Seq(processing_class)
274+
return dataset
275+
elif dataset_text_field not in column_names:
276+
do_formatting_func = True
277+
if formatting_func is None:
278+
raise RuntimeError("Unsloth: You must specify a `formatting_func`")
279+
pass
280+
281+
# Check double BOS tokens
282+
if do_formatting_func:
283+
test_text = formatting_func(dataset[0])
284+
if not isinstance(test_text, list):
285+
raise ValueError(
286+
"Unsloth: The `formatting_func` should return a list of processed strings."
287+
)
288+
test_text = test_text[0]
289+
else:
290+
test_text = dataset[0][dataset_text_field]
291+
chat_template = getattr(processing_class, 'chat_template', None)
292+
chat_template = '' if chat_template is None else chat_template
293+
add_special_tokens = True
294+
295+
if getattr(processing_class, 'bos_token', None) is not None:
296+
if test_text.startswith(processing_class.bos_token) or processing_class.bos_token in chat_template:
297+
add_special_tokens = False
298+
print("Unsloth: We found double BOS tokens - we shall remove one automatically.")
299+
pass
300+
301+
# Create tokenize function
302+
def _tokenize(example):
303+
return processing_class(
304+
example[dataset_text_field] if not do_formatting_func else formatting_func(example),
305+
truncation = do_truncation,
306+
max_length = max_seq_length,
307+
return_token_type_ids = False,
308+
add_special_tokens = add_special_tokens,
309+
)
310+
pass
311+
312+
map_kwargs["num_proc"] = getattr(args, "dataset_num_proc", 2)
313+
if use_desc: map_kwargs["desc"] = f'Tokenizing to ["{dataset_text_field}"]'
314+
dataset = dataset.map(_tokenize, batched = True, **map_kwargs)
315+
316+
if packing:
317+
if max_seq_length == 0:
318+
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
319+
320+
if use_desc: map_kwargs["desc"] = f"Packing {dataset_name} dataset"
321+
dataset = dataset.select_columns("input_ids").map(
322+
pack_examples,
323+
batched = True,
324+
fn_kwargs = {"seq_length": max_seq_length,},
325+
**map_kwargs,
326+
)
327+
return dataset
328+
pass
329+
RL_REPLACEMENTS["sft_prepare_dataset"] = sft_prepare_dataset
330+
238331
# Unsloth Zoo - Utilities for Unsloth
239332
# Copyright 2023-present Daniel Han-Chen, Michael Han-Chen & the Unsloth team. All rights reserved.
240333
#

0 commit comments

Comments
 (0)