Mistral 系列模型整理
在本文中,我们梳理了 24 年 7 月前 Mistral 系列模型的关键信息,包括它们的主要特点、亮点以及相关资源链接。涉及模型 Mistral 7B, Mixtral 8x7B,Mixtral 8x22B,Mistral Nemo, Mistral Large 2
mistral 7B
Mistral 7B 模型的亮点包括:
- Sliding Window Attention
Mistral 采用的 window size 为 4096,而后一共有 32 层 layer,那么采用 SWA 之后,理论上在进行 attention 的时候, 理论上 可以收集到约 131K tokens 的信息。(虽然论文里提到的 window size 是 4096,但 官方提供的 huggingface 上的权重 中 max_position_embeddings
为 32768,且在新一点的版本中,比如 mistral-7b-instruct-v0.2,都不采用 sliding window 了)
由于代用了固定的 attention 窗口大小,因此我们只需要一个大小为 W=window size
的 cache ,在计算第 i 个 token 的 cache 的时候,只需要覆盖 cache 中 i mod M
位置上的 hidden state 即可。
参考 huggingface 的 mistral 实现,Sliding window attention 通过 attention_mask 来控制:
# huggignface mistral attn mask 实现
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
):
# ... 省略部分无关代码
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
exclude_mask.bitwise_or_(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
- GQA (Grouped Query Attention)
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
grouped-query attention 指出,Multi-Query Attention 提高了推理速度的同时,却可能极大地降低回复质量。因此根据上图,GQA 在推理速度和质量之间作了权衡。
以下为 GQA 文中的实验结果,值得注意的是论文中使用原 MHA checkpoint 转换为 GQA 权重后,还进行了额外的预训练:
此外 Mistral,Llama2 的部分模型使用 GQA 时,采用的 kv head 数量似乎都是 8。
为什么现在大家都在用 MQA 和 GQA? 文中提到 MQA 和 GQA 能获得巨大加速的一个点在于:GPU 内存强的限制。由于 MQA 和 GQA 都降低了内存中数据的读取量,减少了计算单元的等待时间,因此推理速度的提高比想象中的要快更多。
Mixtral 8*7B
论文,huggingface 模型权重, 官方博客,huggingface 模型代码,混合专家模型基础(推荐),官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。
发布时间:23 年 12 月
模型大小:8 个 expert MLP 层,一共 45B 大小。
训练:除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。
模型效果:
- 架构:Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:
对 moe 架构不太了解的朋友,可以参考这篇博客 混合专家模型基础(推荐)。
参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe
。
主要逻辑为:
其中 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP
。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天
,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天
时, 有概率所有 expert 都会参与到计算当中 ,具体可以参考 MixtralSparseMoeBlock
的实现。
mixtral 论文中提到专家分配在不同主题(如 ArXiv 论文、生物学和哲学文档)中没有明显的模式,只有在 DM 数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。
- huggingface 中的 mixtral 核心代码:
class MixtralDecoderLayer(nn.Module):
def __init__(self, config: MixtralConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.block_sparse_moe = MixtralSparseMoeBlock(config)
self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
# 此处省略参数 ..
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value = self.self_attn(
# 此处省略参数
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
return outputs
huggingface 中 block_sparse_moe
的实现(省略部分次要代码):
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
self.jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states) # (batch * sequence_length, n_experts)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# current_state: shape (n_i, hidden_dim)
# 所有 current_state 的长度 n 总和为 batch * sequence_length
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
其中: MixtralBlockSparseTop2MLP
长这样:
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states
推理
根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit 的话,30GB 显存就够了。量化的生成速度,可以参考这个 redis 中的评论,大致为 :
推理精度 | 设备 | 速度 tokens/s |
---|---|---|
Q4_K_M | 单卡 4090 + 7950X3D | 20 |
Q4_K_M | 2 x 3090 | 48.26 |
如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:
docker run --gpus all \
-e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
ghcr.io/mistralai/mistral-src/vllm:latest \
--host 0.0.0.0 \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tensor-parallel-size 2 # 100+GB 显存 \
--load-format pt # needed since both `pt` and `safetensors` are available
NVIDIA 的 TensorRT-LLM 博客中发出了对 Mixtral 8*7B 的吞吐量 benchmark (using input and output sequence lengths of 128):
文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2 个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。
Mixtral 8*22B
- 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是
MixtralForCausalLM
,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k,vocab_size
也更大一些。 - 支持 function calling,不过好像没有透露具体的 function calling 训练细节。
- 数学和 coding 能力明显超越 llama2 70B
- 似乎对中文的支持不是很好。
Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:
Mistral Nemo
Mistral Nemo 使用的也是 MistralForCausalLM
架构,与 mistral 7B 的差别为:Mistral Nemo 的 hidden_size
从 4096 变为 5120;max_position_embeddings
变为 1024000,num_hidden_layers
增加到 40, vocab_size 增加到 131072,不用 sliding window。
- 支持 function calling!
- 采用了 Tekken 作为 tokenizer,比 SentencePiece 更高效(压缩率更高,官方描述是~30% more efficient at compressing,不确定是哪个方面的 efficient)
NVIDIA 在这个博客中提到:Mistral Nemo 采用这样的设计,是为了能够适配单个 NVIDIA L40S、NVIDIA GeForce RTX 4090 或 NVIDIA RTX 4500 GPU。模型采用 Megatron-LM 训练,用了 3,072 个 H100 80GB 。
但光采用 FP16 加载整个 Mistral Nemo 就需要花 23 GB 显存,要是要跑满整个 context window size,除了量化外,还是得需要采用 offload 或者其他方法来推理
不过 mistral 官方把 12 B 的模型和其他 8B 的模型对比,感觉好像不太公平:
Mistral Large 2
Mistral Large 2,参数量 123B,主打多语言以及 coding 能力。采用与 mistral 7B 一样的架构,huggingface 中同样使用 MistralForCausalLM
;比较值得注意的是 context window size 为 131072,不用 sliding window。
Llama 3.1 刚出不久,就拿 Mistral Large 2 和别人来对比:
在代码能力上,Mistral large 2 比 llama 3.1 平均效果更好。
除了 coding 和数学外,在 MT Bench 的评分也比 llama 3.1 高,平均生成的回复长度比 llama 3.1 要短
同时,中文能力相对上一代 mistral large 有大步幅提升: