SC-NAFSSR

该论文来自同级的qzd等同学。论文地址

指导老师为金枝老师,发表于CVPR的wordshop论文。为CVPR的NTIRE 2023双目图像超分辨率挑战赛的第二名解决方案。


模型架构

训练

VGG 感知损失

LPIPS 感知损失

基于视差的监督损失

对抗训练

立体一致性损失

一些细节

在第一阶段训练中,使用 MSE 损失函数训练模型。使用余弦退火策略,初始学习率为 3e − 3,最小学习率为 1e − 7,进行 100000 次迭代。

在第二阶段,利用感知损失和立体一致性损失进行微调,并将初始学习率设置为 5e − 4 进行 100000 次迭代。MSE 损失、感知损失和立体一致性损失的权重分别为 1、1 和 0.01。数据集在线随机裁剪以增强泛化性能,并应用 EMA 来提高模型的鲁棒性。其他训练超参数与第一阶段训练相同。

由于不允许使用在其他训练集上预训练的光流和视差估计模型,我们使用立体一致性损失作为视差监督损失的替代。从表 1 中我们可以看到,用立体一致性损失替代视差监督损失。性能会下降,但是有比没有好。

作者认为,TTA会破坏图像的感知域结构和立体一致性,因此没有使用这些集成策略。

作者尽管提出了GAN的方法,但是是在提交结果后研究的,因此比赛提交的模型是无 GAN 的。

代码

NAFSSR.yml:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
name: NAFNetSR-L_x4
model_type: ImageRestorationModel
scale: 4
num_gpu: 4
manual_seed: 10

datasets:
train:
name: Flickr1024-sr-train
type: PairedStereoImageDataset
dataroot_gt: /data/ntire/Flickr1024/train/HR
dataroot_lq: /data/ntire/Flickr1024/train/LR_x4
io_backend:
type: disk

gt_size_h: 120
gt_size_w: 360
use_hflip: true
use_vflip: true
use_rot: false
flip_RGB: true

# data loader
use_shuffle: true
num_worker_per_gpu: 4
batch_size_per_gpu: 2
dataset_enlarge_ratio: 1
prefetch_mode: ~

val:
name: Flickr1024-sr-test
type: PairedImageSRLRDataset
dataroot_gt: /data/ntire/Flickr1024/val/HR
dataroot_lq: /data/ntire/Flickr1024/val/LR_x4
io_backend:
type: disk

# network structures
network_g:
type: NAFSSRsc
up_scale: 4
width: 128
num_blks: 128
drop_path_rate: 0.3
train_size: [1, 6, 30, 90]
drop_out_rate: 0.

# path
path:
pretrain_network_g: ~
strict_load_g: true
resume_state: ~

# training settings
train:
optim_g:
type: AdamW
lr: !!float 3e-3
weight_decay: !!float 0
betas: [0.9, 0.9]

scheduler:
type: TrueCosineAnnealingLR
T_max: 100000
eta_min: !!float 1e-7

total_iter: 100000
warmup_iter: -1 # no warm up
mixup: false

# losses
pixel_opt:
type: MSELoss
loss_weight: 1.
reduction: mean

# validation settings
val:
val_freq: !!float 2e4
save_img: false
trans_num: 1

max_minibatch: 1

metrics:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 0
test_y_channel: false
ssim:
type: calculate_skimage_ssim

# logging settings
logger:
print_freq: 200
save_checkpoint_freq: !!float 1e4
use_tb_logger: true
wandb:
project: ~
resume_id: ~

# dist training settings
dist_params:
backend: nccl
port: 29500

注意到默认情况下,是只有pixel_opt(像素损失)。

网格化处理大图像: 用于在验证/测试时处理可能无法一次性放入 GPU 显存的大图像。grids 将大图像分割成多个重叠或不重叠的小块(patches/grids),grids_inverse 则将处理后的小块结果拼接回原始大图的尺寸。

