Skip to content

Conversation

Zhikaiiii
Copy link
Collaborator

@Zhikaiiii Zhikaiiii commented Apr 11, 2024

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Model or Dataset Support

PR information

Previous PR: #647

  1. Integrate more model patch function for torchacc.
  2. Support stat speed metrics for after some warmup steps(since there is compile time in the beginning of torchacc)

Experiment results

Paste your experiment result here(if needed).

We have test some models for torchacc and swift

  1. llama2-13b
method train_sample/s train_sample/s after warmup
torchacc + 2fsdp 3.775 4.426(1.13x)
torchacc + 2ddp 4.997(1.28x) 5.416(1.38x)
swift + 2ddp 3.899 3.912
  1. baichuan2-13b
method train_sample/s train_sample/s after warmup
torchacc + 2fsdp 5.014(1.32x) 6.039(1.60x)
torchacc + 2ddp 6.218(1.63x) 6.861(1.80x)
swift + 2ddp 3.812 3.815
  1. chatglm3-6b
method train_sample/s train_sample/s after warmup
torchacc + 2fsdp 9.859(1.82x) 11.896(2.19x)
swift + 2ddp 5.431 -
  1. yi-34b
method train_sample/s train_sample/s after warmup
torchacc + 4fsdp 2.349 2.978(1.24x)
swift + 2ddp + 2mp 2.411 2.411
  1. llama3-8b
method train_sample/s train_sample/s after warmup
torchacc + 2ddp 9.569(1.17x) 10.593(1.30x)
swift + 2ddp 8.126 -
  1. qwen1.5-14b
method train_sample/s train_sample/s after warmup
torchacc + 2ddp 5.293(1.07x) 5.765(1.17x)
swift + 2ddp 4.944 -

baoleai and others added 30 commits January 25, 2024 14:09
* [TorchAcc] Fix batch split when padding_to is not None.

* fix lint
logging_steps: int = 5
dataloader_num_workers: int = 1
dataloader_pin_memory: bool = True
dataloader_drop_last: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

问一下这里为什么不使用training_args的默认值False呢

swift/llm/sft.py Outdated
logger.info(f'val_dataset_sample: {val_dataset_sample}')
val_idxs = random_state.permutation(val_dataset_sample)
val_dataset = val_dataset.select(val_idxs)
training_args.train_dataset_sample = train_dataset.shape[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是为什么呀

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_dataset_sample会在下个版本去除了,使用'{dataset_name}#{train_sample}|{val_sample}'来控制单个数据集的数量

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train_dataset_sample会在下个版本去除了,使用'{dataset_name}#{train_sample}|{val_sample}'来控制单个数据集的数量

  1. 这边是需要获取一个总的train_dataset_sample用于后面warmup_step的计算。
  2. 而且这个pr的train_dataset_sample是基于Refactor dataset #802 得到的train_dataset的结果,给SwiftArgumentMixin的对应参数,和前面数据集的处理是独立的吧,不知道是否理解有误

neftune_alpha: Optional[float] = None
deepspeed_config_path: Optional[str] = None
model_cache_dir: Optional[str] = None
metric_warmup_step: Optional[float] = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这三个超参数可以解释一下是什么意思么,是一定需要在torch_acc的情况下有效么

可以注释一下只在torchacc下有效嘛

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float 还是 int

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这三个超参数可以解释一下是什么意思么,是一定需要在torch_acc的情况下有效么

可以注释一下只在torchacc下有效嘛

前两个应该是通用的,fsdp_num只在torchacc有效,我注释一下

from typing import List
from typing import List, Optional, Tuple

import einops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ms-swift没有这个依赖, 验证一下是否会导致报错

default='token', metadata={'choices': ['token', 'sentence']})
additional_saved_files: Optional[List[str]] = None
metric_warmup_step: Optional[float] = 0
train_dataset_sample: Optional[int] = -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同, train_dataset_sample将会在下个版本移除,可能产生影响

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logs=None,
**kwargs):
logs['global_step'] = state.global_step
if state.global_step >= self.metric_warmup_step and self.warmup_start_time == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以解释一下这里的逻辑么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以解释一下这里的逻辑么

同上

self.training_bar = tqdm(
desc='Train', total=state.max_steps, dynamic_ncols=True)
self.current_step = 0
self.warmup_start_time = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以解释一下这里嘛

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以解释一下这里嘛

  1. 这里是因为torchacc在初始的阶段需要进行图编译,导致训练速度比后期稳定慢很多,因此加入了这个warmup_step,表示从训练的第warm_up step开始再计算一个训练的平均速度。
  2. 这里的计算逻辑是如果当前step到达了warm_up step,调用transformers的speed_metric函数计算一个指标并进行更新。
  3. args.metric_warmup_step可以是int或者float,表示具体的步数或者比例

swift/llm/sft.py Outdated
logger.info(f'The logging file will be saved in: {logging_path}')
trainer.train(training_args.resume_from_checkpoint)

if args.use_profiler:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以介绍一下这里的逻辑嘛,或者使用环境变量隔离

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以介绍一下这里的逻辑嘛,或者使用环境变量隔离

这里是加了对训练的profiler功能,和使用torchacc是独立的

@Jintao-Huang Jintao-Huang merged commit fdb7a4d into modelscope:main May 22, 2024
tastelikefeet added a commit to tastelikefeet/swift that referenced this pull request May 24, 2024
…3_paligemma

* commit '20bc628746772836fe3838e16e87fb27c39b5ec8':
  fix val_dataset (modelscope#992)
  update custom_val_dataset (modelscope#991)
  [TorchAcc][Experimental] Integrate more model in torchacc (modelscope#683)
  fix cpu 'torch._C' has no attribute '_cuda_resetPeakMemoryStats' (modelscope#914)
  refactor readme web-ui (modelscope#983)
  support  transformers==4.41 (modelscope#979)
  support more models (modelscope#971)
tastelikefeet added a commit to tastelikefeet/swift that referenced this pull request May 28, 2024
* main: (23 commits)
  fix gr limit (modelscope#1016)
  fix minicpm-v (modelscope#1010)
  fix cogvlm2 history (modelscope#1005)
  更新了Command-line-parameters.md里面的一个链接 (modelscope#1001)
  fix template example copy (modelscope#1003)
  Feat/phi3 paligemma (modelscope#998)
  fix pt deploy lora (modelscope#999)
  fix args (modelscope#996)
  fix val_dataset (modelscope#992)
  update custom_val_dataset (modelscope#991)
  [TorchAcc][Experimental] Integrate more model in torchacc (modelscope#683)
  fix cpu 'torch._C' has no attribute '_cuda_resetPeakMemoryStats' (modelscope#914)
  refactor readme web-ui (modelscope#983)
  support  transformers==4.41 (modelscope#979)
  support more models (modelscope#971)
  Fix minicpm device map (modelscope#978)
  fix typing (modelscope#974)
  fix vllm eos_token (modelscope#973)
  Support minicpm-v-v2_5-chat (modelscope#970)
  support cogvlm2-en-chat-19b (modelscope#967)
  ...
hjh0119 pushed a commit to hjh0119/swift that referenced this pull request Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants