AdamW_优化算法相较于_Adam_的改进之处

🎓博主介绍:Java、Python、js全栈开发 “多面手”,精通多种编程语言和技术,痴迷于人工智能领域。秉持着对技术的热爱与执着,持续探索创新,愿在此分享交流和学习,与大家共进步。
📖DeepSeek-行业融合之万象视界(附实战案例详解100+)
📖全栈开发环境搭建运行攻略:多语言一站式指南(环境搭建+运行+调试+发布+保姆级详解)
👉感兴趣的可以先收藏起来,希望帮助更多的人
在这里插入图片描述

AdamW 优化算法相较于 Adam 的改进之处

一、引言

在深度学习领域,优化算法对于模型的训练起着至关重要的作用。Adam(Adaptive Moment Estimation)算法自提出以来,因其能自适应地调整每个参数的学习率,在众多深度学习任务中得到了广泛应用。然而,随着研究的深入,人们发现 Adam 算法存在一些问题。AdamW 优化算法应运而生,它是对 Adam 算法的改进,旨在解决 Adam 算法中存在的一些不足,提高模型的泛化能力。本文将详细探讨 AdamW 优化算法相较于 Adam 的改进之处。

二、Adam 算法概述

2.1 算法原理

Adam 算法结合了 AdaGrad 和 RMSProp 算法的优点,通过计算梯度的一阶矩估计(均值)和二阶矩估计(方差)来为不同的参数动态调整学习率。具体步骤如下:

  1. 初始化参数:
    • 初始化梯度的一阶矩估计 m 0 = 0 m_0 = 0 m0=0,二阶矩估计 v 0 = 0 v_0 = 0 v0=0,时间步 t = 0 t = 0 t=0
    • 给定学习率 α \alpha α,衰减率 β 1 \beta_1 β1 β 2 \beta_2 β2,以及小常数 ϵ \epsilon ϵ
  2. 在每个时间步 t t t
    • 计算梯度 g t g_t gt
    • 更新一阶矩估计: m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t - 1} + (1 - \beta_1) g_t mt=β1mt1+(1β1)gt
    • 更新二阶矩估计: v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t - 1} + (1 - \beta_2) g_t^2 vt=β2vt1+(1β2)gt2
    • 修正一阶矩和二阶矩的偏差: m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t=1β1tmt v ^ t = v t 1 − β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t=1β2tvt
    • 更新参数: θ t + 1 = θ t − α m ^ t v ^ t + ϵ \theta_{t + 1} = \theta_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} θt+1=θtαv^t +ϵm^t

2.2 代码实现(使用 PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
model = nn.Linear(10, 1)

# 使用 Adam 优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

三、Adam 算法存在的问题

3.1 权重衰减问题

在 Adam 算法中,通常会使用 L2 正则化来进行权重衰减,以防止模型过拟合。然而,在 Adam 中使用 L2 正则化会导致权重衰减和自适应学习率调整之间的相互作用,使得权重衰减的效果变得不稳定。具体来说,在 Adam 中,L2 正则化项会被加到梯度中,然后再进行自适应学习率的调整。这样,自适应学习率的调整会影响权重衰减的效果,使得权重衰减的强度在不同的参数上不一致。

3.2 泛化能力不足

由于权重衰减问题,Adam 算法在一些情况下可能会导致模型的泛化能力不足。在训练过程中,模型可能会过度拟合训练数据,而在测试数据上的表现不佳。

四、AdamW 算法的改进

4.1 修正权重衰减的实现方式

AdamW 算法对权重衰减的实现方式进行了修正。在 AdamW 中,权重衰减是在自适应学习率调整之后进行的,这样可以避免权重衰减和自适应学习率调整之间的相互作用。具体步骤如下:

  1. 初始化参数:同 Adam 算法。
  2. 在每个时间步 t t t
    • 计算梯度 g t g_t gt
    • 更新一阶矩估计: m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t - 1} + (1 - \beta_1) g_t mt=β1mt1+(1β1)gt
    • 更新二阶矩估计: v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t - 1} + (1 - \beta_2) g_t^2 vt=β2vt1+(1β2)gt2
    • 修正一阶矩和二阶矩的偏差: m ^ t = m t 1 − β 1 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t} m^t=1β1tmt v ^ t = v t 1 − β 2 t \hat{v}_t = \frac{v_t}{1 - \beta_2^t} v^t=1β2tvt
    • 计算权重衰减: θ t + 1 = θ t − α m ^ t v ^ t + ϵ − λ θ t \theta_{t + 1} = \theta_t - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \lambda \theta_t θt+1=θtαv^t +ϵm^tλθt,其中 λ \lambda λ是权重衰减系数。

4.2 代码实现(使用 PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
model = nn.Linear(10, 1)

# 使用 AdamW 优化器
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

4.3 提高泛化能力

通过修正权重衰减的实现方式,AdamW 算法可以更有效地控制模型的复杂度,从而提高模型的泛化能力。在一些实验中,使用 AdamW 算法训练的模型在测试数据上的表现优于使用 Adam 算法训练的模型。

五、实验对比

为了更直观地展示 AdamW 算法相较于 Adam 算法的改进,我们可以进行一个简单的实验。假设我们要训练一个简单的全连接神经网络来进行手写数字识别任务。

5.1 实验步骤

  1. 加载数据集:使用 PyTorch 的 torchvision 库加载 MNIST 数据集。
  2. 定义模型:定义一个简单的全连接神经网络。
  3. 训练模型:分别使用 Adam 和 AdamW 优化器训练模型。
  4. 评估模型:在测试集上评估模型的准确率。

5.2 代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# 加载数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False)

# 定义模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model_adam = SimpleNet()
model_adamw = SimpleNet()

# 定义优化器
optimizer_adam = optim.Adam(model_adam.parameters(), lr=0.001)
optimizer_adamw = optim.AdamW(model_adamw.parameters(), lr=0.001, weight_decay=0.01)

# 定义损失函数
criterion = nn.CrossEntropyLoss()

# 训练模型
def train_model(model, optimizer, epochs=5):
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

# 训练 Adam 模型
print("Training model with Adam optimizer...")
train_model(model_adam, optimizer_adam)

# 训练 AdamW 模型
print("Training model with AdamW optimizer...")
train_model(model_adamw, optimizer_adamw)

# 评估模型
def evaluate_model(model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f'Accuracy: {100 * correct / total}%')

# 评估 Adam 模型
print("Evaluating model with Adam optimizer...")
evaluate_model(model_adam)

# 评估 AdamW 模型
print("Evaluating model with AdamW optimizer...")
evaluate_model(model_adamw)

5.3 实验结果分析

通过实验结果可以发现,使用 AdamW 优化器训练的模型在测试集上的准确率通常会高于使用 Adam 优化器训练的模型,这表明 AdamW 算法在提高模型泛化能力方面具有优势。

六、结论

AdamW 优化算法通过修正权重衰减的实现方式,解决了 Adam 算法中权重衰减和自适应学习率调整之间的相互作用问题,从而提高了模型的泛化能力。在实际应用中,当需要训练复杂的深度学习模型时,AdamW 算法是一个更好的选择。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

fanxbl957

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值