Skip to content

Conversation

CJack812
Copy link
Contributor

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

问题描述

swift/llm/template/base.py 中调用 self.loss_scale 时,没有将额外的参数传递给底层的 get_loss_scale 函数,导致用户无法通过数据集样本自定义参数来灵活控制 loss_scale行为。

修改内容

文件: swift/llm/template/base.py

修改前:

res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages)

修改后:

res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages, **inputs.extra_kwargs)

get_loss_scale 函数签名原本就支持额外的 **kwargs 参数:

def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs) -> Tuple[List[str], List[float]]:

但在template base.py的调用处没有利用这一特性,导致额外的参数无法传递

用户现在可以在数据集的样本中定义 extra_kwargs,通过这些参数来自定义loss_scale
无破坏性修改:完全向后兼容,不影响现有代码的正常运行
支持高级用法:例如:

  • 基于样本难度动态调整损失权重
  • 根据任务类型应用不同的缩放策略

Experiment results

已在qwen3-omni模型的megatron sft训练和qwen2.5-omni模型的deepspeed sft 训练中试验了该功能。具体为:在样本中增加extra_kwargs ,自定义loss_scale根据这些extra_kwargs 调整 loss_scale , 模型性能提升。

该修改为框架基础设施优化,不直接影响模型性能指标,但显著提升了框架的灵活性和可定制性,为用户实现更精细化的训练控制提供了可能。

Copy link
Contributor

Summary of Changes

Hello @CJack812, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

此拉取请求引入了一项新功能,允许用户在训练过程中更精细地控制损失缩放。通过将数据集样本中定义的额外参数传递给损失缩放方法,框架的灵活性和可定制性得到了显著提升,从而支持更高级的训练策略。除了这项主要功能外,还对多个文件中的f-string语法进行了统一和修正,将双引号内的键引用改为单引号,以提高代码一致性和避免潜在的语法问题。

Highlights

  • 动态损失缩放参数传递: 在 swift/llm/template/base.py 中,self.loss_scale 方法现在会传递 inputs.extra_kwargs。这允许用户通过数据集样本中的自定义参数来灵活控制损失缩放行为,例如根据样本难度或任务类型调整损失权重。
  • 代码风格统一: 对多个文件中的f-string语法进行了统一和修正,将双引号内的字典键引用改为单引号,以提高代码一致性并避免潜在的语法问题。
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

本次PR的核心目标是向loss_scale方法传递extra_kwargs,这是一个非常有价值的改进,增强了训练的灵活性,实现得很好。

然而,除了核心功能外,本次PR还包含了大量跨多个文件的代码格式化改动。虽然这有助于代码风格的统一,但也为代码审查带来了许多噪音。建议未来将功能性变更和代码格式化分开提交。

最重要的问题是,PR中新增了一个名为stepfun.ipynb的Jupyter notebook文件。此文件似乎与本次PR的主要功能无关,并且其输出单元格中包含ModuleNotFoundError错误。通常不建议将带有输出(尤其是错误信息)的notebook文件提交到代码库中。请考虑从本次合并请求中移除此文件。

