跳至主要內容

MOE 系列模型小记

Kevin 吴嘉文大约 11 分钟知识笔记AIGCLLM

在本文中,我们梳理了近期 (24 年 7 月前)部分 MOE 大模型的关键信息,包括它们的主要特点、亮点以及相关资源链接。涉及模型 Mixtral 8x7B,Mixtral 8x22B,DeepSeek-MoE,Qwen1.5-MoE,DeepSeek-V2

混合专家模型的 Transformer 模型

对于 MOE 的基础,相比 dense model,MOE 的预训练速度更快,推理速度更快,但需要大量的显存。此外,MOE 的训练也有一些独有的 tips,详细的 MOE 混合专家模型基础,推荐参考:

混合专家模型基础(推荐)open in new window

对于一些经典的 MOE 架构模型,可以参考:Mixture-of-Experts (MoE) 经典论文一览open in new window

Mixtral 8*7B

相关资源:论文open in new windowhuggingface 模型权重open in new window官方博客open in new windowhuggingface 模型代码open in new window混合专家模型基础(推荐)open in new window,官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。

  • 发布时间: 23 年 12 月

  • 模型大小 :8 个 expert MLP 层,一共 45B 大小。

  • 训练: 除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。

  • 模型效果:

image-20240805144805475
image-20240805144805475
  • 架构: Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:
image-20240805145126057
image-20240805145126057

对 moe 架构不太了解的朋友,可以参考这篇博客 混合专家模型基础(推荐)open in new window

参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe

https://github.com/open-compass/MixtralKit
https://github.com/open-compass/MixtralKit

主要逻辑为:

G(x)=Softmax(TopK(xWgate))final hidden states=i=0n1G(x)iEi(x) G(x) = \text{Softmax}(TopK(x · W_{gate}))\\ \text{final hidden states} = \sum^{n-1}_{i=0} G(x)_i·E_i(x)

其中 Ei(x)E_i(x) 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天 时, 有概率所有 expert 都会参与到计算当中 ,具体可以参考 MixtralSparseMoeBlock 的实现。

image-20240805144907548
image-20240805144907548

mixtral 论文中提到专家分配在不同主题(如 ArXiv 论文、生物学和哲学文档)中没有明显的模式,只有在 DM 数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。

  • huggingface 中的 mixtral 核心代码:
class MixtralDecoderLayer(nn.Module):
    def __init__(self, config: MixtralConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        self.block_sparse_moe = MixtralSparseMoeBlock(config)
        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        # 此处省略参数 ..
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
        	# 此处省略参数 
        )
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        # Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
        hidden_states, router_logits = self.block_sparse_moe(hidden_states)
        
        hidden_states = residual + hidden_states
        outputs = (hidden_states,)

        return outputs

huggingface 中 block_sparse_moe 的实现(省略部分次要代码):

class MixtralSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

        self.jitter_noise = config.router_jitter_noise

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)  # (batch * sequence_length, n_experts)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            # current_state: shape (n_i, hidden_dim)
            # 所有 current_state 的长度 n 总和为 batch * sequence_length
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

其中: MixtralBlockSparseTop2MLP 长这样:

class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit 的话,30GB 显存就够了。量化的生成速度,可以参考这个 redisopen in new window 中的评论,大致为 :

推理精度设备速度 tokens/s
Q4_K_M单卡 4090 + 7950X3D20
Q4_K_M2 x 309048.26

如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:

docker run --gpus all \
    -e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
    ghcr.io/mistralai/mistral-src/vllm:latest \
    --host 0.0.0.0 \
    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
    --tensor-parallel-size 2 # 100+GB 显存 \
    --load-format pt # needed since both `pt` and `safetensors` are available

NVIDIA 的 TensorRT-LLM 博客open in new window中发出了对 Mixtral 8*7B 的吞吐量 benchmark (using input and output sequence lengths of 128):

image-20240728094958591
image-20240728094958591

文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2 个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。

DeepSeek-MoE

相关资源:githubopen in new window论文 DeepSeekMoE: Towards Ultimate Expert Specialization in Mixture-of-Experts Language Modelsopen in new window

  • 训练: 整个模型在 2T 的中英文预料上训练,实现了和 DeekSeek 7B 及 LlaMA 2 7B 差不都的效果。
  • 模型效果: DeepSeekMoE 16B 推理时候,只用到了 2.8B 的参数,整体的 FLOPs 是 LlaMA 2 7B 的 39.6%;推理速度更快的同时,效果也不差。
image-20240803203140532
image-20240803203140532
  • 架构: DeepSeekMoE 16B 主要亮点在于 fine-grained expert segmentation 和 shared experts isolation.
image-20240728170828489
image-20240728170828489

Fine-grained Expert Segmentation

如上图 B,DeepSeek-MoE 在减少了每个 expert FFN intermediate hidden dimension 的同时,增加激活的 expert 的数量,依次保证总体激活的 expert 的参数量一致。DeepSeekMoE 论文种认为,组合数量的提升,有利于 gate 更准确地选择 expert。

