Pytorch学习(一)加载数据

本文详细介绍如何在PyTorch中加载和准备数据,通过Yesno数据集实例演示使用torch.utils.data.DataLoader类进行数据加载的过程,包括数据集访问、数据加载、数据迭代等关键步骤。

在Pytorch中加载数据

pytorch具有广泛的神经网络构建模块和一个简单、直观、稳定的API。Pytorch包括为您的模型准备和加载通用数据集的包。

介绍

Pytorch加载数据的核心是torch.utils.data.DataLoader类。它表示一个在数据集上的一个Python可迭代对象。Pytorch库为我们提供了内置的高质量数据集,去在torch.utils.data.Dataset中使用。数据集可从tochvisiontorchaudiotorchtext中获得。

我们使用来自torchaudio.datasets的Yesno数据集。我们将演示如何有效地将数据从PyTorch数据集加载到PyTorch DataLoader中。

配置

pip install torchaudio

步骤、

1. 导入必须的库,来加载我们的数据

2. 访问数据集中的数据

3. 加载数据

4. 对数据进行迭代

5. 可视化数据(可选择)

1. Import necessary libraries for loading our data

import torch
import torchaudio

2. Access the data in the dataset

torchaudio.datasets.YESNO(
  root,
  url='http://www.openslr.org/resources/1/waves_yesno.tar.gz',
  folder_in_archive='waves_yesno',
  download=False,
  transform=None,
  target_transform=None)

# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
# * ``transform``: Using transforms on your data allows you to take it from its source state and transform it into data that’s joined together, de-normalized, and ready for training. Each library in PyTorch supports a growing list of transformations.
# * ``target_transform``: A function/transform that takes in the target and transforms it.
#
# Let’s access our Yesno data:
#

# A data point in Yesno is a tuple (waveform, sample_rate, labels) where labels
# is a list of integers with 1 for yes and 0 for no.
yesno_data_trainset = torchaudio.datasets.YESNO('./', download=True)

# Pick data point number 3 to see an example of the the yesno_data:
n = 3
waveform, sample_rate, labels = yesno_data[n]
print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))

3. Loading the data

data_loader = torch.utils.data.DataLoader(yesno_data,
                                          batch_size=1,
                                          shuffle=True)

4. Iterate over the data

for data in data_loader:
  print("Data: ", data)
  print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2]))
  break

5. [Optional] Visualize the data

import matplotlib.pyplot as plt

print(data[0][0].numpy())

plt.figure()
plt.plot(waveform.t().numpy())

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值