Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2.5不支持moe吗? #450

Open
DumoeDss opened this issue Feb 6, 2025 · 2 comments
Open

qwen2.5不支持moe吗? #450

DumoeDss opened this issue Feb 6, 2025 · 2 comments

Comments

@DumoeDss
Copy link

DumoeDss commented Feb 6, 2025

如题,如果不支持的话,后续是否有计划支持Megatron-Core-MoE?

@lostkevin
Copy link
Contributor

您好,在https://www.modelscope.cn/collections/Qwen25-dbc4d30adb768 并没有看到qwen2.5 moe的相关内容。考虑到当前qwen2.5实现后端调用的是qwen2的训练代码,如果您有自定义需求,可以先尝试修改入口脚本~

@DumoeDss
Copy link
Author

我这边更新了一下hf2mcore_qwen2_dense_and_moe_gqa.py,主要是添加了gate的初始化。
然后在mcore转hf时修改了一下模型加载。
现在在qwen2.5-0.5B转mcore moe再转换hf moe之后,会出现补全都是乱码的情况。
"prompt": "圣诞快乐~"
"text": "Lon")). async newValue deleting 같습니다每一位peated SysonClick�-value DAY𝔱 внимIssuer"

代码:

def convert_checkpoint_from_megatron_to_transformers(mgmodel, hfmodel, args):

    if args.fp16:
        mgmodel = mgmodel.half()
        hfmodel = hfmodel.half()
    elif args.bf16:
        mgmodel = mgmodel.bfloat16()
        hfmodel = hfmodel.bfloat16()
    
    num_query_groups = args.num_query_groups
    hidden_size = args.hidden_size
    head_dim = hidden_size // args.num_attention_heads
    use_te = args.transformer_impl == "transformer_engine"
    value_num_per_group = args.num_attention_heads // num_query_groups
    q_dim_per_group = hidden_size // num_query_groups
    kv_dim_per_group = head_dim
    with torch.no_grad():
        hfmodel.model.embed_tokens.weight.copy_(mgmodel.embedding.word_embeddings.weight)
        for mglayer, hflayer in zip(mgmodel.decoder.layers, hfmodel.model.layers):
            if use_te:
                hflayer.input_layernorm.weight.copy_(mglayer.self_attention.linear_qkv.layer_norm_weight)
            else:
                hflayer.input_layernorm.weight.copy_(mglayer.input_layernorm.weight)

            qkv_weight = mglayer.self_attention.linear_qkv.weight.view(num_query_groups, -1, head_dim, hidden_size)
            q_weight, k_weight, v_weight = torch.split(qkv_weight, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
            hflayer.self_attn.q_proj.weight.copy_(q_weight.reshape(-1, hidden_size))
            hflayer.self_attn.k_proj.weight.copy_(k_weight.reshape(-1, hidden_size))
            hflayer.self_attn.v_proj.weight.copy_(v_weight.reshape(-1, hidden_size))

            qkv_bias = mglayer.self_attention.linear_qkv.bias.view(num_query_groups, -1)
            q_bias, k_bias, v_bias = torch.split(qkv_bias, split_size_or_sections=[q_dim_per_group, kv_dim_per_group, kv_dim_per_group], dim=1)
            q_bias = q_bias.contiguous().view(-1)
            k_bias = k_bias.contiguous().view(-1)
            v_bias = v_bias.contiguous().view(-1)

            hflayer.self_attn.q_proj.bias.copy_(q_bias)
            hflayer.self_attn.k_proj.bias.copy_(k_bias)
            hflayer.self_attn.v_proj.bias.copy_(v_bias)

            hflayer.self_attn.o_proj.weight.copy_(mglayer.self_attention.linear_proj.weight)

            if args.num_experts is None:
                gate_weight, fc1_weight = torch.split(mglayer.mlp.linear_fc1.weight, split_size_or_sections=args.ffn_hidden_size)
                hflayer.mlp.gate_proj.weight.copy_(gate_weight)
                hflayer.mlp.up_proj.weight.copy_(fc1_weight)
                hflayer.mlp.down_proj.weight.copy_(mglayer.mlp.linear_fc2.weight)
            else:
                hflayer.mlp.gate.weight.copy_(mglayer.mlp.router.weight)
                for mgexpert, hfexpert in zip(mglayer.mlp.experts.local_experts, hflayer.mlp.experts):
                    gate_weight, up_weight = torch.split(mgexpert.linear_fc1.weight,
                                                        split_size_or_sections=args.moe_ffn_hidden_size)
                    hfexpert.gate_proj.weight.copy_(gate_weight)
                    hfexpert.up_proj.weight.copy_(up_weight)
                    hfexpert.down_proj.weight.copy_(mgexpert.linear_fc2.weight)

                hflayer.mlp.shared_expert_gate.weight.copy_(mglayer.mlp.shared_expert_gate.weight)
                shared_expert_gate_weight, shared_expert_up_weight = \
                    torch.split(mglayer.mlp.shared_expert.linear_fc1.weight,
                                split_size_or_sections=args.shared_moe_ffn_hidden_size)
                hflayer.mlp.shared_expert.gate_proj.weight.copy_(shared_expert_gate_weight)
                hflayer.mlp.shared_expert.up_proj.weight.copy_(shared_expert_up_weight)
                hflayer.mlp.shared_expert.down_proj.weight.copy_(mglayer.mlp.shared_expert.linear_fc2.weight)
                
            if use_te and not args.num_experts:
                hflayer.post_attention_layernorm.weight.copy_(mglayer.mlp.linear_fc1.layer_norm_weight)
            else:
                hflayer.post_attention_layernorm.weight.copy_(mglayer.pre_mlp_layernorm.weight)

        hfmodel.model.norm.weight.copy_(mgmodel.decoder.final_layernorm.weight)
        if args.untie_embeddings_and_output_weights:
            hfmodel.lm_head.weight.copy_(mgmodel.output_layer.weight)
    
def convert_checkpoint_from_transformers_to_megatron(hfmodel, mgmodel, args):

    if args.fp16:
        mgmodel = mgmodel.half()
        hfmodel = hfmodel.half()
    elif args.bf16:
        mgmodel = mgmodel.bfloat16()
        hfmodel = hfmodel.bfloat16()

    assert args.num_query_groups >= args.target_tensor_model_parallel_size

    num_attention_heads = args.num_attention_heads
    num_query_groups = args.num_query_groups
    hidden_size = args.hidden_size
    head_dim = hidden_size // num_attention_heads
    use_te = args.transformer_impl == "transformer_engine"

    with torch.no_grad():
        mgmodel.embedding.word_embeddings.weight.copy_(hfmodel.model.embed_tokens.weight)
        for mglayer, hflayer in zip(mgmodel.decoder.layers, hfmodel.model.layers):
            if use_te:
                mglayer.self_attention.linear_qkv.layer_norm_weight.copy_(hflayer.input_layernorm.weight)
            else:
                mglayer.input_layernorm.weight.copy_(hflayer.input_layernorm.weight)

            q_proj_weight = hflayer.self_attn.q_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
            k_proj_weight = hflayer.self_attn.k_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
            v_proj_weight = hflayer.self_attn.v_proj.weight.view(num_query_groups, -1, head_dim, hidden_size)
            qkv_proj = torch.cat([q_proj_weight, k_proj_weight, v_proj_weight], dim=1).view(-1, hidden_size).contiguous()
            mglayer.self_attention.linear_qkv.weight.copy_(qkv_proj)

            q_proj_bias = hflayer.self_attn.q_proj.bias.view(num_query_groups, -1)
            k_proj_bias = hflayer.self_attn.k_proj.bias.view(num_query_groups, -1)
            v_proj_bias = hflayer.self_attn.v_proj.bias.view(num_query_groups, -1)
            qkv_bias = torch.cat([q_proj_bias, k_proj_bias, v_proj_bias], dim=1).view(-1).contiguous()
            mglayer.self_attention.linear_qkv.bias.copy_(qkv_bias)

            mglayer.self_attention.linear_proj.weight.copy_(hflayer.self_attn.o_proj.weight)

            if args.num_experts is None:
                fc1_weight = torch.cat([hflayer.mlp.gate_proj.weight, hflayer.mlp.up_proj.weight])
                mglayer.mlp.linear_fc1.weight.copy_(fc1_weight)
                mglayer.mlp.linear_fc2.weight.copy_(hflayer.mlp.down_proj.weight)
            else:
                try:
                    mglayer.mlp.router.weight.copy_(hflayer.mlp.gate.weight)
                    for hf_expert, expert in zip(hflayer.mlp.experts, mglayer.mlp.experts.local_experts):
                        fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
                        expert.linear_fc1.weight.copy_(fc1_weight)
                        expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight)
                    mglayer.mlp.shared_expert_gate.weight.copy_(hflayer.mlp.shared_expert_gate.weight)
                    shared_fc1_weight = torch.cat(
                        [hflayer.mlp.shared_expert.gate_proj.weight, hflayer.mlp.shared_expert.up_proj.weight])
                    mglayer.mlp.shared_expert.linear_fc1.weight.copy_(shared_fc1_weight)
                    mglayer.mlp.shared_expert.linear_fc2.weight.copy_(hflayer.mlp.shared_expert.down_proj.weight)
                except:
                    # 初始化 mglayer.mlp.router.weight
                    nn.init.normal_(mglayer.mlp.router.weight, mean=0, std=0.02)

                    split_size = args.ffn_hidden_size // args.num_splits
                    gate_proj_splits = torch.split(hflayer.mlp.gate_proj.weight, split_size_or_sections=split_size)
                    up_proj_splits = torch.split(hflayer.mlp.up_proj.weight, split_size_or_sections=split_size)
                    down_proj_splits = torch.split(hflayer.mlp.down_proj.weight, split_size_or_sections=split_size, dim=1)
                    extra_size = args.moe_ffn_hidden_size - split_size

                    for idx, expert in enumerate(mglayer.mlp.experts.local_experts):
                        base_linear_fc1 = torch.cat([gate_proj_splits[idx%args.num_splits], up_proj_splits[idx%args.num_splits]])
                        extra_linear_fc1 = torch.empty(2*extra_size, base_linear_fc1.shape[1], device=base_linear_fc1.device, dtype=base_linear_fc1.dtype) # 创建空的额外权重
                        extra_linear_fc2 = torch.empty(base_linear_fc1.shape[1], extra_size, device=down_proj_splits[0].device, dtype=down_proj_splits[0].dtype)
                        nn.init.normal_(extra_linear_fc1, mean=0, std=0.02)  # 初始化
                        nn.init.normal_(extra_linear_fc2, mean=0, std=0.02)
                        expert.linear_fc1.weight.copy_(torch.cat([base_linear_fc1, extra_linear_fc1]))
                        expert.linear_fc2.weight.copy_(torch.cat([down_proj_splits[idx%args.num_splits], extra_linear_fc2], dim=1))

                    # 初始化 shared_expert 相关的权重
                    nn.init.normal_(mglayer.mlp.shared_expert_gate.weight, mean=0, std=0.02)
                    nn.init.normal_(mglayer.mlp.shared_expert.linear_fc1.weight, mean=0, std=0.02)
                    nn.init.normal_(mglayer.mlp.shared_expert.linear_fc2.weight, mean=0, std=0.02)
            if use_te and not args.num_experts:
                mglayer.mlp.linear_fc1.layer_norm_weight.copy_(hflayer.post_attention_layernorm.weight)
            else:
                mglayer.pre_mlp_layernorm.weight.copy_(hflayer.post_attention_layernorm.weight)

        mgmodel.decoder.final_layernorm.weight.copy_(hfmodel.model.norm.weight)
        if args.untie_embeddings_and_output_weights:
            mgmodel.output_layer.weight.copy_(hfmodel.lm_head.weight)

def main():
    initialize_megatron(extra_args_provider=add_extra_args)
    args = get_args()

    if args.convert_checkpoint_from_megatron_to_transformers:
        # config = AutoConfig.from_pretrained(args.hf_ckpt_path)
        # hf_model = AutoModelForCausalLM.from_pretrained(args.hf_ckpt_path, torch_dtype=config.torch_dtype)
        # mg_model = load_megatron_model(args)
        config = AutoConfig.from_pretrained(args.load, trust_remote_code=True)
        hf_model = AutoModelForCausalLM.from_config(config)
        mg_model = model_provider()
        convert_checkpoint_from_megatron_to_transformers(mg_model, hf_model, args)
        save_hfmodel(args, hf_model)
    else:
        config = AutoConfig.from_pretrained(args.load)
        hf_model = AutoModelForCausalLM.from_pretrained(args.load, torch_dtype=config.torch_dtype)
        mg_model = model_provider()
        convert_checkpoint_from_transformers_to_megatron(hf_model, mg_model, args)
        if not args.num_experts:
            check_hf_mg_forward(hf_model, mg_model, args)
        save_mgmodel(mg_model, args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants