diff --git a/swift/llm/sft.py b/swift/llm/sft.py index 5c3b55f4a..1dc7712fc 100644 --- a/swift/llm/sft.py +++ b/swift/llm/sft.py @@ -469,6 +469,14 @@ def trainer_train( if args.train_type == 'ppo': trainer_kwargs['reward_model'] = reward_model trainer_kwargs['value_model'] = value_model + if args.use_channel_loss: + channel_dataset_dict = {} + for sample in val_dataset: + channel = sample['channel'] + if channel not in channel_dataset_dict: + channel_dataset_dict[channel] = [] + channel_dataset_dict[channel].append(sample) + val_dataset = channel_dataset_dict trainer = trainer_cls( model=model, args=training_args, diff --git a/swift/llm/utils/argument.py b/swift/llm/utils/argument.py index 396197c0e..c695ac03f 100644 --- a/swift/llm/utils/argument.py +++ b/swift/llm/utils/argument.py @@ -655,6 +655,8 @@ class SftArguments(ArgumentsBase): dataset_seed: Optional[int] = None dataset_test_ratio: float = 0.01 use_loss_scale: bool = False # for agent + use_channel_loss: bool = False + channel_to_save: Optional[str] = None loss_scale_config_path: str = 'DEFAULT' system: Optional[str] = None tools_prompt: Literal['react_en', 'react_zh', 'toolbench'] = 'react_en' @@ -1198,6 +1200,10 @@ def _init_training_args(self) -> None: kwargs['accelerator_config'] = {'dispatch_batches': False} metric_for_best_model = 'rouge-l' if self.predict_with_generate else 'loss' + if self.use_channel_loss: + if self.channel_to_save is None: + raise ValueError('Please specify --channel_to_save') + metric_for_best_model = f'{self.channel_to_save}_{metric_for_best_model}' if hasattr(self, 'rlhf_type') and self.rlhf_type == 'ppo': metric_for_best_model = None diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index f84aafb87..9041ab3d2 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -66,7 +66,7 @@ def new_func(self, *args, **kwargs): standard_keys = { 'query', 'query_role', 'response', 'rejected_response', 'system', 'history', 'history_roles', 'images', 'objects', - 'videos', 'audios', 'tools', 'label' + 'videos', 'audios', 'tools', 'label', 'channel' } diff --git a/swift/llm/utils/template.py b/swift/llm/utils/template.py index ad135d688..cea6c99c9 100644 --- a/swift/llm/utils/template.py +++ b/swift/llm/utils/template.py @@ -610,6 +610,7 @@ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, An inputs[key] = example.get(key) if inputs.get('labels') is None: inputs.pop('loss_scale', None) + inputs['channel'] = example.get('channel', '') return inputs, tokenizer_kwargs def _concat_context_list(