跳至主要內容

Mistral 系列模型整理

Kevin 吴嘉文大约 9 分钟知识笔记AIGCLLM

在本文中,我们梳理了 24 年 7 月前 Mistral 系列模型的关键信息,包括它们的主要特点、亮点以及相关资源链接。涉及模型 Mistral 7B, Mixtral 8x7B,Mixtral 8x22B,Mistral Nemo, Mistral Large 2

mistral 7B

官方博客open in new windowmistral 7B 论文open in new window

Mistral 7B 模型的亮点包括:

  • Sliding Window Attention

Mistral 采用的 window size 为 4096,而后一共有 32 层 layer,那么采用 SWA 之后,理论上在进行 attention 的时候, 理论上 可以收集到约 131K tokens 的信息。(虽然论文里提到的 window size 是 4096,但 官方提供的 huggingface 上的权重open in new windowmax_position_embeddings 为 32768,且在新一点的版本中,比如 mistral-7b-instruct-v0.2open in new window,都不采用 sliding window 了)

image-20240805103929202
image-20240805103929202

由于代用了固定的 attention 窗口大小,因此我们只需要一个大小为 W=window size 的 cache ,在计算第 i 个 token 的 cache 的时候,只需要覆盖 cache 中 i mod M 位置上的 hidden state 即可。

参考 huggingface 的 mistral 实现,Sliding window attention 通过 attention_mask 来控制:

    # huggignface mistral attn mask 实现
    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
    ):
        # ... 省略部分无关代码
        past_seen_tokens = cache_position[0] if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)
        using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        # SlidingWindowCache
        if using_sliding_window_cache:
            target_length = max(sequence_length, self.config.sliding_window)
        # StaticCache
        elif using_static_cache:
            target_length = past_key_values.get_max_length()
        # DynamicCache or no cache
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
            causal_mask = attention_mask
        else:
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            if self.config.sliding_window is not None:
                if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
                    exclude_mask.bitwise_or_(
                        torch.arange(target_length, device=device)
                        <= (cache_position.reshape(-1, 1) - self.config.sliding_window)
                    )
            causal_mask *= exclude_mask
            causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                if attention_mask.dim() == 2:
                    mask_length = attention_mask.shape[-1]
                    padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                    padding_mask = padding_mask == 0
                    causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                        padding_mask, min_dtype
                    )

        return causal_mask
  • GQA (Grouped Query Attention)

GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpointsopen in new window

image-20240805103944726
image-20240805103944726

grouped-query attention 指出,Multi-Query Attentionopen in new window 提高了推理速度的同时,却可能极大地降低回复质量。因此根据上图,GQA 在推理速度和质量之间作了权衡。

以下为 GQA 文中的实验结果,值得注意的是论文中使用原 MHA checkpoint 转换为 GQA 权重后,还进行了额外的预训练:

image-20240805103956240
image-20240805103956240

此外 Mistral,Llama2 的部分模型使用 GQA 时,采用的 kv head 数量似乎都是 8。

为什么现在大家都在用 MQA 和 GQA?open in new window 文中提到 MQA 和 GQA 能获得巨大加速的一个点在于:GPU 内存强的限制。由于 MQA 和 GQA 都降低了内存中数据的读取量,减少了计算单元的等待时间,因此推理速度的提高比想象中的要快更多。

Mixtral 8*7B

论文open in new windowhuggingface 模型权重open in new window官方博客open in new windowhuggingface 模型代码open in new window混合专家模型基础(推荐)open in new window,官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。

  • 发布时间:23 年 12 月

  • 模型大小:8 个 expert MLP 层,一共 45B 大小。

  • 训练:除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。

  • 模型效果:

image-20240805104006798
image-20240805104006798
  • 架构:Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:
image-20240805104016338
image-20240805104016338

对 moe 架构不太了解的朋友,可以参考这篇博客 混合专家模型基础(推荐)open in new window

参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe

https://github.com/open-compass/MixtralKit
https://github.com/open-compass/MixtralKit

主要逻辑为:

G(x)=Softmax(TopK(xWgate))final hidden states=i=0n1G(x)iEi(x) G(x) = \text{Softmax}(TopK(x · W_{gate}))\\ \text{final hidden states} = \sum^{n-1}_{i=0} G(x)_i·E_i(x)

其中 Ei(x)E_i(x) 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天 时, 有概率所有 expert 都会参与到计算当中 ,具体可以参考 MixtralSparseMoeBlock 的实现。

image-20240805104032194
image-20240805104032194

mixtral 论文中提到专家分配在不同主题(如 ArXiv 论文、生物学和哲学文档)中没有明显的模式,只有在 DM 数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。

  • huggingface 中的 mixtral 核心代码:
