迁移学习实战:用PyTorch冻结ResNet底层权重的3种高效写法

迁移学习实战:用PyTorch冻结ResNet底层权重的3种高效写法

最近在做一个细粒度图像分类的项目,手头数据量不大,从头训练一个深度网络显然不现实。和团队里的几位工程师讨论后,大家一致决定采用迁移学习的策略,拿一个在ImageNet上预训练好的ResNet作为起点。想法很美好,但实际操作时,我们却在一个看似基础的问题上卡了壳:到底怎么“冻结”预训练模型的底层权重,才是最有效、最不容易出错的写法?

一开始,我们只是简单地把所有参数的 requires_grad 设为 False,然后单独解冻最后的全连接层。但在训练过程中,发现模型收敛速度比预期慢,而且偶尔会出现梯度消失的迹象。这促使我们深入探究,发现冻结权重远不止一个布尔开关那么简单。它涉及到优化器的配置效率、内存的占用,甚至会影响模型最终的性能上限。对于计算机视觉工程师而言,选择一种合适的冻结策略,是平衡训练速度、资源消耗和模型效果的关键。

本文将结合我们项目中的实际踩坑经验,分享三种经过实战检验的PyTorch冻结权重高效写法。我们会从最基础的用法讲起,逐步深入到优化器参数组配置和基于网络结构的精细化冻结策略,并附上我们在一个标准数据集上的性能对比实验数据,希望能为你下次的迁移学习任务提供一个清晰的“操作手册”。

1. 基础篇:理解 requires_grad 与优化器的联动

在PyTorch中,requires_grad 是张量的一个属性,它决定了在反向传播过程中是否需要计算该张量的梯度。将其设置为 False,是冻结权重的核心操作。但很多新手容易忽略的一点是,仅仅设置 requires_grad=False 并不足以保证该参数不被更新,优化器的行为同样至关重要。

1.1 标准的“开关”式冻结

这是最常见、最直观的方法。在加载预训练权重后,遍历模型的所有参数,根据参数名(例如,排除包含 fcclassifier 的层)将其 requires_grad 属性关闭。

import torch
import torchvision.models as models

# 加载预训练ResNet50
model = models.resnet50(pretrained=True)

# 方案A:修改模型最后一层,适配新任务(例如5分类)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, 5)

# 冻结除最后一层外的所有权重
for name, param in model.named_parameters():
    if ‘fc’ not in name:  # 仅全连接层可训练
        param.requires_grad = False
        # 或者使用 param.requires_grad_(False)

# 检查冻结效果
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f“可训练参数: {trainable_params}, 总参数: {total_params}, 冻结比例: {(1 - trainable_params/total_params)*100:.2f}%”)

注意:这里使用 param.requires_grad = Falseparam.requires_grad_(False) 是等价的,后者是in-place操作,更符合PyTorch的风格。

这种方法简单有效,但存在一个潜在的效率问题。当你将整个模型传递给优化器时,优化器仍然会为所有参数维护状态(例如SGD中的动量缓冲区),即使它们的 requires_grad=False。对于像ResNet50这样的大模型,这意味着会浪费可观的内存。

1.2 优化器的正确配置:传递可训练参数列表

为了解决上述内存浪费问题,最佳实践是只将需要训练的参数传递给优化器。PyTorch的优化器(如 SGD, Adam)接受一个可迭代的参数列表,我们可以利用列表推导式来筛选

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值