avatar

Catalog
Pytorch Dataloader 用法

DataLoder

Pytorch在训练前一步数据读取时,要使用 DataLoader 加载数据, 可以shuffle 、多线程读取等。记录一下如何使用。

自黑一下,之前我写的第一个工程,完全没写DataLoader,直接把原图,标签放在两个大列表里面来循环,😂🤣,所以当时16g内存都直接溢出😆,所以说不写DataLoader也不是不可以,就是有点不可描述

镇楼

使用步骤

定义Dataset

python
1
torch.utils.data.Dataset

首先自定义dataset类,以上述Dataset为父类,必须重写__getitem__() 方法,即获取数据逻辑;

可选重写__len()__方法,获取数据长度信息。

类似如下,是我们最近做的一个研究中Dataset 片段,重点关注它return的结果。

python
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.utils.data.Dataset 

# 定义 Dataset
class JointsDataset(Dataset):
def __init__(self, ........):
pass
def __len__(self, ):
return len(self.db)
def __getitem__(self, idx):
'''
省略代码块
'''
return input_data_numpy, input_sup_A_data_numpy, input_sup_B_data_numpy, target_heatmaps, target_weight, meta

创建 DataLoader

先看一下Dataloader类定义

python
1
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

一般用的几个参数:

  • dataset 为上述的自定义的类
  • batch_size
  • shuffle 打乱数据
  • num_workers 多线程
  • pin_memory 更快的发送数据到显存 (不太清楚)
python
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
train_dataset = JointsDataset(
cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
cfg.DATASET.TRAIN_NPY_DIR,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])
)

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
shuffle=cfg.TRAIN.SHUFFLE,
num_workers=cfg.WORKERS,
pin_memory=cfg.PIN_MEMORY
)

迭代获取数据

python
1
2
for i, (input, input_sup_A, input_sup_B, target, target_weight, meta) in enumerate(train_loader):
pass

关于获取的数据可以看到和自定义dataset__getitem()__方法返回的东西是一样的。


总结一下,定义dataset, 创建dataloder, 迭代获取进行训练。

Author: 星星泡饭
Link: https://luoyou.art/2020/07/09/Pytorch-Dataloader-%E7%94%A8%E6%B3%95/
Copyright Notice: All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Donate
  • 微信
    微信
  • 支付寶
    支付寶

Comment