Skip to content

Commit 99df181

Browse files
fix adalora (#3714)
1 parent 70051cc commit 99df181

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

swift/llm/train/tuner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset
215215
elif args.train_type == 'adalora':
216216
lora_kwargs.pop('lorap_lr_ratio', None)
217217
lora_kwargs['rank_pattern'] = None
218+
from swift.plugin.optimizer import calculate_max_steps
218219
adalora_config = AdaLoraConfig(
219220
task_type=task_type,
220221
**lora_kwargs,
@@ -226,6 +227,7 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset
226227
beta1=args.adalora_beta1,
227228
beta2=args.adalora_beta2,
228229
orth_reg_weight=args.adalora_orth_reg_weight,
230+
total_step=calculate_max_steps(args.training_args, train_dataset),
229231
)
230232
model = Swift.prepare_model(model, adalora_config)
231233
logger.info(f'adalora_config: {adalora_config}')

tests/tuners/test_peft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def test_lora_reload_by_peft(self):
121121
def test_peft_adalora_injection(self):
122122
model = SbertForSequenceClassification(SbertConfig())
123123
model2 = copy.deepcopy(model)
124-
adalora_config = AdaLoraConfig(target_modules=['query', 'key', 'value'])
124+
adalora_config = AdaLoraConfig(target_modules=['query', 'key', 'value'], total_step=1)
125125
model = Swift.prepare_model(model, adalora_config)
126126
model.save_pretrained(self.tmp_dir, safe_serialization=False)
127127
with open(os.path.join(self.tmp_dir, 'configuration.json'), 'w') as f:

0 commit comments

Comments
 (0)