10分钟快速上手飞桨

10分钟快速上手飞桨

从完成一个简单的『手写数字识别任务』开始,可快速了解深度学习模型开发的大致流程,并初步掌握飞桨框架 API 的使用方法。

一、快速安装飞桨

../../_images/mnist.png

如果已经安装好飞桨那么可以跳过此步骤。飞桨支持很多种安装方式,这里介绍其中一种简单的安装命令。

注:目前飞桨支持 Python 3.6 ~ 3.9 版本,pip3 要求 20.2.2 或更高版本,请提前安装对应版本的 Python 和 pip 工具。

[2]:
# 使用 pip 工具安装飞桨 CPU 版
! python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple

该命令用于安装 CPU 版本的飞桨。如果要安装其他计算平台或操作系统支持的版本,可点击 快速安装 查看安装引导。

二、导入飞桨

安装完成后,需要在Python解释器中使用 import 导入飞桨,即可开始实践深度学习任务。

若操作成功,会输出飞桨的版本号。

[3]:
import paddle

print(paddle.__version__)
2.2.1

三、实践:手写数字识别任务

『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入手写数字的图片,可识别出这个图片中的数字。

本任务用到的数据集为 MNIST 手写数字数据集,用于训练和测试模型。该数据集包含 60000 张训练图片、 10000 张测试图片、以及对应的分类标签文件,每张图片上是一个 0 ~ 9 的手写数字,分辨率为 28 * 28。部分图像和对应的分类标签如下图所示。

图 1:MNIST 数据集样例

开始之前,需要使用下面的命令安装 Python 的 matplotlib 库和 numpy 库,matplotlib 库用于可视化图片,numpy 库用于处理数据。

[3]:
# 使用 pip 工具安装 matplotlib 和 numpy
! python3 -m pip install matplotlib numpy -i https://mirror.baidu.com/pypi/simple

下面是手写数字识别任务的完整代码,如果想直接运行代码,可以拷贝下面的完整代码到一个Python文件中运行。

[5]:
import paddle
import numpy as np
from paddle.vision.transforms import Normalize

transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
# 下载数据集并初始化 DataSet
train_dataset = paddle.vision.datasets.MNIST(mode="train", transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode="test", transform=transform)

# 模型组网并初始化网络
lenet = paddle.vision.models.LeNet(num_classes=10)
model = paddle.Model(lenet)

# 模型训练的配置准备,准备损失函数,优化器和评价指标
model.prepare(
    paddle.optimizer.Adam(parameters=model.parameters()),
    paddle.nn.CrossEntropyLoss(),
    paddle.metric.Accuracy(),
)

# 模型训练
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
# 模型评估
model.evaluate(test_dataset, batch_size=64, verbose=1)

# 保存模型
model.save("./output/mnist")
# 加载模型
model.load("output/mnist")

# 从测试集中取出一张图片
img, label = test_dataset[0]
# 将图片shape从1*28*28变为1*1*28*28,增加一个batch维度,以匹配模型输入格式要求
img_batch = np.expand_dims(img.astype("float32"), axis=0)

# 执行推理并打印结果,此处predict_batch返回的是一个list,取出其中数据获得预测结果
out = model.predict_batch(img_batch)[0]
pred_label = out.argmax()
print("true label: {}, pred label: {}".format(label[0], pred_label))
# 可视化图片
from matplotlib import pyplot as plt

plt.imshow(img[0])
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/5
step 938/938 [==============================] - loss: 0.0519 - acc: 0.9344 - 14ms/step
Epoch 2/5
step 938/938 [==============================] - loss: 0.0239 - acc: 0.9767 - 14ms/step
Epoch 3/5
step 938/938 [==============================] - loss: 0.0416 - acc: 0.9811 - 14ms/step
Epoch 4/5
step 938/938 [==============================] 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

奋进学堂

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值