
自监督学习——对比学习SimCLR框架(原理+代码)
1原理
对比学习
通过比较不同实例之间的相似性和差异性来进行学习
。在对比学习中,我们将输入数据分为不同的类别或组(正负样本对
),并通过比较样本之间的差异来提取特征或进行分类。
样本相似度
对比学习有几种不同的方法,其中最常见的是基于距离度量
的方法。这些方法使用距离函数来度量两个实例之间的相似性,例如欧氏距离
或余弦相似度
。通过计算实例之间的距离,我们可以找到最相似或最不相似的实例,从而进行特征选择、相似性匹配或分类任务。(正样本对相似度越近越好,负样本对相似度越远越好)
SimCLR——对比学习通过度量学习,提供特征提取的能力
- 取一个输入图像:同1张图像进行2种数据增强,形成一个正样本对儿;不同图像之间是负样本对儿。
- 准备2个随机的图像增强:旋转,颜色/饱和度/亮度变化,缩放,裁剪等。文中详细讨论了增强的范围,并分析了哪些增广效果最好。(构造正样本:图像SimCLR-数据增强、文本SimCSE-Dropout、图文CLIP-图像文本对)
- 特征提取:运行一个深度神经网络(最好是卷积神经网络,如ViT、Bert、ResNet50)来获得那些增强图像的
图像特征表示(嵌入)
。 - 特征投影:运行一个小的全连接线性神经网络,将嵌入投影到另一个向量空间。
- 计算loss:计算对比损失并通过两个网络进行反向传播。当来自同一图像的投影相似时,对比损失减少。投影之间的相似度可以是任意的,这里我使用余弦相似度,和论文中一样。
- 下游任务:对比学习
得到Encoder
做为特征提取器,根据下游任务的数据集进行微调Finetuin。
数据要多,batch要大(batchsize=8192)
正负样本对的构建,不需要标注
损失loss函数怎么设计?
l
i
,
j
=
−
l
o
g
e
x
p
(
s
i
m
(
z
i
,
z
j
)
/
t
)
∑
k
=
1
2
N
1
[
k
!
=
i
]
e
x
p
(
s
i
m
(
z
i
,
z
k
)
/
t
l_{i,j}=-log{\frac{exp(sim(z_i,z_j)/t)}{\sum_{k=1}^{2N}1_{[k!=i]}exp(sim(z_i,z_k)/t}}
li,j=−log∑k=12N1[k!=i]exp(sim(zi,zk)/texp(sim(zi,zj)/t)
其中,分子是同类之间的相似度(正样本之间的距离),分母是不同类之间的相似度(负样本对之间的距离)。
t
t
t是temperature(尺度<1)参数,用于调整比列。
2 代码
DALLE2-pytorch 以 CLIP 为例,学习对比学习的过程,loss:=文本(MLM,Mask Language Model),图像(SimCLR对比损失),图文(图像文本对儿对比损失)
class SimCLR(nn.Module):
def __init__(
self,
net,
image_size,
channels = 3,
hidden_layer = -2,
project_hidden = True,
project_dim = 128,
augment_both = True,
use_nt_xent_loss = False,
augment_fn = None,
temperature = 0.1
):
super().__init__()
self.net = NetWrapper(net, project_dim, layer = hidden_layer)
self.augment = default(augment_fn, get_default_aug(image_size, channels))
self.augment_both = augment_both
self.temperature = temperature
# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)
# send a mock image tensor to instantiate parameters
self.forward(torch.randn(1, channels, image_size, image_size))
def forward(self, x):
b, c, h, w, device = *x.shape, x.device
transform_fn = self.augment if self.augment_both else noop
# 把原图使用不同数据增强和ViT提取成两个不同的图像特征(正样本对queries、keys)
queries, _ = self.net(transform_fn(x))
keys, _ = self.net(self.augment(x))
queries, keys = map(flatten, (queries, keys))
# 计算loss
loss = nt_xent_loss(queries, keys, temperature = self.temperature)
return loss
loss
def nt_xent_loss(queries, keys, temperature = 0.1):
b, device = queries.shape[0], queries.device
n = b * 2 # 同一图片内部不同patch也是负样本
projs = torch.cat((queries, keys))
logits = projs @ projs.t()
mask = torch.eye(n, device=device).bool()
logits = logits[~mask].reshape(n, n - 1) # 同一图片内部不同patch也是负样本,除了自己和自己
logits /= temperature
labels = torch.cat(((torch.arange(b, device = device) + b - 1), torch.arange(b, device=device)), dim=0)
loss = F.cross_entropy(logits, labels, reduction = 'sum')
loss /= n
return loss
更多推荐



所有评论(0)