Skip to content

Commit 20bc628

Browse files
authored
fix val_dataset (modelscope#992)
1 parent 03e5e3a commit 20bc628

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

swift/llm/infer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,17 @@ def llm_infer(args: InferArguments) -> None:
393393
'model_author': args.model_author
394394
}
395395
if len(args.val_dataset) > 0:
396-
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio, **dataset_kwargs)
397-
else:
398396
_, val_dataset = get_dataset(args.val_dataset, 1.0, **dataset_kwargs)
397+
else:
398+
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio, **dataset_kwargs)
399399
_, val_dataset = args._handle_dataset_compat(_, val_dataset)
400+
assert val_dataset is not None
400401
if args.show_dataset_sample >= 0 and val_dataset.shape[0] > args.show_dataset_sample:
401402
random_state = np.random.RandomState(args.dataset_seed)
402403
logger.info(f'show_dataset_sample: {args.show_dataset_sample}')
403404
val_dataset = sample_dataset(val_dataset, args.show_dataset_sample, random_state)
404-
405405
logger.info(f'val_dataset: {val_dataset}')
406+
406407
if args.verbose is None:
407408
if len(val_dataset) >= 100:
408409
args.verbose = False

swift/llm/utils/argument.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,15 +213,18 @@ def handle_compatibility(self: Union['SftArguments', 'InferArguments']) -> None:
213213
v = _mapping[k]
214214
setattr(self, _name, v)
215215
break
216-
if isinstance(self.dataset, str):
217-
self.dataset = [self.dataset]
218-
if len(self.dataset) == 1 and ',' in self.dataset[0]:
219-
self.dataset = self.dataset[0].split(',')
220-
for i, dataset in enumerate(self.dataset):
221-
if dataset in dataset_name_mapping:
222-
self.dataset[i] = dataset_name_mapping[dataset]
223-
for d in self.dataset:
224-
assert ',' not in d, f'dataset: {d}, please use `/`'
216+
for key in ['dataset', 'val_dataset']:
217+
_dataset = getattr(self, key)
218+
if isinstance(_dataset, str):
219+
_dataset = [_dataset]
220+
if len(_dataset) == 1 and ',' in _dataset[0]:
221+
_dataset = _dataset[0].split(',')
222+
for i, d in enumerate(_dataset):
223+
if d in dataset_name_mapping:
224+
_dataset[i] = dataset_name_mapping[d]
225+
for d in _dataset:
226+
assert ',' not in d, f'dataset: {d}, please use `/`'
227+
setattr(self, key, _dataset)
225228
if self.truncation_strategy == 'ignore':
226229
self.truncation_strategy = 'delete'
227230
if self.safe_serialization is not None:
@@ -1072,12 +1075,12 @@ def __post_init__(self) -> None:
10721075
self.torch_dtype, _, _ = self.select_dtype()
10731076
self.prepare_template()
10741077
if self.eval_human is None:
1075-
if not len(self.dataset) > 0:
1078+
if len(self.dataset) == 0 and len(self.val_dataset) == 0:
10761079
self.eval_human = True
10771080
else:
10781081
self.eval_human = False
10791082
logger.info(f'Setting self.eval_human: {self.eval_human}')
1080-
elif self.eval_human is False and not len(self.dataset) > 0:
1083+
elif self.eval_human is False and len(self.dataset) == 0 and len(self.val_dataset) == 0:
10811084
raise ValueError('Please provide the dataset or set `--load_dataset_config true`.')
10821085

10831086
# compatibility

0 commit comments

Comments
 (0)