torchvision.models简介
1 torchvision.models介绍
1.1 torchvision介绍
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets、torchvision.models、torchvision.transforms
该篇主要介绍torchvision.models,关于torchvision.datasets 和 torchvision.transforms 可以看以下几篇:
https://blog.csdn.net/Alexa_/article/details/129408512
1.2 torchvision.models
torchvision.models:这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
torchvision.models模块的子模块中包含以下模型结构:
- AlexNet
- VGG: VGG-11, VGG-13, VGG-16, VGG-19
- ResNet: ResNet-18, ResNet-34, ResNet-50, ResNet-101, ResNet-152
- SqueezeNet: SqueezeNet 1.0, and SqueezeNet 1.1
- DenseNet:Densenet-121,Densenet-169,Densenet-161,Densenet-201
(1),预训练模型可以通过设置pretrained=True来构建:
import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
- 预训练模型期望的输入是RGB图像的mini-batch:(batch_size, 3, H, W),并且H和W不能低于224。
- 图像的像素值必须在范围[0,1]间,并且用均值mean=[0.485, 0.456, 0.406]和方差std=[0.229, 0.224, 0.225]进行归一化。
下载的模型可以通过state_dict() 来打印状态参数、缓存的字典,如下所示:
import torchvision.models as models
vgg16 = models.vgg16(pretrained=True)
# 返回包含模块所有状态的字典,包括参数和缓存
pretrained_dict = vgg16.state_dict()
(2),只需要网络结构,不加载参数来初始化,可以将pretrained = False
model = torchvision.models.densenet169(pretrained=False)
# 等价于:
model = torchvision.models.densenet169()
2 导入模型举例
应用VGG16模型,并进行改动,以适应CIDIAR10数据集。
- CIFAR10数据集是 10个类别
- VGG16输出是1000个类别
- VGG 加一层输出10个类别
2.1 模型的使用
导入模型,输出查看网络结构:
import torchvision
# 直接调用,实例化模型,pretrained代表是否下载预先训练好的参数
vgg16_false = torchvision.models.vgg16(pretrained = False)
vgg16_ture = torchvision.models.vgg16(pretrained = True)
print(vgg16_ture)
输出结果,可以看到VGG16的结构,可以看出,其最后一行 out_features = 1000.
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride

本文介绍了torchvision.models包,它包含了预训练的深度学习模型如AlexNet、VGG、ResNet等,并展示了如何导入、使用、修改这些模型。对于CIFAR10数据集,可以通过修改预训练模型的输出层来适应10类分类任务。此外,还讨论了模型的保存和加载方法。

7287

被折叠的 条评论
为什么被折叠?



