首页 > 关注 > > > 正文
浅谈表征学习中的自监督模型Simsiam
发布时间:2023-02-06 23:13:27   来源:哔哩哔哩  

论文名称为Exploring Simple Siamese Representation Learning[1],是大神Kaiming He的团队在表征学习/对比学习/自监督学习问题上提出的新算法。在之前,对比学习(contrastive learning)已经应用于自监督视觉任务中,其中颇有代表性的是SimCLR[2]、SwAV[3]、BYOL[4]。

上述三个方法均是为了解决Siamese network(孪生网络)带来的主要问题:collapsing solutions(答案崩溃)。Simsiam亦是如此,但Simsiam的贡献在于,没有改变模型结构、不需要增加batch size、也不需要引入负样本提高计算量,只需要通过加入stop-gradient的方法,就可以解决collapsing solutions的问题。作者在文中表示,希望该方法可以重新唤起大家在表征学习中对Siamese network的信心。


(资料图片仅供参考)

本文将先直观展示Simsiam的模型结构,再简单介绍什么叫collapsing solutions,接下来,针对模型结构中的三大设计stop-gradient、predictor 以及损失函数分别进行解释。最后从K-Means的角度来重新审视Simsiam模型架构。

一、模型结构

本文将先给出模型结构的简单阐述,再去用较为通俗易懂的原理解释为什么要这么做。模型的结构图如下图所示:

可以看到,模型非常的简单,主要分为几个部分:

输入:代表数据集中的一张图片,是由对进行数据增强(data augmentations)得到的两张图片。注意,原图片并不作为输入。数据增强的方法有很多,比较重要的是裁剪和改变颜色通道等[2]。一般是通过一系列不同的数据增强方法的叠加而获取的,而并非使用单独的数据增强方法。

接下来,通过同一个编码器(encoder)编码,得到两个表征向量。这个编码器一般使用经典卷积神经网络ResNet。这一步的目的就是通过卷积神经网络提取特征,得到感受野大、维度较小的向量。

然后,我们将经过一个MLP映射得到,通过计算的相似度作为目标函数去学习。在训练中,因为来自同一张图片,所以我们将这两个输入看作正样本对(即同类样本),因此我们需要最大化二者的相似度。至于为什么要加这个MLP层,我会在第三节给出作者的解释

其实我们的目标函数是一个对称的(symmetric)函数,也就是不仅要算的相似度,还要计算的相似度()。最终的损失函数如下:

其中,

上式被称为余弦相似度(cosine similarity)。损失函数的设计思路,我们会在第四节给出通俗解释。

下面我们来说一下为什么模型是这个结构:

先举个例子,如下图所示,比如我们给一个孩子一张孙悟空大闹天宫时的图片(左图),再给这个孩子西天取经时,师徒四人的图片(右图),来让孩子选出右边这四张图片中哪个是孙悟空。相信孩子会很快选出正确答案。这是因为孩子会自觉地从五张图像中,抽样出一些特征(如脸型、毛发、五官等)来进行比对。

孩子脑海中的做法可能是这样的:先从右边第一张图片(唐僧)和左边比对,不是很像,再拿右边第二张图片(孙悟空)和左边比对,比较像,应该是正确答案,再看看后面,然后依次拿猪八戒和沙僧进行比较。最终发现还是第二张图最像,选择第二张。

通过“对比”这个思路,我们可以直接地想到,如果我们将两张(或多张)图片,送入模型中,然后通过目标函数对模型进行优化,使得两张相似图片的特征向量尽可能接近,两张不相似的图片向量尽可能远离,这不是就有可能抽样出较好的特征向量了吗?

如果你能想到这个思路,那说明你已经对孪生模型的认识有了雏形。但这里面会存在一个很重要的问题,就是我们是一个自监督/无监督模型,换言之,我们的数据是没有标签的,我们如何定义、选出两张相似/不相似的图片作为输入呢?

我们想到,如果对同一张图片做数据增强,并且让他们尽可能地发生较大变化,由于我们是对同一个图片进行的增强,所以我们可以合理的认为这两张图片是相似的,至少是同类别的。而由于增强,两张图片之间也形成了较大区分度,也可以使得模型学到更高层级、表征度更高的特征,尽可能的避免过拟合。

然而,在上述方法下,我们的输入只有相似的图片,这会造成一种被称为collapsing solutions的问题。在开篇也提到了,很多最近的研究都是在提出各种方案解决这个问题。那么接下来我们就来看一下collapsing solutions到底是什么。

