DataLoder
Pytorch在训练前一步数据读取时,要使用 DataLoader
加载数据, 可以shuffle
、多线程读取等。记录一下如何使用。
自黑一下,之前我写的第一个工程,完全没写DataLoader,直接把原图,标签放在两个大列表里面来循环,😂🤣,所以当时16g内存都直接溢出😆,所以说不写DataLoader也不是不可以,就是有点不可描述
镇楼
使用步骤
定义Dataset
python
1 | torch.utils.data.Dataset |
首先自定义dataset
类,以上述Dataset
为父类,必须重写__getitem__()
方法,即获取数据逻辑;
可选重写__len()__
方法,获取数据长度信息。
类似如下,是我们最近做的一个研究中Dataset
片段,重点关注它return
的结果。
python
1 | import torch.utils.data.Dataset |
创建 DataLoader
先看一下Dataloader类定义
python
1 | torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None) |
一般用的几个参数:
- dataset 为上述的自定义的类
- batch_size
- shuffle 打乱数据
- num_workers 多线程
- pin_memory 更快的发送数据到显存 (不太清楚)
python
1 | train_dataset = JointsDataset( |
迭代获取数据
python
1 | for i, (input, input_sup_A, input_sup_B, target, target_weight, meta) in enumerate(train_loader): |
关于获取的数据可以看到和自定义dataset
类__getitem()__
方法返回的东西是一样的。
总结一下,定义dataset
, 创建dataloder
, 迭代获取进行训练。