class MixtralDecoderLayer(nn.Module):
    def __init__(self, config: MixtralConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

        self.block_sparse_moe = MixtralSparseMoeBlock(config)
        self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        # 此处省略参数 ..
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
        	# 此处省略参数 
        )
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        
        # Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
        hidden_states, router_logits = self.block_sparse_moe(hidden_states)
        
        hidden_states = residual + hidden_states
        outputs = (hidden_states,)

        return outputs

huggingface 中 block_sparse_moe 的实现(省略部分次要代码):

class MixtralSparseMoeBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok

        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])

        self.jitter_noise = config.router_jitter_noise

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)  # (batch * sequence_length, n_experts)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            # current_state: shape (n_i, hidden_dim)
            # 所有 current_state 的长度 n 总和为 batch * sequence_length
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits

其中: MixtralBlockSparseTop2MLP 长这样:

class MixtralBlockSparseTop2MLP(nn.Module):
    def __init__(self, config: MixtralConfig):
        super().__init__()
        self.ffn_dim = config.intermediate_size
        self.hidden_dim = config.hidden_size

        self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
        self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
        self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)

        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_states):
        current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
        current_hidden_states = self.w2(current_hidden_states)
        return current_hidden_states

推理

根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit 的话,30GB 显存就够了。量化的生成速度,可以参考这个 redisopen in new window 中的评论,大致为 :

推理精度设备速度 tokens/s
Q4_K_M单卡 4090 + 7950X3D20
Q4_K_M2 x 309048.26

如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:

docker run --gpus all \
    -e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
    ghcr.io/mistralai/mistral-src/vllm:latest \
    --host 0.0.0.0 \
    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
    --tensor-parallel-size 2 # 100+GB 显存 \
    --load-format pt # needed since both `pt` and `safetensors` are available

NVIDIA 的 TensorRT-LLM 博客open in new window中发出了对 Mixtral 8*7B 的吞吐量 benchmark (using input and output sequence lengths of 128):

image-20240728094958591
image-20240728094958591

文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2 个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。

Mixtral 8*22B

官方博客open in new windowhuggingface 开源模型open in new window

  • 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是MixtralForCausalLM ,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k, vocab_size 也更大一些。
  • 支持 function calling,不过好像没有透露具体的 function calling 训练细节。
  • 数学和 coding 能力明显超越 llama2 70B
  • 似乎对中文的支持不是很好。
image-20240727152529176
image-20240727152529176

Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:

image-20240727152702974
image-20240727152702974

Mistral Nemo

官方博客open in new windowhuggingface 模型权重open in new window

Mistral Nemo 使用的也是 MistralForCausalLM 架构,与 mistral 7B 的差别为:Mistral Nemo 的 hidden_size 从 4096 变为 5120;max_position_embeddings 变为 1024000,num_hidden_layers 增加到 40, vocab_size 增加到 131072,不用 sliding window。

  • 支持 function calling!
  • 采用了 Tekken 作为 tokenizer,比 SentencePiece 更高效(压缩率更高,官方描述是~30% more efficient at compressing,不确定是哪个方面的 efficient)

NVIDIA 在这个博客open in new window中提到:Mistral Nemo 采用这样的设计,是为了能够适配单个 NVIDIA L40S、NVIDIA GeForce RTX 4090 或 NVIDIA RTX 4500 GPU。模型采用 Megatron-LMopen in new window 训练,用了 3,072 个 H100 80GB 。

但光采用 FP16 加载整个 Mistral Nemo 就需要花 23 GB 显存,要是要跑满整个 context window size,除了量化外,还是得需要采用 offload 或者其他方法来推理

不过 mistral 官方把 12 B 的模型和其他 8B 的模型对比,感觉好像不太公平:

image-20240727154936831
image-20240727154936831

Mistral Large 2

官方博客open in new windowhuggingface 模型权重open in new window

Mistral Large 2,参数量 123B,主打多语言以及 coding 能力。采用与 mistral 7B 一样的架构,huggingface 中同样使用 MistralForCausalLM;比较值得注意的是 context window size 为 131072,不用 sliding window。

Llama 3.1 刚出不久,就拿 Mistral Large 2 和别人来对比:

image-20240805104145922
image-20240805104145922

在代码能力上,Mistral large 2 比 llama 3.1 平均效果更好。

image-20240727201347583
image-20240727201347583

除了 coding 和数学外,在 MT Bench 的评分也比 llama 3.1 高,平均生成的回复长度比 llama 3.1 要短

image-20240805104200493
image-20240805104200493

同时,中文能力相对上一代 mistral large 有大步幅提升:

image-20240805104212135
image-20240805104212135