二、什么是collapsing solutions

依然举一个极端的例子来理解:假设损失函数的最小值为0(即我们要用梯度下降法尽量将损失函数降低到最小值0),那么模型完全可以投机取巧地令所有权重更新为0,这样我们永远都可以让损失函数停在最小点。为什么呢?原因就在于训练的时候每次输入都只有正样本而没有负样本,这样模型自然会“偷懒”,无论输入是什么,就让损失等于0就好了。

在这种情况下,我们的模型学习到的表征(也就是模型的solutions)也将统一为零向量,这种可怕的结果就被称为collapsing solutions。

有人可能会想到一种办法来解决这个问题,其实造成这类问题原因的,是因为我们的输入永远都是正样本对,模型学到这类规矩后才会“投机取巧”。那么我们不妨再引入一些输入,这些输入与输入不同,来让他们作为反例,通过最大化正样本对的相似度+最小化正反样本间的相似度,来学习模型权重。

这个方法是可行的,而且其实就是SimCLR所采用的策略。但是实验显示,这类方法需要很大的batch-size才能取得很不错的效果。而大的batch不仅大大增加了计算量(设batch size为N,增强后有2*N个样本,对于N中任意一个正样本对,都有2N-2个负样本对参与计算),更重要的是,对显存的占有率极高。所以对大多数任务可能并不友好。

而Simsiam采用的方法,只需要正样本对,就可以达到同样甚至更好的效果。这种方法叫做stop-gradient。顾名思义,其实就是停止计算该部分权重的梯度(如图所示只算通路的梯度,而不计算通路的梯度),也就是并不更新该部分权重。因此损失函数实际如下图

作者通过多种实验结果表示,这样就可以解决collapsing solutions的问题。下面来简单讲解一下这种方法为什么有效。

三、为什么stop-gradient可以解决这个问题

首先,我们先继续用通俗的思路和语言来理解一下为什么stop-gradient可以成功解决collapsing solutions的问题。

由于我们的任务是一个无监督学习任务,即没有target作为模型学习的基准,而每次输入又都是正样本对,每次输入都只需要做同一件事情——最大化相似度,因此才会出现collapsing solutions的问题。就好像没有KPI(没有target),而不管多努力到手的奖金都是一样多(不管输入是什么反正都是相似的),大家就会摆烂一样(模型偷懒直接摆烂)。

对于孪生网络,输入,编码成,再去计算损失值,然后对称地对共享的权重进行梯度下降更新。这种训练方式可以看作是并行进行的。那么现在,我们对编码过程使用stop-gradient,也就是不再针对的梯度进行权重更新,这样的话,我们完全可以理解为模型将(也可以认为是)当做了此次权重更新的target。有了target,无监督学习可以被巧妙地视为监督学习,从而避免了collapsing solutions问题的出现。

有人可能会有一个疑问,如果我们先对进行stop-gradient,只计算的梯度来更新权重,再对使用stop-gradient,只计算的梯度来更新权重,这和同时对和求梯度更新权重不是一个意思嘛!

你说的其实很对,我上面讲的只是一个简单易懂的讲法,其实模型中还有一个部分我们在解释中一直没有提到,就是那个MLP层。这就是加入MLP层的意义之一。

作者的实验还发现,的权重一定要是可学习的(learnable),因为如果是固定的随机初始化权重,虽然不会发生collapsing solutions的问题,但是模型可能会无法收敛。

另外,作者发现固定的学习率(fixed learning rate)要比使用一些学习优化器(learning scheduler)的效果要好(比如线性变化、余弦变化)。作者也给出了一个较为直观的原因:因为我们每次的输入都是通过新的数据增强形成的,对于模型来说是新的数据,并且我们的模型任务是进行表征学习而非分类、检测等任务,所以相比于收敛到最优点,我们更希望模型去适应学习最新的表征。

在第五节,我们会对predictor 有更进一步的理解,也会从另一个角度来诠释stop-gradient这种交替更新(alternating updates)的方法为什么效果这么好。

四、为什么用余弦相似度来做损失函数

我们的任务目标是学习到分布较好的表征向量,什么叫分布较好?也就是相似的向量越接近越好,不相似的向量越远越好。向量远近的度量单位一般可以用向量夹角的大小来表示。

