36
36
from .utils import Version , is_main_process
37
37
import triton
38
38
from .peft_utils import get_lora_layer_modules
39
+ from importlib .metadata import version as importlib_version
40
+ from packaging .version import Version
39
41
40
42
# Disable some compilations if old versions are seen
41
43
OLD_TORCH_VERSION = Version (torch .__version__ ) < Version ("2.5.0" )
@@ -62,10 +64,22 @@ def filter(self, x): return not (self.text in x.getMessage())
62
64
global COMBINED_UNSLOTH_NAME
63
65
global UNSLOTH_COMPILE_LOCATION
64
66
global UNSLOTH_CREATED_FUNCTIONS
67
+ global UNSLOTH_COMPILE_LOCATION_USE_TEMP
65
68
COMBINED_UNSLOTH_NAME = "unsloth_compiled_module"
66
69
UNSLOTH_COMPILE_LOCATION = "unsloth_compiled_cache"
67
70
UNSLOTH_CREATED_FUNCTIONS = []
68
-
71
+ UNSLOTH_COMPILE_LOCATION_USE_TEMP = False
72
+
73
+ # Try creating a directory for cache, or else use a temporary folder
74
+ try :
75
+ os .makedirs (UNSLOTH_COMPILE_LOCATION , exist_ok = True )
76
+ if not os .path .exists (UNSLOTH_COMPILE_LOCATION ): raise
77
+ except :
78
+ from tempfile import TemporaryDirectory
79
+ UNSLOTH_COMPILE_LOCATION_USE_TEMP = True
80
+ UNSLOTH_COMPILE_LOCATION = TemporaryDirectory (ignore_cleanup_errors = True ).name
81
+ print (f"Unsloth: We can't create folders, so as a hack, we used a temporary directory = { UNSLOTH_COMPILE_LOCATION } " )
82
+ pass
69
83
70
84
_license_header = """
71
85
# Unsloth Zoo - Utilities for Unsloth
@@ -210,8 +224,11 @@ def create_new_function(
210
224
add_torch_compile = False ,
211
225
):
212
226
# All Unsloth Zoo code licensed under LGPLv3
227
+ old_new_source = new_source
228
+
213
229
global UNSLOTH_CREATED_FUNCTIONS
214
230
global UNSLOTH_COMPILE_LOCATION
231
+ global UNSLOTH_COMPILE_LOCATION_USE_TEMP
215
232
if new_source [0 ] == " " :
216
233
spaces = new_source .find ("def" )
217
234
new_source = new_source .split ("\n " )
@@ -237,6 +254,24 @@ def create_new_function(
237
254
# Fix super() Not necessary anymore!
238
255
# new_source = new_source.replace("super()", "super(type(self), self)")
239
256
257
+ # Check versioning
258
+ try : unsloth_zoo_version = importlib_version ("unsloth_zoo" )
259
+ except : unsloth_zoo_version = "0"
260
+ try : unsloth_version = importlib_version ("unsloth" )
261
+ except : unsloth_version = "0"
262
+ try : transformers_version = importlib_version ("transformers" )
263
+ except : transformers_version = "0"
264
+ try : trl_version = importlib_version ("trl" )
265
+ except : trl_version = "0"
266
+
267
+ versioning = '"""\n ' + \
268
+ f'{ unsloth_zoo_version } \n ' \
269
+ f'{ unsloth_version } \n ' \
270
+ f'{ transformers_version } \n ' \
271
+ f'{ trl_version } \n __UNSLOTH_VERSIONING__\n ' + '"""\n '
272
+
273
+ write_new_source = versioning + new_source
274
+
240
275
# Check location
241
276
if is_main_process ():
242
277
if not os .path .exists (UNSLOTH_COMPILE_LOCATION ):
@@ -247,35 +282,72 @@ def create_new_function(
247
282
function_location = location
248
283
if overwrite or not os .path .isfile (function_location ):
249
284
with open (function_location , "wb" , buffering = 0 ) as file :
250
- file .write (new_source .encode ("utf-8" ))
285
+ file .write (write_new_source .encode ("utf-8" ))
251
286
file .flush ()
252
287
os .fsync (file .fileno ())
253
288
pass
254
289
pass
255
- else :
256
- # Wait until file is created
257
- location = os .path .join (UNSLOTH_COMPILE_LOCATION , f"{ name } .py" )
258
- function_location = location
259
- if overwrite or not os .path .isfile (function_location ):
260
- while not os .path .isfile (function_location ): continue
290
+ pass
291
+ # Wait until file is created
292
+ file_location = os .path .join (UNSLOTH_COMPILE_LOCATION , f"{ name } .py" )
293
+ trials = 0
294
+ if overwrite or not os .path .isfile (file_location ):
295
+ while not os .path .isfile (file_location ):
296
+ if trials == 1000 : raise RuntimeError ("Unsloth: Failed to create dynamic compiled modules!" )
297
+ trials += 1
298
+ time .sleep (0.01 )
299
+ pass
300
+ # Check versioning, and overwrite if any packages changed
301
+ with open (file_location , "r" ) as f : f = f .read ()
302
+
303
+ # Check if exactly equivalent:
304
+ rewrite = False
305
+ if f != write_new_source :
306
+ rewrite = True
307
+ elif not overwrite :
308
+ if "__UNSLOTH_VERSIONING__" not in f :
309
+ rewrite = True
310
+ else :
311
+ versions = f [:f .find ('__UNSLOTH_VERSIONING__' )]
312
+ if versioning [:versioning .find ('__UNSLOTH_VERSIONING__' )] != versions :
313
+ rewrite = True
314
+ pass
315
+ if rewrite :
316
+ return create_new_function (
317
+ name = name ,
318
+ new_source = old_new_source ,
319
+ model_location = model_location ,
320
+ functions = functions ,
321
+ prepend = prepend ,
322
+ append = append ,
323
+ overwrite = True ,
324
+ add_torch_compile = add_torch_compile ,
325
+ )
261
326
pass
262
327
263
328
# Try loading new module
264
329
new_module = None
330
+ trials = 0
265
331
while True :
332
+ if trials == 1000 : raise RuntimeError ("Unsloth: Failed to create dynamic compiled" )
266
333
try :
267
334
new_module = importlib .import_module (UNSLOTH_COMPILE_LOCATION + "." + name )
268
335
break
269
336
except :
270
- # Instead use sys modules for dynamic loading
271
337
module_name = f"unsloth_cache_{ name } "
272
338
file_location = os .path .join (UNSLOTH_COMPILE_LOCATION , name ) + ".py"
339
+
340
+ # Instead use sys modules for dynamic loading
273
341
spec = importlib .util .spec_from_file_location (module_name , file_location )
274
342
new_module = importlib .util .module_from_spec (spec )
275
343
sys .modules [module_name ] = new_module
276
344
spec .loader .exec_module (new_module )
277
345
346
+ # Temp modules can only use dynamic loading
347
+ if UNSLOTH_COMPILE_LOCATION_USE_TEMP : break
348
+
278
349
time .sleep (0.01 )
350
+ trials += 1
279
351
pass
280
352
pass
281
353
if new_module is None :
@@ -1454,31 +1526,20 @@ def unsloth_compile_transformers(
1454
1526
1455
1527
all_code = "\n \n " .join (final_all_standalone_classes )
1456
1528
1457
- if import_from_cache :
1458
- try :
1459
- combined_module = importlib .import_module (f"{ UNSLOTH_COMPILE_LOCATION } .{ COMBINED_UNSLOTH_NAME } _{ model_type } " )
1460
- import_from_cache = True
1461
- except :
1462
- import_from_cache = False
1463
- else :
1464
- import_from_cache = False
1465
- pass
1466
- if not import_from_cache :
1467
- try :
1468
- combined_module = create_new_function (
1469
- f"{ COMBINED_UNSLOTH_NAME } _{ model_type } " ,
1470
- all_code ,
1471
- model_location ,
1472
- functions ,
1473
- prepend = \
1474
- _disabled_sdpa_code + \
1475
- f"\n torch_compile_options = { torch_compile_options } \n " + \
1476
- _cross_entropy_code + "\n "
1477
- )
1478
- except Exception as exception :
1479
- raise RuntimeError (exception )
1480
- combined_module = None
1481
- pass
1529
+ try :
1530
+ combined_module = create_new_function (
1531
+ f"{ COMBINED_UNSLOTH_NAME } _{ model_type } " ,
1532
+ all_code ,
1533
+ model_location ,
1534
+ functions ,
1535
+ prepend = \
1536
+ _disabled_sdpa_code + \
1537
+ f"\n torch_compile_options = { torch_compile_options } \n " + \
1538
+ _cross_entropy_code + "\n "
1539
+ )
1540
+ except Exception as exception :
1541
+ raise RuntimeError (exception )
1542
+ combined_module = None
1482
1543
1483
1544
if compile_torch_modules and not disable :
1484
1545
0 commit comments