-
Notifications
You must be signed in to change notification settings - Fork 903
[Feature] Update loss_scale method call to pass through inputs.extra_kwargs #6156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @CJack812, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! 此拉取请求引入了一项新功能,允许用户在训练过程中更精细地控制损失缩放。通过将数据集样本中定义的额外参数传递给损失缩放方法,框架的灵活性和可定制性得到了显著提升,从而支持更高级的训练策略。除了这项主要功能外,还对多个文件中的f-string语法进行了统一和修正,将双引号内的键引用改为单引号,以提高代码一致性和避免潜在的语法问题。 Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
本次PR的核心目标是向loss_scale
方法传递extra_kwargs
,这是一个非常有价值的改进,增强了训练的灵活性,实现得很好。
然而,除了核心功能外,本次PR还包含了大量跨多个文件的代码格式化改动。虽然这有助于代码风格的统一,但也为代码审查带来了许多噪音。建议未来将功能性变更和代码格式化分开提交。
最重要的问题是,PR中新增了一个名为stepfun.ipynb
的Jupyter notebook文件。此文件似乎与本次PR的主要功能无关,并且其输出单元格中包含ModuleNotFoundError
错误。通常不建议将带有输出(尤其是错误信息)的notebook文件提交到代码库中。请考虑从本次合并请求中移除此文件。
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "a9dd41ed", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/home/debian/workspace/multi_model/ms-swift/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
}, | ||
{ | ||
"ename": "ModuleNotFoundError", | ||
"evalue": "No module named 'librosa'", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[31m---------------------------------------------------------------------------\u001b[39m", | ||
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)", | ||
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mllm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m PtEngine, RequestConfig\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_logger, seed_everything\n\u001b[32m 3\u001b[39m logger = get_logger()\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:104\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m 103\u001b[39m module = \u001b[38;5;28mself\u001b[39m._get_module(\u001b[38;5;28mself\u001b[39m._class_to_module[name])\n\u001b[32m--> \u001b[39m\u001b[32m104\u001b[39m value = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:104\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m 103\u001b[39m module = \u001b[38;5;28mself\u001b[39m._get_module(\u001b[38;5;28mself\u001b[39m._class_to_module[name])\n\u001b[32m--> \u001b[39m\u001b[32m104\u001b[39m value = \u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 106\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mmodule \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.\u001b[34m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m has no attribute \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m)\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:103\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 101\u001b[39m value = \u001b[38;5;28mself\u001b[39m._get_module(name)\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m module = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 104\u001b[39m value = \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:112\u001b[39m, in \u001b[36m_LazyModule._get_module\u001b[39m\u001b[34m(self, module_name)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_get_module\u001b[39m(\u001b[38;5;28mself\u001b[39m, module_name: \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m112\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/importlib/__init__.py:90\u001b[39m, in \u001b[36mimport_module\u001b[39m\u001b[34m(name, package)\u001b[39m\n\u001b[32m 88\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 89\u001b[39m level += \u001b[32m1\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m90\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/infer/infer_engine/pt_engine.py:19\u001b[39m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtransformers\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m GenerationConfig, LogitsProcessorList\n\u001b[32m 17\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtransformers\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m is_torch_npu_available\n\u001b[32m---> \u001b[39m\u001b[32m19\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mllm\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m InferRequest, Template, TemplateMeta, get_model_tokenizer, safe_snapshot_download, to_device\n\u001b[32m 20\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mplugin\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Metric\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtuners\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Swift\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:103\u001b[39m, in \u001b[36m_LazyModule.__getattr__\u001b[39m\u001b[34m(self, name)\u001b[39m\n\u001b[32m 101\u001b[39m value = \u001b[38;5;28mself\u001b[39m._get_module(name)\n\u001b[32m 102\u001b[39m \u001b[38;5;28;01melif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m._class_to_module.keys():\n\u001b[32m--> \u001b[39m\u001b[32m103\u001b[39m module = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_get_module\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_class_to_module\u001b[49m\u001b[43m[\u001b[49m\u001b[43mname\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 104\u001b[39m value = \u001b[38;5;28mgetattr\u001b[39m(module, name)\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/utils/import_utils.py:112\u001b[39m, in \u001b[36m_LazyModule._get_module\u001b[39m\u001b[34m(self, module_name)\u001b[39m\n\u001b[32m 111\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m_get_module\u001b[39m(\u001b[38;5;28mself\u001b[39m, module_name: \u001b[38;5;28mstr\u001b[39m):\n\u001b[32m--> \u001b[39m\u001b[32m112\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mimportlib\u001b[49m\u001b[43m.\u001b[49m\u001b[43mimport_module\u001b[49m\u001b[43m(\u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43m.\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m \u001b[49m\u001b[43m+\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodule_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[34;43m__name__\u001b[39;49m\u001b[43m)\u001b[49m\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/lib/python3.12/importlib/__init__.py:90\u001b[39m, in \u001b[36mimport_module\u001b[39m\u001b[34m(name, package)\u001b[39m\n\u001b[32m 88\u001b[39m \u001b[38;5;28;01mbreak\u001b[39;00m\n\u001b[32m 89\u001b[39m level += \u001b[32m1\u001b[39m\n\u001b[32m---> \u001b[39m\u001b[32m90\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_bootstrap\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_gcd_import\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpackage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlevel\u001b[49m\u001b[43m)\u001b[49m\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/__init__.py:2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Copyright (c) Alibaba, Inc. and its affiliates.\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m template\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m MaxLengthError, Template\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01mconstant\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m TemplateType\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/template/__init__.py:1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (baidu, bert, deepseek, dots, emu3, gemma, glm, idefics3, internlm, internvl, kwai, llama, llava, llm,\n\u001b[32m 2\u001b[39m megrez, microsoft, midashenglm, minicpm, minimax, mistral, molmo, moonshot, mplug, openbuddy, pixtral,\n\u001b[32m 3\u001b[39m qwen, seed, stepfun, valley, yi)\n", | ||
"\u001b[36mFile \u001b[39m\u001b[32m~/workspace/multi_model/ms-swift/swift/llm/template/template/stepfun.py:8\u001b[39m\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mrnn\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m pad_sequence\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mnn\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mfunctional\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mas\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mF\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mlibrosa\u001b[39;00m\n\u001b[32m 10\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mswift\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_env_args\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01m.\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mbase\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Template\n", | ||
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'librosa'" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from swift.llm import PtEngine, RequestConfig\n", | ||
"from swift.utils import get_logger, seed_everything\n", | ||
"logger = get_logger()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "945a5203", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"\n", | ||
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", | ||
"\n", | ||
"\n", | ||
"def _infer_model(pt_engine, system=None, messages=None, audios=None):\n", | ||
" seed_everything(42)\n", | ||
" request_config = RequestConfig(max_tokens=128, temperature=0)\n", | ||
" if messages is None:\n", | ||
" messages = []\n", | ||
" if system is not None:\n", | ||
" messages += [{'role': 'system', 'content': system}]\n", | ||
" messages += [{'role': 'user', 'content': '你好'}]\n", | ||
" resp = pt_engine.infer([{'messages': messages}], request_config=request_config)\n", | ||
" response = resp[0].choices[0].message.content\n", | ||
" messages += [{'role': 'assistant', 'content': response}]\n", | ||
" messages += [{'role': 'user', 'content': '<audio>这段语音说了什么'}]\n", | ||
" else:\n", | ||
" messages = messages.copy()\n", | ||
" if audios is None:\n", | ||
" audios = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']\n", | ||
" resp = pt_engine.infer([{'messages': messages, 'audios': audios}], request_config=request_config)\n", | ||
" response = resp[0].choices[0].message.content\n", | ||
" messages += [{'role': 'assistant', 'content': response}]\n", | ||
" logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}')\n", | ||
" return response" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "bfd3a2f7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "4711dea0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
" \n", | ||
"pt_engine = PtEngine('/mnt/nfs/models/Step-Audio-2-mini')\n", | ||
"messages = [{'role': 'user', 'content': '<audio>Caption the audio.'}]\n", | ||
"response = _infer_model(pt_engine, messages=messages)\n", | ||
"pt_engine.default_template.template_backend = 'jinja'\n", | ||
"response2 = _infer_model(pt_engine, messages=messages)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f533916d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"assert response == response2 == \"The audio contains a male voice speaking the phrase '今天天气真好呀' in Mandarin.\"\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res_context_list, loss_scale_list = self.loss_scale(res_context_list, res_context_types, inputs.messages, | ||
**inputs.extra_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR type
PR information
问题描述
在
swift/llm/template/base.py
中调用self.loss_scale
时,没有将额外的参数传递给底层的get_loss_scale
函数,导致用户无法通过数据集样本自定义参数来灵活控制 loss_scale行为。修改内容
文件:
swift/llm/template/base.py
修改前:
修改后:
get_loss_scale 函数签名原本就支持额外的 **kwargs 参数:
但在template base.py的调用处没有利用这一特性,导致额外的参数无法传递
用户现在可以在数据集的样本中定义 extra_kwargs,通过这些参数来自定义loss_scale
无破坏性修改:完全向后兼容,不影响现有代码的正常运行
支持高级用法:例如:
Experiment results
已在qwen3-omni模型的megatron sft训练和qwen2.5-omni模型的deepspeed sft 训练中试验了该功能。具体为:在样本中增加extra_kwargs ,自定义loss_scale根据这些extra_kwargs 调整 loss_scale , 模型性能提升。
该修改为框架基础设施优化,不直接影响模型性能指标,但显著提升了框架的灵活性和可定制性,为用户实现更精细化的训练控制提供了可能。