优化参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def optimize_parameters(self, current_iter, tb_logger):
# --- 分阶段训练逻辑 ---
# 固定光流网络参数 (fix_flow_iter)
if self.fix_flow_iter:
logger = get_root_logger()
if current_iter == 1: # 第一次迭代时
logger.info(f'Fix flow network and feature extractor for {self.fix_flow_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'spynet' in name or 'edvr' in name: # 假设光流网络或特征提取器模块名包含 'spynet' 或 'edvr'
param.requires_grad_(False) # 将这些参数的 requires_grad 设置为 False,使其不参与梯度更新
elif current_iter == self.fix_flow_iter: # 达到指定迭代次数后
logger.warning('Train all the parameters.')
self.net_g.requires_grad_(True) # 重新将整个网络设置为可训练

# 只训练 PAM 模块 (train_pam_iter)
if self.train_pam_iter:
if current_iter == 1:
logger = get_root_logger()
logger.info(f'Only train PAM module for {self.train_pam_iter} iters.')
for name, param in self.net_g.named_parameters():
# 假设 PAM 模块相关的参数名包含 'fusion' 或 'patch_embed_ln'
if ('fusion' not in name) and ('patch_embed_ln' not in name):
param.requires_grad = False # 其他参数不训练
elif current_iter == self.train_pam_iter:
self.train_pam_iter = None # 标记结束
logger = get_root_logger()
logger.warning('Train all the parameters.')
for param in self.net_g.parameters():
param.requires_grad = True # 恢复所有参数可训练

self.optimizer_g.zero_grad() # 清空优化器的梯度

# MixUp 数据增强 (可选)
if self.opt['train'].get('mixup', False):
self.mixup_aug() # 假设 BaseModel 或此文件中有 mixup_aug 方法实现

# 前向传播
preds, cri_sc = self.net_g(self.lq) # 将低分辨率图像 self.lq 输入网络 self.net_g
# 网络可能返回多个预测结果 (如多尺度预测) 和一个 stereo consistent loss (cri_sc)
if not isinstance(preds, list): # 确保 preds 是一个列表
preds = [preds]

self.output = preds[-1] # 通常取最后一个预测结果作为最终输出

# --- 损失计算 ---
l_total = 0 # 总损失
loss_dict = OrderedDict() # 用于记录各种损失的值

# 像素损失
if self.cri_pix:
l_pix = 0.
for pred in preds: # 对每个预测结果计算像素损失 (如果网络输出多个预测)
l_pix += self.cri_pix(pred, self.gt) # 计算预测 pred 和真实 self.gt 之间的像素损失
l_total += l_pix
loss_dict['l_pix'] = l_pix

# 感知损失 (针对立体图像分别计算左右眼)
if self.cri_perceptual:
# 假设输出和GT的前3个通道是左眼,后3个通道是右眼
l_percep_l, l_style_l = self.cri_perceptual(self.output[:,:3], self.gt[:,:3])
l_percep_r, l_style_r = self.cri_perceptual(self.output[:,3:], self.gt[:,3:])

if l_percep_l is not None: # 感知内容损失
l_percep = l_percep_l + l_percep_r
l_total += l_percep
loss_dict['l_percep'] = l_percep
if l_style_l is not None: # 感知风格损失
l_style = l_style_l + l_style_r
l_total += l_style
loss_dict['l_style'] = l_style

# 立体一致性损失
if self.cri_sc: # 如果在 init_training_settings 中定义了 self.cri_sc
l_sc = self.cri_sc(self.output, self.gt)
l_total += l_sc
loss_dict['l_sc'] = l_sc
else: # 如果 self.cri_sc 未定义,则使用网络直接返回的 cri_sc
# 这意味着网络内部可能计算了一种一致性损失
l_sc = cri_sc
l_total += l_sc
loss_dict['l_sc'] = l_sc

# 一个小的正则化项,通常用于确保参数有值,实际影响很小 (乘以0.)
l_total = l_total + 0. * sum(p.sum() for p in self.net_g.parameters())

# --- 反向传播和参数更新 ---
l_total.backward() # 计算总损失相对于网络参数的梯度

# 梯度裁剪 (可选)
use_grad_clip = self.opt['train'].get('use_grad_clip', True)
if use_grad_clip:
torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01) # 将梯度范数裁剪到0.01,防止梯度爆炸

self.optimizer_g.step() # 使用优化器根据梯度更新网络参数

# --- 日志和 EMA 更新 ---
self.log_dict = self.reduce_loss_dict(loss_dict) # (用于分布式训练) 聚合来自不同GPU的损失值并记录

if self.ema_decay > 0: # 如果启用了EMA
self.model_ema(decay=self.ema_decay) # 更新EMA模型的参数

架构:

此外还好使用DropPath和SCAM

SCAM:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class SCAM(nn.Module):
'''
Stereo Cross Attention Module (SCAM)
'''
def __init__(self, c): # c - 通道数
super().__init__()
self.scale = c ** -0.5 # 注意力机制的缩放因子
self.criterion = nn.L1Loss() # L1损失,用于辅助损失函数
self.norm_l = LayerNorm2d(c) # 左输入的层归一化
self.norm_r = LayerNorm2d(c) # 右输入的层归一化
# 用于创建Query (Q), Key (K)的投影层 - Key在这里是从另一个视图的Q隐式形成的
self.l_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj1 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)

# 用于缩放聚合特征的可学习参数
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

# 用于创建Value (V)的投影层
self.l_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)
self.r_proj2 = nn.Conv2d(c, c, kernel_size=1, stride=1, padding=0)

# forward_ - 一个更简单的版本,没有辅助损失(在NAFSSRsc中未使用)
# def forward_(self, x_l, x_r): ...

def forward(self, x_l, x_r, LR_left, LR_right, loss):
# 1. 为两个视图生成 Q (Query) 和 V (Value)
Q_l = self.l_proj1(self.norm_l(x_l)).permute(0, 2, 3, 1) # B, H, W, c
Q_r_T = self.r_proj1(self.norm_r(x_r)).permute(0, 2, 1, 3) # B, H, c, W (来自右视图Q的转置Key)
V_l = self.l_proj2(x_l).permute(0, 2, 3, 1) # B, H, W, c
V_r = self.r_proj2(x_r).permute(0, 2, 3, 1) # B, H, W, c

b, h, w, c_ = Q_l.shape # 使用c_代替c,以避免与类变量冲突

# 2. 计算注意力图
# 从左到右的注意力:左视图的每个像素关注同一行中右视图的所有像素
attention = torch.matmul(Q_l, Q_r_T) * self.scale # (B, H, W, c) x (B, H, c, W) -> (B, H, W, W)
attention_T = attention.permute(0, 1, 3, 2) # 转置的注意力图(从右到左)

M_right_to_left = torch.softmax(attention, dim=-1) # 转移矩阵:右视图特征如何影响左视图
M_left_to_right = torch.softmax(attention_T, dim=-1) # 转移矩阵:左视图特征如何影响右视图

# 3. 使用形态学处理生成可见性掩码
# 如果像素的注意力权重之和足够大,则认为它是“可见的”或“对应的”
V_left_to_right_mask = torch.sum(M_left_to_right.detach(), 2) > 0.1 # 按行求和(像素关注的位置)
V_left_to_right_mask = morphologic_process(V_left_to_right_mask.view(b, 1, h, w))
V_right_to_left_mask = torch.sum(M_right_to_left.detach(), 2) > 0.1
V_right_to_left_mask = morphologic_process(V_right_to_left_mask.view(b, 1, h, w))

# 4. 注意力的循环一致性
# 理想情况下,如果应用R->L注意力,然后应用L->R注意力,应该得到原始位置
M_left_right_left = torch.matmul(M_right_to_left, M_left_to_right)
M_right_left_right = torch.matmul(M_left_to_right, M_right_to_left)

# 5. 基于注意力的特征聚合
F_r2l = torch.matmul(M_right_to_left, V_r) # 从右视图为左视图聚合的特征
F_l2r = torch.matmul(M_left_to_right, V_l) # 从左视图为右视图聚合的特征

# 缩放并添加到原始特征(残差连接)
F_r2l = F_r2l.permute(0, 3, 1, 2) * self.beta
F_l2r = F_l2r.permute(0, 3, 1, 2) * self.gamma

# 6. 计算辅助损失函数
### 注意力图的平滑度损失 (Smoothness Loss)
# 鼓励 M_right_to_left 和 M_left_to_right 注意力图的平滑性
loss_h = self.criterion(M_right_to_left[:, :-1, :, :], M_right_to_left[:, 1:, :, :]) + \
self.criterion(M_left_to_right[:, :-1, :, :], M_left_to_right[:, 1:, :, :])
loss_w = self.criterion(M_right_to_left[:, :, :-1, :-1], M_right_to_left[:, :, 1:, 1:]) + \
self.criterion(M_left_to_right[:, :, :-1, :-1], M_left_to_right[:, :, 1:, 1:])
loss_smooth = loss_w + loss_h

### 循环一致性损失 (Cycle Consistency Loss)
# M_left_right_left (和 M_right_left_right) 在可见区域应接近单位矩阵
Identity = torch.autograd.Variable(torch.eye(w, w).repeat(b, h, 1, 1).to(Q_l.device), requires_grad=False)
loss_cycle = self.criterion(M_left_right_left * V_left_to_right_mask.permute(0, 2, 1, 3), Identity * V_left_to_right_mask.permute(0, 2, 1, 3)) + \
self.criterion(M_right_left_right * V_right_to_left_mask.permute(0, 2, 1, 3), Identity * V_right_to_left_mask.permute(0, 2, 1, 3))

### 光度损失 (Photometric Loss)
# 使用注意力图对低分辨率 (LR) 图像进行“扭曲”(warping)
# 并在可见区域将它们与原始LR图像进行比较。
# 这有助于训练注意力图找到正确的对应关系。
LR_right_warped = torch.bmm(M_right_to_left.contiguous().view(b * h, w, w), LR_right.permute(0, 2, 3, 1).contiguous().view(b * h, w, 3))
LR_right_warped = LR_right_warped.view(b, h, w, 3).contiguous().permute(0, 3, 1, 2)
LR_left_warped = torch.bmm(M_left_to_right.contiguous().view(b * h, w, w), LR_left.permute(0, 2, 3, 1).contiguous().view(b * h, w, 3))
LR_left_warped = LR_left_warped.view(b, h, w, 3).contiguous().permute(0, 3, 1, 2)

loss_photo = self.criterion(LR_left * V_left_to_right_mask, LR_right_warped * V_left_to_right_mask) + \
self.criterion(LR_right * V_right_to_left_mask, LR_left_warped * V_right_to_left_mask)

# 将辅助损失的加权和添加到模型的总损失中
loss += 0.0025 * (loss_photo + 0.1 * loss_smooth + loss_cycle)
return x_l + F_r2l, x_r + F_l2r, LR_left, LR_right, loss # 返回更新后的特征和总损失

结果


SC-NAFSSR
https://lijianxiong.space/2023/20230802/
作者
LJX
发布于
2023年8月2日
许可协议