后门攻击简单来说,是攻击者通过在训练过程中嵌入触发器(trigger)来操纵测试时的预测。
但后面攻击普遍用于CV或分类中,时间序列(预测)中比较少。
通用方法
发表于SaTML’23的《Backdoor Attacks on Time Series: A Generative Approach 》,SaTML是比较新的IEEE会议。
既然是通用,其暗示即简单。事实也如此。
为了解决生成器训练的冷启动问题,首先在来自 D 的所有干净样本上对时间序列分类器 f 进行预训练,直到其交叉熵损失 LCE 稳定下
降。
预训练完成后,我们同时训练触发器生成器 g 和部分训练的分类器 f 。
在每次迭代中,两个网络都按照类似的过程逐步更新:
1)在污染样本上训练 g ,以最小化针对后门类别 $y_t$ 的分类损失(第 15-17 行);
2)使用 g 生成污染数据集 D′ ;
3)在污染数据集 D′ 上训练 f ,其中污染样本被重新标记为 $y_t$。
整个过程中,后门触发模式被限制在信号幅度的 10% 以内,即 0.1 ∗ (xmax − xmin) ,以增强隐蔽性。
针对多元时间序列预测
NIP2 24 《BACKTIME: Backdoor Attacks on Multivariate Time Series Forecasting 》
多元时间序列(MTS)数据与单变量时间序列相比,会有变量间相关系数的存在,这会使对MTS 数据的攻击更为复杂。
在多变量时间序列预测中,常见做法是将数据集切分为时间窗口,作为预测模型的输入。
然而,在中毒数据集中,识别这些切分的时间窗口是否被中毒面临两大挑战。
(1)这些时间窗口的长度可能与触发器或目标模式的长度不匹配。
(2)当将数据集切分为时间窗口时,这些窗口可能仅包含触发器或目
标模式的一部分。
为解决这些问题,作者假设只有当输入包含触发器的所有组成部分时,注入的后门才会被激活。
对于后门攻击,我们需要确定三个关键要素:
(1)攻击何处。这个由攻击者指定
(2)何时攻击,即选择攻击的时间点
(3)如何攻击,即指定注入的触发器。
攻击时间选择
预测误差较高的时间戳更容易遭受攻击。
利用预训练的干净模型计算预测与真实值之间的 MAE,并进一步选择MAE 最高的前 α个时间戳。
攻击方法
首先,通过利用MLP 来捕捉目标变量 S 内部的变量间相关性,生成一个加权图。然后,进一步基于学成的加权图利用图卷积网络(GCN)进行触发器生成。
由于时间序列跨度比较大,故作者采取的是使用DFT并取低维特征来降低维度和提取信息。即$z_i=Filter(DFT(x_i),k)$。
并使用MLP来构建相关性的图,即$A _ {i,j}=cos(MLP(z_i),MLP(z_j))$。
使用长度为$t^{BEF}$的时间窗口对触发器之前的历史数据进行切片。随后,我们基于切片后的历史
数据,利用 GCN 进行触发器生成:
$$
\hat g_{t_i}=GCN(X^{ATK}[t_i-t^{BEF}-t^{TGR}:t_i-t^{TGR},S],A),\forall t_i\in T^{ATK}
$$
然而,GCN 倾向于激进地增加输出 $\hat g _ {t_i}$的幅度。作者对这种行为的一个潜在解释是,较大的触发幅度会导致显著的偏差,而具有此类偏差特征的数据点更容易被预测模型学成,尽管它们违背了隐蔽性的要求。
为了解决这一问题,作者提出引入一个非线性缩放函数tanh(·),对输出幅度施加强制性限制来生成隐蔽的触发器。
$$
g_{t_i}=\Delta^{TGR}\cdot tanh(\hat g_{t_i}) \tag{6}
$$
由于原始问题是个双层优化问题,故作者引入了一个替代预测模型$f_s$,以提供精确解的实用近似
同样的,也会使用预热来避免难以训练或难以收敛。在预热阶段,我们仅训练替代模型,使其具备合理的预测能力。一旦预热阶段结束,我们将同时更新替代模型和触发器生成器。
第一阶段,替代模型更新:
$$
l _ {cln}=L _ {CLN}(f_s(X^{ATK} _ {t_i,h}),X^{ATK} _ {t _ i,f})\tag{7}
$$
其中使用平滑L1loss。
第二阶段,触发器生成器更新:
固定代理模型的参数后,攻击损失可以形式化为:
$$
l_{atk}=\sum _ {t_i=t}^{t+t^{PTN}} L _ {ATK}(f_s(X^{ATK}_{t_i,h}),X^{ATK} _ {t_i,f})\eta(t_i)
$$
PTN表示目标。
现实世界数据集中的多变量时间序列(MTS)数据普遍存在高频波动或噪声。然而该方法并不能天然保证触发器具有高频信号。因此,为了弥补这一差距,引入了以下规范化损失:
$$
l_{norm}=AVG\left(\left|\sum_{i=0}^{t^{TGR}} g_{t_i}[i,:]\right|\right)\tag{9}
$$
可能有点云里雾里,不妨我们来看代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 class Trainer : def __init__ (self, config, atk_vars, target_pattern, train_mean, train_std, train_data, test_data, train_data_stamps, test_data_stamps, device ): self.net = MODEL_MAP[self.config.surrogate_name](self.config.Surrogate).to(device) self.optimizer = optim.Adam(self.net.parameters(), lr=config.learning_rate) train_set = TimeDataset(train_data, train_mean, train_std, device, num_for_hist=12 , num_for_futr=12 , timestamps=train_data_stamps) channel_features = fft_compress(train_data, 200 ) self.attacker = Attacker(train_set, channel_features, atk_vars, config, target_pattern, device) self.use_timestamps = config.Dataset.use_timestamps self.prepare_data() def train (self ): self.attacker.train() poison_metrics = [] for epoch in range (self.num_epochs): self.net.train() if epoch > self.warmup: if not hasattr (self.attacker, 'atk_ts' ): self.attacker.select_atk_timestamp(poison_metrics) self.attacker.sparse_inject() poison_metrics = [] self.train_loader = DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True ) pbar = tqdm.tqdm(self.train_loader, desc=f'Training data {epoch} /{self.num_epochs} ' ) for batch_index, batch_data in enumerate (pbar): if not self.use_timestamps: encoder_inputs, labels, clean_labels, idx = batch_data x_mark = torch.zeros(encoder_inputs.shape[0 ], encoder_inputs.shape[-1 ], 4 ).to(self.device) else : encoder_inputs, labels, clean_labels, x_mark, y_mark, idx = batch_data encoder_inputs = torch.squeeze(encoder_inputs).to(self.device).permute(0 , 2 , 1 ) labels = torch.squeeze(labels).to(self.device).permute(0 , 2 , 1 ) self.optimizer.zero_grad() x_des = torch.zeros_like(labels) outputs = self.net(encoder_inputs, x_mark, x_des, None ) outputs = self.train_set.denormalize(outputs) loss_per_sample = F.smooth_l1_loss(outputs, labels, reduction='none' ) loss_per_sample = loss_per_sample.mean(dim=(1 , 2 )) poison_metrics.append(torch.stack([loss_per_sample.cpu().detach(), idx.cpu().detach()], dim=1 )) loss = loss_per_sample.mean() loss.backward() self.optimizer.step() if epoch > self.warmup: self.attacker.update_trigger_generator(self.net, epoch, self.num_epochs, use_timestamps=self.use_timestamps)
其中
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def sparse_inject (self ): self.dataset.init_poison_data() n, c, T = self.dataset.data.shape n = len (self.atk_vars) trigger_len = self.trigger_generator.output_dim pattern_len = self.target_pattern.shape[-1 ] for beg_idx in self.atk_ts.tolist(): data_bef_tgr = self.dataset.data[self.atk_vars, 0 :1 , beg_idx - self.trigger_generator.input_dim:beg_idx] data_bef_tgr = self.dataset.normalize(data_bef_tgr) data_bef_tgr = data_bef_tgr.reshape(-1 , self.trigger_generator.input_dim) triggers = self.trigger_generator(data_bef_tgr)[0 ] triggers = self.dataset.denormalize(triggers).reshape(n, 1 , -1 ) self.dataset.poisoned_data[self.atk_vars, 0 :1 , beg_idx:beg_idx + trigger_len] = triggers.detach() self.dataset.poisoned_data[self.atk_vars, 0 :1 , beg_idx + trigger_len:beg_idx + trigger_len + pattern_len] = \ self.target_pattern + self.dataset.poisoned_data[self.atk_vars, 0 :1 , beg_idx - 1 :beg_idx]
注入触发器 : self.dataset.poisoned_data[...] = triggers.detach()
这一行将刚刚生成的 triggers
注入到 poisoned_data
(中毒数据副本)中,位置是从 beg_idx
开始,持续 trigger_len
的长度。.detach()
用于切断梯度,因为注入过程本身不需要反向传播。
注入目标模式 : 紧接着触发器之后,代码注入了目标模式 。
self.target_pattern
: 这是预先定义好的、攻击者希望模型在看到触发器后输出的模式。
+ self.dataset.poisoned_data[...]
: 这里有一个值得注意的细节,目标模式并不是直接覆盖,而是加上了 紧邻攻击点前一个时间步的数据**。** 这样做可能是为了让注入的模式在数值上与周围的数据更“平滑”地衔接,减少异常感,使攻击更隐蔽。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 def update_trigger_generator (self, net, epoch, epochs, use_timestamps=False ): """ update the trigger generator using the soft identification. """ if not use_timestamps: tgr_slices = self.get_trigger_slices(self.fct_input_len - self.trigger_len, self.trigger_len + self.pattern_len + self.fct_output_len) else : tgr_slices, tgr_timestamps = self.get_trigger_slices(self.fct_input_len - self.trigger_len, self.trigger_len + self.pattern_len + self.fct_output_len) pbar = tqdm.tqdm(tgr_slices, desc=f'Attacking data {epoch} /{epochs} ' ) for slice_id, slice in enumerate (pbar): slice = slice .to(self.device) slice = slice [:, 0 :1 , :] n, c, l = slice .shape data_bef = slice [self.atk_vars, :, self.fct_input_len - self.trigger_len - self.bef_tgr_len:self.fct_input_len - self.trigger_len] data_bef = data_bef.reshape(-1 , self.bef_tgr_len) triggers, perturbations = self.predict_trigger(data_bef) triggers = triggers.reshape(self.atk_vars.shape[0 ], -1 , self.trigger_len) slice [self.atk_vars, :, self.fct_input_len - self.trigger_len:self.fct_input_len] = triggers slice [self.atk_vars, :, self.fct_input_len:self.fct_input_len + self.pattern_len] = \ self.target_pattern + slice [self.atk_vars, :, self.fct_input_len - self.trigger_len - 1 ].unsqueeze(-1 ) batch_inputs_bkd = [slice [..., i:i + self.fct_input_len] for i in range (self.pattern_len)] batch_labels_bkd = [slice [..., i + self.fct_input_len:i + self.fct_input_len + self.fct_output_len].detach() for i in range (self.pattern_len)] batch_inputs_bkd = torch.stack(batch_inputs_bkd, dim=0 ) batch_labels_bkd = torch.stack(batch_labels_bkd, dim=0 ) batch_inputs_bkd = batch_inputs_bkd[:, :, 0 :1 , :] batch_labels_bkd = batch_labels_bkd[:, :, 0 , :] batch_inputs_bkd = self.dataset.normalize(batch_inputs_bkd) loss_decay = (self.pattern_len - torch.arange(0 , self.pattern_len, dtype=torch.float32).to( self.device)) / self.pattern_len self.attack_optim.zero_grad() batch_inputs_bkd = batch_inputs_bkd.squeeze(2 ).permute(0 , 2 , 1 ) batch_labels_bkd = batch_labels_bkd.permute(0 , 2 , 1 ) if use_timestamps: batch_x_mark = [tgr_timestamps[slice_id][i:i + self.fct_input_len] for i in range (self.pattern_len)] batch_y_mark = [ tgr_timestamps[slice_id][i + self.fct_input_len:i + self.fct_input_len + self.fct_output_len] for i in range (self.pattern_len)] batch_x_mark = torch.stack(batch_x_mark, dim=0 ) batch_y_mark = torch.stack(batch_y_mark, dim=0 ) else : batch_x_mark = torch.zeros(batch_inputs_bkd.shape[0 ], batch_inputs_bkd.shape[1 ], 4 ).to(self.device) x_des = torch.zeros_like(batch_labels_bkd) outputs_bkd = net(batch_inputs_bkd, batch_x_mark, x_des, None ) outputs_bkd = self.dataset.denormalize(outputs_bkd) loss_bkd = F.mse_loss(outputs_bkd[:, :, self.atk_vars], batch_labels_bkd[:, :, self.atk_vars], reduction='none' ) loss_bkd = torch.mean(loss_bkd, dim=(1 , 2 )) loss_bkd = torch.sum (loss_bkd * loss_decay) loss_norm = torch.abs (torch.sum (perturbations, dim=1 )).mean() loss = loss_bkd + self.lam_norm * loss_norm loss.backward() self.attack_optim.step() self.atk_scheduler.step()
题外话
欧洲航天局举办了一场关于时间序列的后门提取比赛 ,形式新颖,但不允许中国人获得奖金。
但也期待获胜方案是如何解决的。AmbrosM目前位列第一,他一向具有很高深的方法。