Comment on lines +1 to +139
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a9dd41ed",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/debian/workspace/multi_model/ms-swift/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'librosa'",
"output_type": "error",
"traceback": [
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mllm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PtEngine, RequestConfig\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_logger, seed_everything\n\u001b[32m 3\u001b[39m logger = get_logger()\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:104\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m 103\u001b[39m module = \u001b[38;5;28mself\u001b[39m._get_module(\u001b[38;5;28mself\u001b[39m._class_to_module[name])\n\u001b[32m--> \u001b[39m\u001b[32m104\u001b[39m value = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:104\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m 103\u001b[39m module = \u001b[38;5;28mself\u001b[39m._get_module(\u001b[38;5;28mself\u001b[39m._class_to_module[name])\n\u001b[32m--> \u001b[39m\u001b[32m104\u001b[39m value = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:103\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 101\u001b[39m value = \u001b[38;5;28mself\u001b[39m._get_module(name)\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m module = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 104\u001b[39m value = \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:112\u001b[39m, in \u001b[36m_LazyModule._get_module\u001b[39m\u001b[34m(self, module_name)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_get_module\u001b[39m(\u001b[38;5;28mself\u001b[39m, module_name: \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m112\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/importlib/__init__.py:90\u001b[39m, in \u001b[36mimport_module\u001b[39m\u001b[34m(name, package)\u001b[39m\n\u001b[32m 88\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 89\u001b[39m level += \u001b[32m1\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m90\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/infer/infer_engine/pt_engine.py:19\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtransformers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m GenerationConfig, LogitsProcessorList\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtransformers\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m is_torch_npu_available\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mllm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mplugin\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Metric\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtuners\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Swift\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:103\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 101\u001b[39m value = \u001b[38;5;28mself\u001b[39m._get_module(name)\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m module = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 104\u001b[39m value = \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:112\u001b[39m, in \u001b[36m_LazyModule._get_module\u001b[39m\u001b[34m(self, module_name)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_get_module\u001b[39m(\u001b[38;5;28mself\u001b[39m, module_name: \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m112\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/importlib/__init__.py:90\u001b[39m, in \u001b[36mimport_module\u001b[39m\u001b[34m(name, package)\u001b[39m\n\u001b[32m 88\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 89\u001b[39m level += \u001b[32m1\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m90\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/__init__.py:2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Copyright (c) Alibaba, Inc. and its affiliates.\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m template\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MaxLengthError, Template\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconstant\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TemplateType\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/template/__init__.py:1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (baidu, bert, deepseek, dots, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm,\n\u001b[32m 2\u001b[39m megrez, microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral,\n\u001b[32m 3\u001b[39m qwen, seed, stepfun, valley, yi)\n",
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/template/stepfun.py:8\u001b[39m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mrnn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m pad_sequence\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mfunctional\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mF\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlibrosa\u001b[39;00m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_env_args\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Template\n",
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'librosa'"
]
}
],
"source": [
"from swift.llm import PtEngine, RequestConfig\n",
"from swift.utils import get_logger, seed_everything\n",
"logger = get_logger()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "945a5203",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
"\n",
"\n",
"def _infer_model(pt_engine, system=None, messages=None, audios=None):\n",
" seed_everything(42)\n",
" request_config = RequestConfig(max_tokens=128, temperature=0)\n",
" if messages is None:\n",
" messages = []\n",
" if system is not None:\n",
" messages += [{'role': 'system', 'content': system}]\n",
" messages += [{'role': 'user', 'content': '你好'}]\n",
" resp = pt_engine.infer([{'messages': messages}], request_config=request_config)\n",
" response = resp[0].choices[0].message.content\n",
" messages += [{'role': 'assistant', 'content': response}]\n",
" messages += [{'role': 'user', 'content': '<audio>这段语音说了什么'}]\n",
" else:\n",
" messages = messages.copy()\n",
" if audios is None:\n",
" audios = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']\n",
" resp = pt_engine.infer([{'messages': messages, 'audios': audios}], request_config=request_config)\n",
" response = resp[0].choices[0].message.content\n",
" messages += [{'role': 'assistant', 'content': response}]\n",
" logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}')\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfd3a2f7",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "4711dea0",
"metadata": {},
"outputs": [],
"source": [
"\n",
" \n",
"pt_engine = PtEngine('/mnt/nfs/models/Step-Audio-2-mini')\n",
"messages = [{'role': 'user', 'content': '<audio>Caption the audio.'}]\n",
"response = _infer_model(pt_engine, messages=messages)\n",
"pt_engine.default_template.template_backend = 'jinja'\n",
"response2 = _infer_model(pt_engine, messages=messages)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f533916d",
"metadata": {},
"outputs": [],
"source": [
"assert response == response2 == \"The audio contains a male voice speaking the phrase '今天天气真好呀' in Mandarin.\"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

这个新添加的 Jupyter notebook (stepfun.ipynb) 似乎与本次 PR 的核心功能无关。它包含了带有 ModuleNotFoundError 错误的输出单元格,这表明它可能是一个实验性文件或存在未声明的依赖。通常不建议将带有输出(尤其是错误)的 notebook 文件提交到代码库中。请考虑从本次合并请求中移除此文件,以保持代码库的整洁,并在未来的独立 PR 中提交(如果需要的话)。

Comment on lines +1149 to +1150
res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages,
**inputs.extra_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

核心修改正确,通过 **inputs.extra_kwargs 将额外参数传递给 loss_scale 函数,增强了框架的灵活性。但除了这个核心改动,本次PR在大量文件中引入了许多仅涉及引号替换的格式化修改,这增加了代码审查的噪音。建议未来将功能性修改和纯格式化修改分在不同的PR中提交。

@CJack812 CJack812 closed this Oct 15, 2025
@CJack812 CJack812 reopened this Oct 15, 2025
@CJack812 CJack812 closed this Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant