Gram矩阵的妙用

Gram矩阵介绍

给定一个内积空间中的一组向量 v1,v2,…,vn,它们的Gram矩阵 G 是一个n×n的方阵,其元素 Gij 定义为向量 vi 和 vj 的内积。

它有以下性质:

体积计算:Gram行列式的平方根等于这组向量所张成的平行多面体的体积。

线性独立性判定:一组向量是线性独立的当且仅当它们的Gram矩阵是可逆的(即行列式不为零)。

ICLR 2025-使用Gram来对齐多模态embedding

https://github.com/ispamm/GRAM

GRAM 在高维空间中直接学习并对齐 n 种模态,通过最小化由模态向量张成的 k 维平行多面体的格拉姆体积,确保所有模态同时几何对齐。GRAM 可以在任何下游方法中替代余弦相似度,适用于 2 到 n 种模态,并提供相较于以往相似度度量更有意义的对齐。基于 GRAM 的新型对比损失函数增强了多模态模型在高维嵌入空间中的对齐。

简单来说,它使用了Gram矩阵的体积性质。

具体算法为

损失函数为:

除了提出的基于体积的多模态损失函数外,还采用了辅助的数据-锚点匹配损失。该损失旨在鼓励模型推断一对锚点和数据是否匹配。

通过沿序列维度连接所有编码器的未池化特征来获得数据特征。在多模态编码器的底部,一个 MLP 层返回二元预测$p_{dam}$,并计算交叉熵。

代码

简单版本:

1
2
3
4
5
def simple_volume_computation(language, video, audio):
A = torch.stack([language, video, audio])
G = A @ A.T
gramian = torch.linalg.det(G)
return torch.sqrt(gramian)

标准版本:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def volume_computation(anchor, *inputs):
"""
General function to compute volume for contrastive learning loss functions.
Compute the volume metric for each vector in anchor batch and all the other modalities listed in *inputs.

Args:
- anchor (torch.Tensor): Tensor of shape (batch_size1, dim)
- *inputs (torch.Tensor): Variable number of tensors of shape (batch_size2, dim)

Returns:
- torch.Tensor: Tensor of shape (batch_size1, batch_size2) representing the volume for each pair.
"""
batch_size1 = anchor.shape[0]
batch_size2 = inputs[0].shape[0]

# Compute pairwise dot products for language with itself
aa = torch.einsum('bi,bi->b', anchor, anchor).unsqueeze(1).expand(-1, batch_size2)

# Compute pairwise dot products for language with each input
l_inputs = [anchor @ input.T for input in inputs]

# Compute pairwise dot products for each input with themselves and with each other
input_dot_products = []
for i, input1 in enumerate(inputs):
row = []
for j, input2 in enumerate(inputs):
dot_product = torch.einsum('bi,bi->b', input1, input2).unsqueeze(0).expand(batch_size1, -1)
row.append(dot_product)
input_dot_products.append(row)

# Stack the results to form the Gram matrix for each pair
G = torch.stack([
torch.stack([aa] + l_inputs, dim=-1),
*[torch.stack([l_inputs[i]] + input_dot_products[i], dim=-1) for i in range(len(inputs))]
], dim=-2)

# Compute the determinant for each Gram matrix
gram_det = torch.det(G.float())

# Compute the square root of the absolute value of the determinants
res = torch.sqrt(torch.abs(gram_det))
return res

示例代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn.functional as F

# Hyperparameters
bs = 32
latent_dim = 512
contrastive_temp = 0.07

# Output of the encoders
language = torch.randn((bs,latent_dim))
video = torch.randn((bs,latent_dim))
audio = torch.randn((bs,latent_dim))

volume = volume_computation(language,video,audio)
volume = volume / contrastive_temp


volumeT = volume_computation(language,video,audio).T
volumeT = volumeT / contrastive_temp

targets = torch.linspace(0, bs - 1, bs, dtype=int)

loss = (
F.cross_entropy(-volume, targets, label_smoothing=0.1) #d2a
+ F.cross_entropy(-volumeT, targets, label_smoothing=0.1) #a2d
) / 2

print(loss)

arxiv-Principled Multimodal Representation Learning

这篇是上篇的改进。National University of Singapore提出。

不过是对奇异值进行处理,极大化主奇异值并极小化其他的特征值。

ICLR 2025的行列式为0,相当于存在一个为0的特征值(行列式=特征值之积)。

和促进不同实例间的主奇异值的分离。

同样也使用了一个 MLP 层返回二元预测$p_{dam}$,并计算交叉熵。

详细代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
Integrating PMRL with four steps
# 1. Singular Value Decomposition on Multimodal Representations >>>
U, S, _ = torch.linalg.svd(
torch.stack ([ feat_t ,feat_v ,feat_a ,feat_s], dim =-1)
)
# 2. Principled learning via maximum singular values >>>
loss1 = F. cross_entropy (S/self.tau1 , torch.zeros(S.shape [0]).to(S.device).long ())
# Implemented by cross -entropy , and the singular value at the first position is the
maximum one
# 3. Principled regularization via eigenvector corresponding to the maximum singular
values >>>
U1 = U[:, :, 0]
loss2 = F. cross_entropy ((U1 @ U1.T)/self.tau2 , torch.arange(U1.shape [0]).to(U1.device
).long ())
......
# 4. Combine the loss >>>
loss = loss1 + self. lambda1 * loss2 + self. lambda2 * loss_IM

