1.简介

在之前的文章里,我们介绍了集成一阶动量和二阶动量的优化器Adam。AdamW其实是在Adam的基础上加入了weight decay正则化,但是我们上一篇文章里也看到了Adam的代码中已经有正则化,那么两者有什么区别呢?


2.AdamW

其实AdamW和Adam唯一的区别,就是weight decay的加入方式。

在Adam当中,weight decay是直接加入到梯度当中:

g_{t} = g_{t}+\lambda \theta _{t-1}

其中g_{t}是当前step的梯度,\theta _{t-1}是上一个step中的模型权重,\lambda是正则化系数。

而在AdamW中,正则化变成了:

\theta _{t} = \theta _{t-1}-\gamma \lambda \theta _{t-1}

其中\gamma是学习率。

所以AdamW的思路特别简单:反正正则化系数加进梯度之后最终也要在权重上进行更新,那为什么还需要加进梯度去呢?因此,AdamW直接在权重上进行衰减,在收敛速度上也能领先于Adam。


3.思考

但仔细一想,Adam+L2正则化和AdamW虽然都可以实现权重衰减,但是两者的实施细节上其实是不一样的。L2正则化是在loss上加入权重的惩罚系数,也可以说是在梯度上进行修改,而AdamW其实是更字面意思的weight decay,就是直接让权重衰减。

这两者其实在SGD上是对等的:

\theta _{t} = \theta _{t-1}-\gamma g_{t}^{weight}

= \theta _{t-1}-\gamma (g_{t}+\lambda \theta _{t-1}))

= \theta _{t-1}-\gamma \lambda \theta _{t-1}-\gamma g_{t}

只不过在Adam这种要考虑一阶和二阶动量时,以上方程已不满足线性关系,所以最终的结果是有区别的。那么AdamW相对于Adam而言,除了收敛速度更快之外,它的正则系数也不再受动量的影响(一般会被除以二阶动量而稀释),因此拥有超参独立正则力度增加的优点,这也是原论文名字中带有decouple的原因。


4.pytorch代码

AdamW的伪代码流程如下:

以下代码为pytorch官方Adam的代码。

def _single_tensor_adamw(
    params: List[Tensor],
    grads: List[Tensor],
    exp_avgs: List[Tensor],
    exp_avg_sqs: List[Tensor],
    max_exp_avg_sqs: List[Tensor],
    state_steps: List[Tensor],
    grad_scale: Optional[Tensor],
    found_inf: Optional[Tensor],
    *,
    amsgrad: bool,
    beta1: float,
    beta2: float,
    lr: float,
    weight_decay: float,
    eps: float,
    maximize: bool,
    capturable: bool,
    differentiable: bool,
):

    assert grad_scale is None and found_inf is None

    for i, param in enumerate(params):
        grad = grads[i] if not maximize else -grads[i]
        exp_avg = exp_avgs[i]
        exp_avg_sq = exp_avg_sqs[i]
        step_t = state_steps[i]

        if capturable:
            assert (
                param.is_cuda and step_t.is_cuda
            ), "If capturable=True, params and state_steps must be CUDA tensors."

        if torch.is_complex(param):
            grad = torch.view_as_real(grad)
            exp_avg = torch.view_as_real(exp_avg)
            exp_avg_sq = torch.view_as_real(exp_avg_sq)
            param = torch.view_as_real(param)

        # update step
        step_t += 1

        # Perform stepweight decay
        param.mul_(1 - lr * weight_decay)

        # Decay the first and second moment running average coefficient
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
        exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

        if capturable or differentiable:
            step = step_t

            # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor
            # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing")
            bias_correction1 = 1 - torch.pow(beta1, step)
            bias_correction2 = 1 - torch.pow(beta2, step)

            step_size = lr / bias_correction1
            step_size_neg = step_size.neg()

            bias_correction2_sqrt = bias_correction2.sqrt()

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                if differentiable:
                    max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone()
                else:
                    max_exp_avg_sqs_i = max_exp_avg_sqs[i]
                max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq))
                # Uses the max. for normalizing running avg. of gradient
                # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
                # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
                denom = (
                    max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
                ).add_(eps / step_size_neg)
            else:
                denom = (
                    exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
                ).add_(eps / step_size_neg)

            param.addcdiv_(exp_avg, denom)
        else:
            step = _get_value(step_t)

            bias_correction1 = 1 - beta1 ** step
            bias_correction2 = 1 - beta2 ** step

            step_size = lr / bias_correction1

            bias_correction2_sqrt = _dispatch_sqrt(bias_correction2)

            if amsgrad:
                # Maintains the maximum of all 2nd moment running avg. till now
                torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
                # Use the max. for normalizing running avg. of gradient
                denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
            else:
                denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)

            param.addcdiv_(exp_avg, denom, value=-step_size)

业务合作/学习交流+v:lizhiTechnology

 如果想要了解更多优化器相关知识,可以参考我的专栏和其他相关文章:

优化器_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【优化器】(二) AdaGrad原理 & pytorch代码解析_adagrad优化器-CSDN博客

【优化器】(三) RMSProp原理 & pytorch代码解析_rmsprop优化器-CSDN博客

【优化器】(四) AdaDelta原理 & pytorch代码解析_adadelta里rho越大越敏感-CSDN博客

【优化器】(五) Adam原理 & pytorch代码解析_adam优化器-CSDN博客

【优化器】(六) AdamW原理 & pytorch代码解析-CSDN博客

【优化器】(七) 优化器统一框架 & 总结分析_mosec优化器优点-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