Mixture of Experts(MoE)

MOE是当前比较火的技术之一。比如Mistral、当前最火的deepseek都用到了这一技术。

MOE具有预训练速度更快,推理速度更快的性质。但泛化能力不足,对显存需求比较高。

起源

MOE起源于1991的《Adaptive mixture of local experts》

原论文和BP比,且有着不错的性能。

从其论文图像就可以看出MOE的雏形和如今很像。

发展

在LLM潮浪之前,谷歌多个团队都有着几篇惊艳的论文。

2013年的《Learning Factored Representations in a Deep Mixture of Experts》

作者有Ilya Sutskever。

架构图:

可以看出在原始的基础上还更多考虑了深度。

作者发现,Deep Mix of Experts 自动学习在第一层培养位置相关(“where”)的专家,在第二层自动学习培养特定于类(“what”)的专家。

2017年的《Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer》

作者有Jeff Dean和Hinton。

架构图:

稀疏门控函数选择两个专家来执行计算。他们的输出由门控网络的输出调制。类似于后面的Top2 MOE。

2020年的《GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding》

架构图也是类似的:

在前人的基础上增加了负载均衡、Auxiliary loss、为了能在多个GPU上并行运行的切片等相关操作。

基础版本MOE

MOE是FFN的替代品,其中的Expert 一般是一个 FeadFoward Network,FFN。

1
2
3
4
5
6
7
8
9
10
class BasicExpert(nn.Module):
# 一个 Expert 可以是一个最简单的, linear 层即可
# 也可以是 MLP 层
# 也可以是 更复杂的 MLP 层(active function 设置为 swiglu)
def __init__(self, feature_in, feature_out):
super().__init__()
self.linear = nn.Linear(feature_in, feature_out)

def forward(self, x):
return self.linear(x)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

class BasicMOE(nn.Module):
def __init__(self, feature_in, feature_out, expert_number):
super().__init__()
self.experts = nn.ModuleList(
[
BasicExpert(feature_in, feature_out) for _ in range(expert_number)
]
)
# gate 就是选一个 expert
self.gate = nn.Linear(feature_in, expert_number)

def forward(self, x):
# x 的 shape 是 (batch, feature_in)
expert_weight = self.gate(x) # shape 是 (batch, expert_number)
expert_out_list = [
expert(x).unsqueeze(1) for expert in self.experts
] # 里面每一个元素的 shape 是: (batch, ) ??

# concat 起来 (batch, expert_number, feature_out)
expert_output = torch.cat(expert_out_list, dim=1)

# print(expert_output.size())

expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)

# expert_weight * expert_out_list
output = expert_weight @ expert_output # (batch, 1, feature_out)

return output.squeeze()


def test_basic_moe():
x = torch.rand(2, 4)

basic_moe = BasicMOE(4, 3, 2)
out = basic_moe(x)
print(out)


test_basic_moe()

SparseMoE (大模型训练使用)

比较有名的是switch transformer。

编码器解码器结构

1.6万亿参数的Moe, 2048个专家

网络结构升级

预训练速度为T5-XXL的4倍

和 Basic 区别是,MOE 选择 topK 个专家,然后对这 topK 个专家的输出进行加权求和,并且把输入样本变成了大模型中真实的输入 Shape,(batch, seq_len, hidden_dim)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

# 主要参考自 mistral MOE 的实现

class MOERouter(nn.Module):
def __init__(self, hidden_dim, expert_number, top_k):
super().__init__()
self.gate = nn.Linear(hidden_dim, expert_number)
self.expert_number = expert_number
self.top_k = top_k

def forward(self, hidden_states):
# 计算路由logits
router_logits = self.gate(hidden_states) # shape is (b * s, expert_number)

# 计算专家经过softmax之后的概率
routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)

# 计算topk的专家的输出
router_weights, selected_experts = torch.topk(
routing_probs, self.top_k, dim=-1
) # shape都是 (b * s, top_k)

# 专家权重归一化
router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(hidden_states.dtype)

# 生成专家掩码
expert_mask = F.one_hot(
selected_experts,
num_classes=self.expert_number
) # shape是 (b * s, top_k, expert_number)
expert_mask = expert_mask.permute(2, 1, 0) # (expert_number, top_k, b * s)

return router_logits, router_weights, selected_experts, expert_mask