DINO V3中的Gram矩阵

DINO V3是Meta AI在自监督学习(Self-Supervised Learning, SSL)领域推出的最新一代模型,它是DINO和DINOv2的延续和重大升级。DINO V3的目标是利用大规模的无标签图像数据,训练出通用且强大的视觉骨干网络,从而在各种下游视觉任务中取得卓越性能。

DINOV3中提出了一种Gram Anchoring的方法。

在实验过程中,作者们发现学习强判别性特征与保持局部一致性之间存在相对独立性,这一点从全局性能与稠密性能之间缺乏相关系数得以体现。尽管将全局 DINO 损失与局部 iBOT 损失相结合已开始解决这一问题,但这种平衡并不稳定,随着训练的推进,全局表示逐渐占据主导。

故作者使用Gram进一步明确利用了这种独立性。

这个新的损失函数作用于 Gram 矩阵:即图像中所有补丁特征对成对点的积的矩阵。

作者还希望将学生的 Gram 矩阵推向一个早期模型,称为 Gram 教师。通过选择教师网络的早期迭代来选择 Gram 教师,该迭代表现出优越的稠密特性。通过作用于 Gram 矩阵而非特征本身,局部特征可以自由移动,只要相似性结构保持不变。

假设有一张由 P 个补丁组成的图像,以及一个在维度 d 上运行的网络。用 $X_S$ (分别 $X_G$) 表示学生 (分别 Gram 教师) 的 $L_2$ 规范化局部特征的 $P \times d$ 矩阵。定义损失 $L_\text{Gram}$ 如下:
$$
L_\text{Gram} = |X_S \cdot X_S^T - X_G \cdot X_G^T|_F^2
$$

仅在全局裁剪上计算此损失。尽管它可以在训练早期应用,但出于效率考虑,仅在 1 M 次迭代后才开始。有趣的是,作者观察到了 $L_\text{Gram}$ 的延迟应用仍能“修复”严重退化的局部特征。为了进一步提升性能,每 10k 次迭代更新一次 Gram 教师,此时 Gram 教师与主 EMA 教师完全相同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class GramLoss(nn.Module):
"""Implementation of the gram loss"""

def __init__(
self,
apply_norm=True,
img_level=True,
remove_neg=True,
remove_only_teacher_neg=False,
):
super().__init__()

# Loss
self.mse_loss = torch.nn.MSELoss()

# Parameters
self.apply_norm = apply_norm
self.remove_neg = remove_neg
self.remove_only_teacher_neg = remove_only_teacher_neg

if self.remove_neg or self.remove_only_teacher_neg:
assert self.remove_neg != self.remove_only_teacher_neg

def forward(self, output_feats, target_feats, img_level=True):
"""Compute the MSE loss between the gram matrix of the input and target features.

Args:
output_feats: Pytorch tensor (B, N, dim) or (B*N, dim) if img_level == False
target_feats: Pytorch tensor (B, N, dim) or (B*N, dim) if img_level == False
img_level: bool, if true gram computed at the image level only else over the entire batch
Returns:
loss: scalar
"""

# Dimensions of the tensor should be (B, N, dim)
if img_level:
assert len(target_feats.shape) == 3 and len(output_feats.shape) == 3

# Float casting
output_feats = output_feats.float()
target_feats = target_feats.float()

# SSL correlation
if self.apply_norm:
target_feats = F.normalize(target_feats, dim=-1)

if not img_level and len(target_feats.shape) == 3:
# Flatten (B, N, D) into (B*N, D)
target_feats = target_feats.flatten(0, 1)

# Compute similarities
target_sim = torch.matmul(target_feats, target_feats.transpose(-1, -2))

# Patch correlation
if self.apply_norm:
output_feats = F.normalize(output_feats, dim=-1)

if not img_level and len(output_feats.shape) == 3:
# Flatten (B, N, D) into (B*N, D)
output_feats = output_feats.flatten(0, 1)

# Compute similarities
student_sim = torch.matmul(output_feats, output_feats.transpose(-1, -2))

if self.remove_neg:
target_sim[target_sim < 0] = 0.0
student_sim[student_sim < 0] = 0.0

elif self.remove_only_teacher_neg:
# Remove only the negative sim values of the teacher
target_sim[target_sim < 0] = 0.0
student_sim[(student_sim < 0) & (target_sim < 0)] = 0.0

return self.mse_loss(student_sim, target_sim)

Gram矩阵的妙用
https://lijianxiong.space/2025/20250826/
作者
LJX
发布于
2025年8月26日
许可协议