Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhence mixtoken for qwen #548

Merged
merged 4 commits into from
May 28, 2024
Merged

Conversation

cocoshe
Copy link
Contributor

@cocoshe cocoshe commented May 22, 2024

优化了qwen在开启mixtoken时,构造mixtoken datatse时预先读入了所有样本图片,导致的显存易溢出的问题

测试:

# 修改之后,从dataloader中得到的单个step的输入
# batch_size=1, mixtoken=false
bs1_false_images = np.load('bs1_false_images.npy')
bs1_false_input_ids = np.load('bs1_false_input_ids.npy')

# batch_size=1, mixtoken=true
bs1_true_images = np.load('bs1_true_images.npy')
bs1_true_input_ids = np.load('bs1_true_input_ids.npy')

# batch_size=4, mixtoken=false
bs4_false_images = np.load('bs4_false_images.npy')
bs4_false_input_ids = np.load('bs4_false_input_ids.npy')

# batch_size=4, mixtoken=true
bs4_true_images = np.load('bs4_true_images.npy')
bs4_true_input_ids = np.load('bs4_true_input_ids.npy')

# 修改之前,从dataloader中得到的单个step的输入
# batch_size=1, mixtoken=false
origin_bs1_false_images = np.load('origin_bs1_false_images.npy')
origin_bs1_false_input_ids = np.load('origin_bs1_false_input_ids.npy')

# batch_size=1, mixtoken=true
origin_bs1_true_images = np.load('origin_bs1_true_images.npy')
origin_bs1_true_input_ids = np.load('origin_bs1_true_input_ids.npy')

# batch_size=4, mixtoken=false
origin_bs4_false_images = np.load('origin_bs4_false_images.npy')
origin_bs4_false_input_ids = np.load('origin_bs4_false_input_ids.npy')

# batch_size=4, mixtoken=true
origin_bs4_true_images = np.load('origin_bs4_true_images.npy')
origin_bs4_true_input_ids = np.load('origin_bs4_true_input_ids.npy')


# 验证等效
print(bs1_false_images.shape, origin_bs1_false_images.shape)
print(np.sum(bs1_false_images - origin_bs1_false_images))

print(bs1_false_input_ids.shape, origin_bs1_false_input_ids.shape)
print(np.sum(bs1_false_input_ids - origin_bs1_false_input_ids))

print(bs1_true_images.shape, origin_bs1_true_images.shape)
print(np.sum(bs1_true_images - origin_bs1_true_images))

print(bs1_true_input_ids.shape, origin_bs1_true_input_ids.shape)
print(np.sum(bs1_true_input_ids - origin_bs1_true_input_ids))

print(bs4_false_images.shape, origin_bs4_false_images.shape)
print(np.sum(bs4_false_images - origin_bs4_false_images))

print(bs4_false_input_ids.shape, origin_bs4_false_input_ids.shape)
print(np.sum(bs4_false_input_ids - origin_bs4_false_input_ids))

print(bs4_true_images.shape, origin_bs4_true_images.shape)
print(np.sum(bs4_true_images - origin_bs4_true_images))

print(bs4_true_input_ids.shape, origin_bs4_true_input_ids.shape)
print(np.sum(bs4_true_input_ids - origin_bs4_true_input_ids))

输出:

(1, 3, 448, 448) (1, 3, 448, 448)
0.0
(1, 2048) (1, 2048)
0
(6, 3, 448, 448) (6, 3, 448, 448)
0.0
(1, 2048) (1, 2048)
0
(4, 3, 448, 448) (4, 3, 448, 448)
0.0
(4, 2048) (4, 2048)
0
(24, 3, 448, 448) (24, 3, 448, 448)
0.0
(4, 2048) (4, 2048)
0

Copy link

paddle-bot bot commented May 22, 2024

Thanks for your contribution!

@cocoshe
Copy link
Contributor Author

cocoshe commented May 22, 2024

@LokeZhou 辛苦review~

@LokeZhou LokeZhou added the HappyOpenSource 快乐开源活动issue与PR label May 24, 2024
@LokeZhou LokeZhou merged commit 5205811 into PaddlePaddle:develop May 28, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants