1. 为什么我们需要Warmup与Cosine Annealing?
如果你在训练深度学习模型时,感觉模型一开始就“跑偏”了,或者训练损失下降得特别慢,那很可能就是学习率没调好。我刚开始炼丹那会儿,也总在这个问题上栽跟头。后来我发现,很多高手都在用“预热(Warmup)+余弦退火(Cosine Annealing)”这套组合拳,它能让模型训练得更稳、效果更好。
简单来说,Warmup就像运动前的热身。模型刚初始化时,它的参数是随机的,梯度可能非常大。如果一上来就用很大的学习率去更新,步子迈得太大,很容易导致训练不稳定,甚至直接“跑飞”。Warmup就是在训练初期,让学习率从一个很小的值(比如0)开始,线性地(或其他方式)逐渐增加到你设定的初始学习率。这个过程给了模型一个“缓冲期”,让它先适应一下数据,把参数分布调整得平稳一些,后续再用大学习率更新时,就不容易出问题了。
而Cosine Annealing则像是给训练过程安排了一个“优雅的谢幕”。传统的学习率衰减,比如每隔固定轮次减半,是阶梯式下降的,这种突变有时会让模型不适应。余弦退火则不同,它让学习率按照余弦函数的曲线,从最大值平滑地下降到最小值。这个过程非常柔和,能让模型在训练后期,慢慢地、稳定地收敛到一个更优的解,而不是在最小值点附近来回震荡。
PyTorch官方提供了CosineAnnealingLR和CosineAnnealingWarmRestarts等调度器,但它们都没有内置Warmup阶段。想把两者结合起来,最灵活、最强大的工具就是LambdaLR。它允许你通过一个自定义的lambda函数,完全掌控学习率变化的每一个细节。接下来,我就带你彻底搞懂怎么用LambdaLR把这两者无缝衔接起来。
2. 深入理解LambdaLR:你的学习率画笔
torch.optim.lr_scheduler.LambdaLR是PyTorch学习率调度器家族里最“自由”的一个。它的核心参数是lr_lambda,可以是一个函数,也可以是一个函数列表。这个函数的作用是:在每一次调度器更新时,计算一个乘数因子,然后用这个因子去乘以优化器中每个参数组(param_group)的初始学习率(base_lr)。
这里有个关键点必须理解:LambdaLR修改的是学习率乘数,而不是直接设置学习率绝对值。公式是:当前学习率 = base_lr * lr_lambda(current_step)。很多新手在这里踩坑,写lambda函数时直接返回想要的学习率数值,结果发现实际学习率不对,就是因为忘了它是个乘数。
LambdaLR的强大之处在于它能针对不同的参数组设置不同的调整策略。在定义优化器时,我们可以把模型的参数分成多个组,每个组可以有自己的初始学习率和其他超参数。然后,给lr_lambda传递一个函数列表,列表中的每个函数会依次应用到对应的参数组上。这个功能在微调模型时特别有用,比如我们希望骨干网络(backbone)的学习率小一点、慢一点衰减,而新加的分类头(head)学习率大一点、衰减快一点。
下面我们通过一个最简单的例子,看看LambdaLR是怎么工作的:
import torch
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 2)
# 创建优化器,并设置两个参数组,它们的基础学习率不同
optimizer = torch.optim.SGD([
{'params': model.weight, 'lr': 0.1}, # 参数组0,base_lr=0.1
{'params': model.bias, 'lr': 0.01} # 参数组1,base_lr=0.01
])
# 定义两个lambda函数作为乘数
# 参数组0:学习率每步变为之前的0.9倍
lambda0 = lambda epoch: 0.9 ** epoch
# 参数组1:学习率每步线性减少到0
lambda1 = lambda epoch: max(1.0 - epoch / 10.0, 0)
# 创建LambdaLR调度器
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda0, lambda1])
# 模拟训练,打印学习率变化
for epoch in range(5):
print(f'Epoch {epoch}:')
for i, group in enumerate(optimizer.param_groups):
print(f' Param Group {i} lr: {group["lr"]:.6f}')
scheduler.step()
运行这段代码,你会看到两个参数组的学习率按照我们定义的规则独立变化。


2247

被折叠的 条评论
为什么被折叠?