如当我们有 16 个 expert,然后选 top 2 进行推理时,activate expert 的组合数量有 (216)=120(^{16}_{2})=120 种组合,但当将每个 expert 参数缩小 4 倍,expert 个数增加为 64 时,选取 top 8 进行推理时, activate expert 的组合书来给你就有 (864)=442165368(_{8}^{64})=442165368 种。

Shared Expert Isolation

如上图 C,设立一部分 Shared Expert,每次推理的时候都会激活。

Qwen1.5-MoE

官方博客open in new windowgithubopen in new windowhuggingface 权重open in new window

  • 架构重点: 类似于 DeepSeek-MoE,Qwen1.5-MoE 也尝试了 Finegrained experts,整个模型总共设计了 64 个 expert;而后在 routing 机制种也尝试了 Shared Expert Isolation :采用了 4 个总是被激活的共享 expert 和每次只激活其中 4 个的 60 个 routing expert。
  • 训练: 官方博客种表示:从零开始训练 MoE 模型可能效率低下,且难以提升至预期的最优性能水平。因此,Qwen1.5-MoE 首先利用已有的 Qwen-1.8B,将其改造为 Q wen1.5-MoE-A2.7B。此外,在初始化阶段引入随机性可以显著加快收敛速度,并在整个预训练过程中带来更好的整体性能表现。
  • 模型效果: 模型在推理时,总的激活参数为 2.7B。但实现的效果也不错:
image-20240728174009614
image-20240728174009614

官方博客中发布了采用 vllm 部署时候的性能(单个 NVIDIA A100-80G GPU 部署 Qwen1.5-7B 和 Qwen1.5-MoE-A2.7B):

image-20240728174347389
image-20240728174347389

看来 MoE 架构在不牺牲生成质量的情况下,的确可以极大提高吞吐量,降低大模型生成成本。

DeepSeek-V2

githubopen in new window, 论文链接open in new window权重下载open in new windowhuggingface 模型代码open in new window

DeepSeek-V2 文中推出了 DeepSeek-V2-Lite 与 DeepSeek-V2 一小一大 2 个版本。

推理速度:

DeepSeek V2 首先对模型进行了 KV Cache 量化,将参数转换为了 FP8。在单机 8 卡 H800 的节点上部署 DeepSeek-V2,可以达到约 50K tokens/秒 的吞吐量

image-20240728180045838
image-20240728180045838

推理效果,中文水平更强一些,英文水平于 Mixtral 8*22B 有的一比:

image-20240803213442775
image-20240803213442775
  • 模型架构重点 :其中,架构采用了 MLA 取代 MHA,同时 MOE 架构采用了 DeepSeekMoE 的 fine-grained expert segmentation 和 shared experts isolation。整体的 DeepSeek Layer 架构如下:
image-20240728180700469
image-20240728180700469
  • DeepSeekMoE

如上图展示的,DeepSeek-V2 同样采用了 DeepSeekMoE 的策略,其中有 2 个 shared experts , 160 个 routed experts(每次只激活 6 个)

  • Multi-Head Latent Attention

DeepSeek-V2 中着重讲了这一部分的优化。

image-20240728185206897
image-20240728185206897

Low-Rank Key-Value Joint Compression

为了减少 KV cache,MLA 提出将 k, v 的计算方式变为:

ctKV=WDKVhtktC=WUKctKVvtC=WUVctKV \begin{aligned} \bold c_t^{KV} &= W^{DKV}\bold h_t\\ \bold k_t^C &= W^{UK}\bold c^{KV}_t\\ \bold v_t^C &= W^{UV}\bold c^{KV}_t\\ \end{aligned}

其中,ctKVRdc,WDKVRdc×d,WUKRdhnh×dcc_t^{KV} \in \mathbb{R}^{d_c},W^{DKV} \in \mathbb R^{d_c\times d}, W^{UK} \in \mathbb{R}^{d_h n_h \times d_c}

q 的计算方法变为:

ctQ=WDQhtqtC=WUQctQ \begin{aligned} \bold c_t^{Q} &= W^{DQ}\bold h_t\\ \bold q_t^C &= W^{UQ}\bold c^{Q}_t\\ \end{aligned}

其中,ctQRdc,WDQRdc×d,WUQRdhnh×dcc_t^Q \in \mathbb{R}^{d_c'},W^{DQ} \in \mathbb R^{d_c'\times d}, W^{UQ} \in \mathbb{R}^{d_h n_h \times d_c'}

因此,k,v 均从 ctKV\bold c_t^{KV} 进一步计算得来。在推理时候,传统 MHA 需要 cache k,vk,v,但通过以上变化后,只需要 cache ctKV\bold c_t^{KV} 即可。这样,q,kq,k 点积就变成了。

qTk=(WUQWDQht)T(WUKctKV)=hT((WUQWDQ)TWUK)ctKV \begin{aligned} q^Tk & =(W^{UQ}W^{DQ}\bold h_t)^T(W^{UK}\bold c^{KV}_t)\\ &=\bold h^T((W^{UQ}W^{DQ})^TW^{UK})\bold c^{KV}_t \end{aligned}

推理过程中,可以合并 (WUQWDQ)TWUK)(W^{UQ}W^{DQ})^TW^{UK}),以此达到减少 cache 同时,不会增加太多的计算量。