class MOEConfig:
def __init__(
self,
hidden_dim,
expert_number,
top_k,
shared_experts_number=2,
):
self.hidden_dim = hidden_dim
self.expert_number = expert_number
self.top_k = top_k
self.shared_experts_number = shared_experts_number

class SparseMOE(nn.Module):
# 稀疏 MOE 模型,这里每一个 token 都会过 topk 个专家,得到对应token 的 hidden_embeddings
def __init__(self, config):
super().__init__()

self.hidden_dim = config.hidden_dim

self.expert_number = config.expert_number
self.top_k = config.top_k

self.experts = nn.ModuleList(
[
BasicExpert(self.hidden_dim, self.hidden_dim) for _ in range(self.expert_number)
]
)

self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)

def forward(self, x):
# x shape is (b, s, hidden_dim)
batch_size, seq_len, hidden_dim = x.size()

# 合并前两个维度,因为不是 Sample 维度了,而是 token 维度
hidden_states = x.view(-1, hidden_dim) # shape is(b * s, hidden_dim)

router_logits, router_weights, selected_experts_indices, expert_mask = self.router(hidden_states)
# 其中 selected_experts_indices shape 是 (b * s, top_k)
# 其中 expert_mask shape 是 (expert_number, top_k, b * s)

final_hidden_states = torch.zeros(
(batch_size * seq_len, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device
)

for expert_idx in range(self.expert_number):
expert_layer = self.experts[expert_idx]
# expert_mask[expert_idx] shape 是 (top_k, b * s)
idx, top_x = torch.where(expert_mask[expert_idx])
# idx 和 top_x 都是一维 tensor
# idx 的值是 0 或 1, 表示这个 token 是作为当前专家的 top1 还是 top2
# top_x 的值是 token 在 batch*seq_len 中的位置索引
# 例如对于 batch_size=2, seq_len=4 的输入:
# top_x 的值范围是 0-7, 表示在展平后的 8 个 token 中的位置
# idx 的值是 0/1, 表示这个 token 把当前专家作为其 top1/top2 专家

# hidden_states 的 shape 是 (b * s, hidden_dim)
# 需要取到 top_x 对应的 hidden_states
current_state = hidden_states.unsqueeze(
0
)[:, top_x, :].reshape(-1, hidden_dim) # (selected_token_number, hidden_dim)

# router_weight 的 shape 是 (b * s, top_k)
current_hidden_states = expert_layer(
current_state
) * router_weights[top_x, idx].unsqueeze(-1) # (selected_token_number, 1) 这里有广播

# 把当前专家的输出加到 final_hidden_states 中
# 方式1 的写法性能更好,并且方式1容易出现
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
# 方式2
# final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype)
# 方式1 的写法性能更差,并且方式1容易出现错误,+= 操作在处理重复索引时需要多次读写内存,可能会导致竞争条件

# 把 final_hidden_states 还原到原来的 shape
final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

return final_hidden_states, router_logits # shape 是 (b * s, expert_number)


def test_token_level_moe():
x = torch.rand(2, 4, 16)
config = MOEConfig(16, 2, 2)
token_level_moe = SparseMOE(config)
out = token_level_moe(x)
print(out[0].shape, out[1].shape)


test_token_level_moe()

数学理解

苏神对这一类型的MOE做了一个偏数学的分析。

1、一个常规的Dense模型FFN,可以等价改写为n个Expert向量v1,v2,⋯,vn之和;

FFN可以写为
$$
\begin{equation}\boldsymbol{y} = f(\boldsymbol{x}\boldsymbol{W}^{(A)})\boldsymbol{W}^{(B)}\end{equation}
$$

2、为了节省计算量,我们试图挑出k个向量求和来逼近原本的n个向量之和;

MOE要解决的问题是:

能否只挑k个向量的和来逼近n个向量的和呢?这样就可以将计算量降低到k/n了。

3、转化为数学问题求解后,我们发现挑选规则是模长最大的k个向量;

写成数学公式是

记$\gamma_i = 1 - \lambda_i$,则

在这里苏神强行地假设了$v_i$两两正交。(我觉得这个可以和高维空间任取两个向量几乎正交的性质结合起来)

则上式最优解显然就是让模长$||v_i||$最小的$n-k$个$\gamma_i$等于1。

4、直接去算n个Expert的模长然后选k个实际上是不省计算量的,所以要重新设计Expert;

5、将$v_i$归一化得到$e_i$,然后用另外的小模型(Router)预测模长$p_i$,最终的Expert为$p_ie_i$;

归一化即$\boldsymbol{e}_i = \boldsymbol{v}_i/ \Vert\boldsymbol{v}_i\Vert$

我们也不一定用L2 Normalize,也可以gamma参数恒等于1的RMS Norm等等等。

PS:当实际上当前主流的MOE都没有归一化操作,苏神也在他博客的评论区指出。

6、此时,我们就可以先算全体$p_i$,挑出k个后才去计算$e_i$,达到节省计算量的目的。

ShareExpert SparseMoE (deepseek 版本)

和前面的sparsemoe不同的是,有一个shared experts 的模型,是所有 token 共享的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ShareExpertMOE(nn.Module):
def __init__(self, config):
super().__init__()

self.moe_model = SparseMOE(config)
self.shared_experts = nn.ModuleList(
[
BasicExpert(
config.hidden_dim, config.hidden_dim
) for _ in range(config.shared_experts_number)
]
)

def forward(self, x):
# x shape 是 (b, s, hidden_dim)
# 首先过 moe 模型
sparse_moe_out, router_logits = self.moe_model(x)

# 针对的还是 x 的每一个
# 然后过 shared experts
shared_experts_out = [
expert(x) for expert in self.shared_experts
] # 每一个 expert 的输出 shape 是 (b, s, hidden_dim)

shared_experts_out = torch.stack(
shared_experts_out, dim=0
).sum(dim=0)

# 把 sparse_moe_out 和 shared_experts_out 加起来
return sparse_moe_out + shared_experts_out, router_logits


def test_share_expert_moe():
x = torch.rand(2, 4, 16)
config = MOEConfig(16, 2, 2)
share_expert_moe = ShareExpertMOE(config)
out = share_expert_moe(x)
print(out[0].shape, out[1].shape)


test_share_expert_moe()

模型训练测试代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86

def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
"""
计算 Switch Transformers 的负载均衡损失

Args:
router_logits: shape [batch_size * sequence_length, num_experts]
num_experts: 专家数量

Returns:
total_loss: 总损失 = auxiliary_loss + z_loss
"""
# 计算路由概率
router_probs = torch.softmax(router_logits, dim=-1) # [b*s, num_experts]

# 获取每个token的最优专家
_, selected_experts = torch.topk(router_probs, k=2, dim=-1)

# 创建one-hot矩阵表示选中的专家
mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()

# 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
expected_load = torch.ones_like(router_probs) / num_experts

# 计算实际负载 (每个专家处理的token数量除以总token数量)
# 在batch维度上计算平均值
actual_load = mask.mean(dim=0)

# 计算auxiliary loss
# 这会惩罚负载分布与期望负载的差异
aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts

# 计算z_loss (可选)
# 这会惩罚过大的路由logits
z_loss = torch.mean(torch.square(router_logits))
z_loss_weight = 0.001 # 可调整的超参数

# 总损失
total_loss = aux_loss + z_loss * z_loss_weight

return total_loss

def test_moe_training():
# Create a simple dataset
batch_size = 32
seq_len = 16
hidden_dim = 32
num_batches = 100

# Initialize model and optimizer
config = MOEConfig(hidden_dim=hidden_dim,
expert_number=4,
top_k=2,
shared_experts_number=2)
model = ShareExpertMOE(config)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()
for batch in range(num_batches):
# Generate random input data
x = torch.randn(batch_size, seq_len, hidden_dim)
target = torch.randn(batch_size, seq_len, hidden_dim)

# Forward pass
output, router_logits = model(x)

# Compute losses
# MSE loss for prediction
mse_loss = F.mse_loss(output, target)

aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
# Combined loss
total_loss = mse_loss + 0.01 * aux_loss

# Backward pass and optimize
optimizer.zero_grad()
total_loss.backward()
optimizer.step()

if batch % 10 == 0:
print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

微调

(1)冻结MOE,这个更好,和全量更新相当。

(2)冻结除MOE以外的,类似于迁移学习。

参考资料

  1. LLM MOE的进化之路,从普通简化 MOE,到 sparse moe,再到 deepseek 使用的 share_expert sparse moe
  2. MoE环游记:1、从几何意义出发

Mixture of Experts(MoE)
https://lijianxiong.work/2025/20250215/
作者
LJX
发布于
2025年2月15日
许可协议