MOE起源于1991的《Adaptive mixture of local experts》
2013年的《Learning Factored Representations in a Deep Mixture of Experts》
作者有Ilya Sutskever。
作者发现,Deep Mix of Experts 自动学习在第一层培养位置相关(“where”)的专家,在第二层自动学习培养特定于类(“what”)的专家。
2017年的《Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer》
作者有Jeff Dean和Hinton。
稀疏门控函数选择两个专家来执行计算。他们的输出由门控网络的输出调制。类似于后面的Top2 MOE。
2020年的《GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding》
在前人的基础上增加了负载均衡、Auxiliary loss、为了能在多个GPU上并行运行的切片等相关操作。
MOE是FFN的替代品,其中的Expert 一般是一个 FeadFoward Network,FFN。
1 2 3 4 5 6 7 8 9 10 class BasicExpert (nn.Module ): def __init__ (self, feature_in, feature_out ): super ().__init__() self.linear = nn.Linear(feature_in, feature_out) def forward (self, x ): return self.linear(x)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class BasicMOE (nn.Module ): def __init__ (self, feature_in, feature_out, expert_number ): super ().__init__() self.experts = nn.ModuleList( [ BasicExpert(feature_in, feature_out) for _ in range (expert_number) ] ) self.gate = nn.Linear(feature_in, expert_number) def forward (self, x ): expert_weight = self.gate(x) expert_out_list = [ expert(x).unsqueeze(1 ) for expert in self.experts ] expert_output =, dim=1 ) expert_weight = expert_weight.unsqueeze(1 ) output = expert_weight @ expert_output return output.squeeze()def test_basic_moe (): x = torch.rand(2 , 4 ) basic_moe = BasicMOE(4 , 3 , 2 ) out = basic_moe(x) print (out) test_basic_moe()
SparseMoE (大模型训练使用)
比较有名的是switch transformer。
1.6万亿参数的Moe, 2048个专家
和 Basic 区别是,MOE 选择 topK 个专家,然后对这 topK 个专家的输出进行加权求和,并且把输入样本变成了大模型中真实的输入 Shape,(batch, seq_len, hidden_dim)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 class MOERouter (nn.Module ): def __init__ (self, hidden_dim, expert_number, top_k ): super ().__init__() self.gate = nn.Linear(hidden_dim, expert_number) self.expert_number = expert_number self.top_k = top_k def forward (self, hidden_states ): router_logits = self.gate(hidden_states) routing_probs = F.softmax(router_logits, dim=-1 , dtype=torch.float ) router_weights, selected_experts = torch.topk( routing_probs, self.top_k, dim=-1 ) router_weights = router_weights / router_weights.sum (dim=-1 , keepdim=True ) router_weights = expert_mask = F.one_hot( selected_experts, num_classes=self.expert_number ) expert_mask = expert_mask.permute(2 , 1 , 0 ) return router_logits, router_weights, selected_experts, expert_maskclass MOEConfig : def __init__ ( self, hidden_dim, expert_number, top_k, shared_experts_number=2 , ): self.hidden_dim = hidden_dim self.expert_number = expert_number self.top_k = top_k self.shared_experts_number = shared_experts_numberclass SparseMOE (nn.Module ): def __init__ (self, config ): super ().__init__() self.hidden_dim = config.hidden_dim self.expert_number = config.expert_number self.top_k = config.top_k self.experts = nn.ModuleList( [ BasicExpert(self.hidden_dim, self.hidden_dim) for _ in range (self.expert_number) ] ) self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k) def forward (self, x ): batch_size, seq_len, hidden_dim = x.size() hidden_states = x.view(-1 , hidden_dim) router_logits, router_weights, selected_experts_indices, expert_mask = self.router(hidden_states) final_hidden_states = torch.zeros( (batch_size * seq_len, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) for expert_idx in range (self.expert_number): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) current_state = hidden_states.unsqueeze( 0 )[:, top_x, :].reshape(-1 , hidden_dim) current_hidden_states = expert_layer( current_state ) * router_weights[top_x, idx].unsqueeze(-1 ) final_hidden_states.index_add_(0 , top_x, final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim) return final_hidden_states, router_logits def test_token_level_moe (): x = torch.rand(2 , 4 , 16 ) config = MOEConfig(16 , 2 , 2 ) token_level_moe = SparseMOE(config) out = token_level_moe(x) print (out[0 ].shape, out[1 ].shape) test_token_level_moe()
\begin{equation}\boldsymbol{y} = f(\boldsymbol{x}\boldsymbol{W}^{(A)})\boldsymbol{W}^{(B)}\end{equation}
记$\gamma_i = 1 - \lambda_i$,则
在这里苏神强行地假设了$v_i$两两正交。(我觉得这个可以和高维空间任取两个向量几乎正交的性质 结合起来)
归一化即$\boldsymbol{e}_i = \boldsymbol{v}_i/ \Vert\boldsymbol{v}_i\Vert$
我们也不一定用L2 Normalize,也可以gamma参数恒等于1的RMS Norm等等等。
ShareExpert SparseMoE (deepseek 版本)
和前面的sparsemoe不同的是,有一个shared experts 的模型,是所有 token 共享的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 class ShareExpertMOE (nn.Module ): def __init__ (self, config ): super ().__init__() self.moe_model = SparseMOE(config) self.shared_experts = nn.ModuleList( [ BasicExpert( config.hidden_dim, config.hidden_dim ) for _ in range (config.shared_experts_number) ] ) def forward (self, x ): sparse_moe_out, router_logits = self.moe_model(x) shared_experts_out = [ expert(x) for expert in self.shared_experts ] shared_experts_out = torch.stack( shared_experts_out, dim=0 ).sum (dim=0 ) return sparse_moe_out + shared_experts_out, router_logitsdef test_share_expert_moe (): x = torch.rand(2 , 4 , 16 ) config = MOEConfig(16 , 2 , 2 ) share_expert_moe = ShareExpertMOE(config) out = share_expert_moe(x) print (out[0 ].shape, out[1 ].shape) test_share_expert_moe()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 def switch_load_balancing_loss (router_logits: torch.Tensor, num_experts: int ) -> torch.Tensor: """ 计算 Switch Transformers 的负载均衡损失 Args: router_logits: shape [batch_size * sequence_length, num_experts] num_experts: 专家数量 Returns: total_loss: 总损失 = auxiliary_loss + z_loss """ router_probs = torch.softmax(router_logits, dim=-1 ) _, selected_experts = torch.topk(router_probs, k=2 , dim=-1 ) mask = torch.nn.functional.one_hot(selected_experts, num_experts).float () expected_load = torch.ones_like(router_probs) / num_experts actual_load = mask.mean(dim=0 ) aux_loss = torch.sum (actual_load * router_probs.mean(dim=0 )) * num_experts z_loss = torch.mean(torch.square(router_logits)) z_loss_weight = 0.001 total_loss = aux_loss + z_loss * z_loss_weight return total_lossdef test_moe_training (): batch_size = 32 seq_len = 16 hidden_dim = 32 num_batches = 100 config = MOEConfig(hidden_dim=hidden_dim, expert_number=4 , top_k=2 , shared_experts_number=2 ) model = ShareExpertMOE(config) optimizer = torch.optim.Adam(model.parameters(), lr=0.001 ) model.train() for batch in range (num_batches): x = torch.randn(batch_size, seq_len, hidden_dim) target = torch.randn(batch_size, seq_len, hidden_dim) output, router_logits = model(x) mse_loss = F.mse_loss(output, target) aux_loss = switch_load_balancing_loss(router_logits, config.expert_number) total_loss = mse_loss + 0.01 * aux_loss optimizer.zero_grad() total_loss.backward() optimizer.step() if batch % 10 == 0 : print (f"Batch {batch} , Loss: {total_loss.item():.4 f} " f"(MSE: {mse_loss.item():.4 f} , Aux: {aux_loss.item():.4 f} )" ) test_moe_training()
LLM MOE的进化之路,从普通简化 MOE,到 sparse moe,再到 deepseek 使用的 share_expert sparse moe