Decoupled Rotary Position Embedding

以上方案的一个问题是,不兼容 RoPE。由于 RoPE 的存在,

qq 不再是单纯的 WhW\bold h,而是需要内积上相对位置矩阵 RR,因此就无法简单得合并 (WUQWDQ)TWUK)(W^{UQ}W^{DQ})^TW^{UK})

MLA 采用了以下 decoupled RoPE 方案:

[qt,1R;qt,2R;;qt,nhR]=qtR=RoPE(WQRctQ),ktR=RoPE(WKRht),qt,i=[qt,iC;qt,iR],kt,i=[kt,iC;ktR],ot,i=j=1tSoftmaxj(qt,iTkj,idh+dhR)vj,iC,ut=WO[ot,1;ot,2;;ot,nh], \begin{aligned} \left[ q_{t,1}^R ; q_{t,2}^R ; \cdots ; q_{t,n_h}^R \right] &= q_t^R = \text{RoPE}(W^{QR} \bold c_t^Q), \quad \quad \quad \\ \bold k_t^R &= \text{RoPE}(W^{KR} \bold h_t), \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \\ \bold q_{t,i} &= \left[\bold q_{t,i}^C ; \bold q_{t,i}^R \right], \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \\ \bold k_{t,i} &= \left[\bold k_{t,i}^C ; \bold k_{t}^R \right], \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \quad \\ \bold o_{t,i} &= \sum_{j=1}^t \text{Softmax}_j \left( \frac{\bold q_{t,i}^T \bold k_{j,i}}{\sqrt{d_h + d_h^R}} \right) \bold v_{j,i}^C, \quad \quad \\ \bold u_t &= W^O \left[ o_{t,1} ; o_{t,2} ; \cdots ; o_{t,n_h} \right], \quad \quad \quad \quad \quad \end{aligned}

大概思路是,在原先的 qk 中,增加几个维度,用来注入 RoPE 位置信息,比较值得注意的是,k 新增加的维度 ktR\bold k_{t}^R 是所有 head 共享的。其中,WQRRdhRnh×dcW^{QR} \in \mathbb{R}^{d_h^R n_h \times d_c'}, WKRRdhR×dW^{KR} \in \mathbb{R}^{d_h^R \times d}, 因此,q,k 的维度增加到了 (dc+dhR)(d_c + d_h^R)

更深入的 MLA 解读,可以参考:缓存与效果的极限拉扯:从 MHA、MQA、GQA 到 MLAopen in new window 或 deepseek v2 原论文。

文中给出了 MLA 与 MHA, GQA 的效果对比:

MLA 的 KV 的 cache 数量比 MHA, GQA 少了不少。

image-20240803211417135
image-20240803211417135

MLA KV cache 比 MHA 小的同时,效果也不会太差。

image-20240803211623833
image-20240803211623833
  • 预训练: 相比于 DeepSeek 67B,DeepSeek-V2 训练集中有更多的中文数据,同时 DeepSeek V2 对数据过滤算法进行了改进,包括筛除有争议的内容等;DeepSeek V2 使用于 DeepSeek 67B 同样的 tokenizer,vocab size 为 100k,预训练语料约有 8.1T tokens,其中中文比英文多了 12%。整个预训练花费了 172.8K GPU hours 的算力。
  • 超长上下文: 在适配了 Yarn 之后,额外在 32k 的数据集上训练了 1000 steps,batch size 为 576。文中表示,尽管是在 32K 数据集上训练,但在 128K 的大海捞针测试中,模型表现也不错:
image-20240803213356533
image-20240803213356533
  • SFT :在包含了 1.5M 组训练实例的数据集上,微调了 2 个 epoch。
  • RLHF: 采用了 GRPO 来节省 RL 训练的成本,主要是将 PPO 过程中的 advantage 替换成了 A^i,t=rimean(r)std(r)\hat A_{i,t} = \frac {r_i - mean(r)}{std(r)},因此在 RLHF 过程中就不需要 Value model 了。具体算法如下:
image-20240804172843610
image-20240804172843610

在 RLHF 训练过程中,采取了 2 阶段训练。首先进行了 reasoning alignment,而后进行 human preference alignment。

更多训练细节欢迎参考 DeepSeek-V2 论文open in new window

Mixtral 8*22B

官方博客open in new windowhuggingface 开源模型open in new window

  • 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是MixtralForCausalLM ,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k, vocab_size 也更大一些。
  • 支持 function calling,不过好像没有透露具体的 function calling 训练细节。
  • 数学和 coding 能力明显超越 llama2 70B
  • 似乎对中文的支持不是很好。
image-20240805145309786
image-20240805145309786

Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:

image-20240805145321370
image-20240805145321370