88from torch .optim .lr_scheduler import LRScheduler
99from transformers import Trainer , TrainingArguments , get_scheduler
1010
11- from swift .tuners .module_mapping import MODEL_KEYS_MAPPING
1211from swift .utils import get_logger
1312
1413logger = get_logger ()
@@ -23,7 +22,6 @@ class GaLoreConfig:
2322 See https://arxiv.org/abs/2403.03507
2423
2524 Args:
26- model_type (`str`): The model_type of Galore
2725 rank (`int`): The galore rank
2826 target_modules (`Union[str, List[str]]`): The target modules to use, if `None`,
2927 will use all attn and mlp linears
@@ -33,13 +31,11 @@ class GaLoreConfig:
3331 galore_scale(float): the scale of gradient
3432 optim_per_parameter(bool): Gives one optimizer per parameter
3533 """
36- model_type : str = None
3734 rank : int = 128
3835 target_modules : Union [str , List [str ]] = None
3936 update_proj_gap : int = 50
4037 galore_scale : float = 1.0
4138 proj_type : str = 'std'
42- with_embedding : bool = False
4339 optim_per_parameter : bool = False
4440
4541
@@ -72,19 +68,6 @@ def step(self, *args, **kwargs) -> None:
7268def create_optimizer_and_scheduler (model : nn .Module , args : TrainingArguments ,
7369 config : GaLoreConfig , max_steps ,
7470 ** defaults ):
75- if not config .target_modules :
76- if config .model_type in MODEL_KEYS_MAPPING :
77- target_modules_list = [
78- MODEL_KEYS_MAPPING [config .model_type ].attention .split ('.{}.' )
79- [1 ], MODEL_KEYS_MAPPING [config .model_type ].mlp .split ('.{}.' )[1 ]
80- ]
81- config .target_modules = target_modules_list
82- if config .with_embedding :
83- embedding = MODEL_KEYS_MAPPING [config .model_type ].embedding
84- idx = embedding .rfind ('.' )
85- embedding = embedding [idx + 1 :]
86- target_modules_list .append (embedding )
87-
8871 galore_params = []
8972 for module_name , module in model .named_modules ():
9073 if not isinstance (module , (nn .Linear , nn .Embedding )) or \
0 commit comments