研究SAC的时候没搞太懂,花了好几天研究这个问题,记录一下
参考:
漫谈重参数:从正态分布到Gumbel Softmax - 科学空间|Scientific Spaces (kexue.fm)
VAE中的重参数化技巧-reparameterization trick - 知乎 (zhihu.com)
引入
考虑形如下形式的损失函数:
Ep(z)[fθ(z)]
在连续问题或z的取值空间很大的离散问题中,我们很难或者不可能遍历所有的z,因此需要采样(Monte Carlo)。
若z的分布与我们需要求梯度的参数θ无关,则:
∇θEp(z)[fθ(z)]=∇θ[∫zp(z)fθ(z)dz]=∫zp(z)[∇θfθ(z)]dz=Ep(z)[∇θfθ(z)]
然而,若问题变为:
Epθ(z)[fθ(z)]
计算梯度:
∇θEpθ(z)[fθ(z)]=∇θ[∫zpθ(z)fθ(z)dz]=∫z∇θ[pθ(z)fθ(z)]dz=∫zfθ(z)∇θpθ(z)dz+∫zpθ(z)∇θfθ(z)dz
由于我们需要计算分布p的梯度,第一项无法变成期望的形式,因此也无法进行采样。
为了解决这个问题,可以使用重参数化技巧与Gumbel-Softmax
Reparameterization
原理
考虑连续情况:
Lθ=∫zpθ(z)f(z)dz
我们需要在进行采样的同时保留θ的梯度,为此,我们考虑先从无参分布q中进行采样,然后通过某种变换生成z:
ϵ∼q(ϵ)z=gθ(ϵ)
此时式子变为:
Lθ=Eϵ∼q(ϵ)[f(gθ(ϵ))]
此时我们把随机采样和梯度传播解耦了,可以直接反向传播loss
实现
以SAC为例,原本需要从$ \mathcal{N} (\mu_\theta, \sigma^2_\theta) $中进行抽样。我们进行重参数化:
ϵ∼N(0,1)z=ϵ×σθ+μθ⇒Lθ=Eϵ∼N(0,1)[f(ϵ×σθ+μθ)]
然后就可以直接进行反向传播更新网络参数
Gumbel-Softmax
原理
现在我们考虑离散情况:
Lθ=y∑pθ(y)f(y)
显然我们是可以通过这个求和操作直接计算出Loss的,
然而若取值空间非常巨大,我们依旧需要通过采样来估算这个期望。
和上文一样,我们考虑如何分离随机采样:
引入Gumbel-Max:
argmaxi(logpi−log(−logϵi))i=1k, ϵi∼U[0,1]
现在已经通过这个一样重参数过程将随机性转移到了均匀分布上,但是由于我们使用了不可导的argmax,还是会丢失梯度信息。
因此,我们引入其光光滑似版本,Gumbel-Softmax:
softmaxi((logpi−log(−logϵi)/τ)i=1k, ϵi∼U[0,1]
tau为退火参数,越小则输出越接近One-Hot输出,然而此时会导致梯度消失。因此训练时可以从1开始,慢慢衰减。
证明
要证明Gumbel-Max抽样和原始分布一样,需要证明输出i的概率为pi,此处证明输出1的概率为p1,即:
logp1−log(−logϵ1)>logpi−log(−logϵi) ,∀i=1
化简得:
ϵi<ϵ1pi/p1≤1
成立概率为:
ϵ1(p2+p3+⋯+pk)/p1=ϵ(1/p1)−1∫01ϵ1(1/p1)−1dϵ1=p1
证毕。
实现
pytorch自带Gumbel-Softmax函数,看看代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| gumbels = ( -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log() ) gumbels = (logits + gumbels) / tau y_soft = gumbels.softmax(dim)
if hard: index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: ret = y_soft return ret
|
我们看到,pytorch除了输出类似One-Hot版本,还支持一个hard模式,这步ret = y_hard - y_soft.detach() + y_soft
通过分离计算图的方式让前向传播和反向传播不同,反向传播时仍然计算的是y_soft
的梯度。