回顾余弦相似度的公式,我们可以发现,之所以叫余弦相似度,因为有

而余弦相似度的公式,恰好就是在求两个向量夹角的余弦值。余弦值越大,两向量夹角越小,表示两向量越接近,反之亦然。

另外,有

余弦相似度即为两个归一化向量的内积。[4]中证明了最大化上式等价于最小化两个归一化向量的MSE(Mean Squared Error)。

作者也尝试了交叉熵相似度来做损失函数,其他参数和设置都没有改变,但是结果下降了约5个百分点(68.1->63.2)。这说明了两点:1. 余弦相似度在本任务中优于交叉熵,2. 使用余弦相似度并不是避免collapsing solutions的发生的主要因素(因为换了交叉熵依然不会发生collapsing solutions)。

五、K-Means的角度理解Stop-gradient

这一部分是对Simsiam的另类思考,如果只想了解模型整体概况,此节可以略过。

另,这一节讲的内容可能会和原文有点出入,原因是本文讲的内容是在我本人理解的基础上,修改了原文部分符号以及逻辑顺序,使得更容易理解。当然也可能存在错误理解作者本意的情况,如有发现,希望指正。

首先,我们在上一节中提到,余弦相似度与MSE实际上是等价的[4]。那么我们不妨把损失函数改写为

(1)

其中为数学期望。我们设的参数集合为,又因为我们有,为数据增强,的参数集合为。因此我们有了优化目标:

现在,我们从K-Means聚类的角度来看问题。如果我们将看作聚类中心,将看作赋值函数。

回顾一下K-Means是怎么进行优化的,首先我们先得到一个赋值函数,即对每个样本点,给出它对应的最近聚类中心,赋值函数可以将样本空间进行一个划分。然后根据这个划分,我们再重新计算聚类中心。经过一次次的交替更新,从而得到最优解。

ok,现在我们来思考一下,如何更新与,由K-Means的alternating algorithm,我们可以得到

通过这种算法,其实我们就很自然的得到了stop-gradient的思路。因为在计算时,我们需要是固定的。因此不能从的梯度进行更新。这也就是为什么要进行stop-gradient的原因。在原文中,也给出了MLP层究竟在做什么事情。

回到式(1),我们可以显式地得到的最优解为,对于任意输入图片,

注意,这里面不再代表对某一张,而代表一个随机变量,是对任意的,这个数学期望是数据增强分布上的数学期望。也就是说,这个MLP层实际上是在预测数据增强分布上的数学期望,从而试图矫正由数据增强带来的随机性而导致的误差。这也是为什么第四节中讲到的,作者发现如果的参数无法学习,则会导致模型难以收敛。因为模型并不能预测出数据增强的分布,而是一直给出一个错误的分布,这样的话即使不会出现collapsing solutions的问题,也会导致模型难以学到正确的表征。

六、结语

本文用非常浅显易懂的文字与例子,讲述了Simsiam模型的原理。并在最后基于作者给出的数学解释,结合我自身的理解,对模型和K-means聚类之间的关系进行了讲述。模型借鉴于K-means聚类的alternating algorithm优化方法,不需要大的batch,不需要改变模型结构,只通过引入stop-gradient,就解决了孪生网络的collapsing solutions。这种简单易懂的方法(而非大量堆砌参数与子结构)在机器学习界非常的受欢迎,不得不说Kaiming He大神的优化思路总是如此简洁易懂。另外,通过引入MLP层,对表征向量进行进一步映射,来推测数据增强的分布期望,从而降低了由数据增强带来的随机性。作者也通过另一种方法来近似估计数据增强期望,且移除了,发现结果依然很好,来证明的作用确实是在预估该期望。最后,由于我们stop-gradient是参考的alternating algorithm的方法,因此不难想到,如果我们对输入做次()数据增强,在计算loss时,次均使用stop-gradient,随着的增大,效果也会越来越好。作者的实验表明确实如此,因此也证实了前面说的Simsiam确实是在做alternating optimization。

[1] https://arxiv.org/abs/2011.10566

[2] https://arxiv.org/abs/2002.05709

[3] https://arxiv.org/abs/2006.09882

[4] https://arxiv.org/abs/2006.07733

关键词: 损失函数 模型结构 数学期望

推荐内容

Copyright@  2015-2022 起点器材装备网版权所有  备案号: 皖ICP备2022009963号-12   联系邮箱:295 911 578@qq.com