torch.utils.data

torch.utils.data | Torch Docs

torch.utils.data.Dataset

torch.utils.data.Dataset | Torch Docs

自建数据集

所有的自建数据集都需要继承Dataset这个类, 并且重写__getitem__()方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from torch.utils.data import Dataset

#构建一个自定义数据集的类,继承Dataset类
class CustomDataset(Dataset):
def __init__(self, input, target):
self.input, self.target = input, target

def __getitem__(self, index):
return self.input[index], self.target[index]

def __len__(self):
return len(self.input)

#构建一个简单的数据集
input = torch.rand([108,64]) #108个样本, 每个样本是一个64维的向量
target = torch.randint(0,10,(108,)) #108个标签, 每个标签都是[0,10)的整数

#通过CustomDataset获取经过Dataset封装的可迭代的数据集
dataset = CustomDataset(input, target)

print(type(dataset)) # <class '__main__.CustomDataset'>
#输出dataset的第一项
for item in dataset:
print(type(item)) # <class 'tuple'>
print(item) # (tensor([*, *, ..., *]), tensor(*))
break

从本地加载图片数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Folder Structure
my_dataset.py
flowers : 30
├── daisies : 6
| ├── daisy000.jpg
| ├── ...
| └── daisy005.jpg
|
├── dandelions : 6
| ├── dandelion000.jpg
| ├── ...
| └── dandelion005.jpg
|
├── roses : 6
| ├── rose000.jpg
| ├── ...
| └── rose005.jpg
|
├── sunflowers : 6
| ├── sunflower000.jpg
| ├── ...
| └── sunflower005.jpg
|
├── tulips : 6
| ├── tulip000.jpg
| ├── ...
| └── tulip005.jpg
|
└── LICENSE.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# my_dataset.py
import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms

class MyDataSet(Dataset):
"""自定义Dataset类"""

def __init__(self, image_paths: list, image_labels: list, transform=None):
self.image_paths = image_paths
self.image_labels = image_labels
self.transform = transform

def __len__(self):
return len(self.image_paths)

def __getitem__(self, index):
image = Image.open(self.image_paths[index])
label = self.image_labels[index]

if self.transform is not None:
image = self.transform(image)

return image, label
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def read_image_paths(root: str, val_rate: float=0.2):
# print(root) # './flowers'
# print(os.listdir(root))
# # ['daisies', 'dandelions', 'LICENSE.txt', 'roses', 'sunflowers', 'tulips']

# # 判断路径是否是文件夹
# print(os.path.join(root, 'daisies')) # './flowers\daisies'
# print(os.path.isdir(os.path.join(root, 'daisies'))) # True
# print(os.path.join(root, 'LICENSE.txt')) # './flowers\LICENSE.txt'
# print(os.path.isdir(os.path.join(root, 'LICENSE.txt'))) # False

# 遍历 root 下的文件/文件夹,将文件夹的名称对应类别名称
flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# ['daisies', 'dandelions', 'roses', 'sunflowers', 'tulips']

# 排序,保证各平台顺序一致
flower_classes.sort()

# 构建类别名称和对应数字索引间的字典
cla2lab = dict((k, v) for v, k in enumerate(flower_classes))
# {'daisies': 0, 'dandelions': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
# lab2cla = dict((v, k) for v, k in enumerate(flower_classes))
# {0: 'daisies', 1: 'dandelions', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}

train_image_paths = [] # 存储训练集所有图片的 path
train_image_labels = [] # 存储训练集所有图片的 label
val_image_paths = [] # 存储验证集所有图片的 path
val_image_labels = [] # 存储验证集所有图片的 label

# 遍历每个文件夹下的文件
for cla in flower_classes:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
image_paths = [os.path.join(root, cla, file_name) for file_name in os.listdir(cla_path)]
# ['./flowers\\daisies\\daisy000.jpg',
# ...
# './flowers\\daisies\\daisy005.jpg']

# 排序,保证各平台顺序一致
image_paths.sort()

# 获取该类别对应的索引
# cla: 'daisies', 'dandelions', 'roses', 'sunflowers', 'tulips'
# lab: 0, 1, 2, 3, 4
image_label = cla2lab[cla]

# 按比例随机采样验证样本, k 向下取整
val_image_path = random.sample(image_paths, k=int(len(image_paths) * val_rate))
val_image_label = [image_label] * len(val_image_path) # list * int
train_image_path = [image_path for image_path in image_paths if image_path not in val_image_path]
train_image_label = [image_label] * len(train_image_path) # list * int

val_image_paths += val_image_path # list + list
val_image_labels += val_image_label # list + list
train_image_paths += train_image_path # list + list
train_image_labels += train_image_label # list + list

return train_image_paths, train_image_labels, val_image_paths, val_image_labels
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 从本地目录加载图片路径,并区分 train 和 validation 数据
train_image_paths, train_image_labels, \
val_image_paths, val_image_labels = read_image_paths('./flowers', val_rate=0.2)

# classes : num | calculation | validation | train
# daisies : 6 | 6*0.2=1.2 | 1 | 5
# dandelions : 6 | 6*0.2=1.2 | 1 | 5
# roses : 6 | 6*0.2=1.2 | 1 | 5
# sunflowers : 6 | 6*0.2=1.2 | 1 | 5
# tulips : 6 | 6*0.2=1.2 | 1 | 5
# -----------------------------------------------------
# total : 30 | | 5 | 25

print(len(train_image_paths)) # 25
print(len(train_image_labels)) # 25
print(len(val_image_paths)) # 5
print(len(val_image_labels)) # 5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 创建 train 数据集的 Dataset 实例
train_dataset = MyDataSet(image_paths=train_image_paths,
image_labels=train_image_labels,
transform=transforms.ToTensor())

# 创建 validation 数据集的 Dataset 实例
val_dataset = MyDataSet(image_paths=val_image_paths,
image_labels=val_image_labels,
transform=transforms.ToTensor())

for data in train_dataset:
image, label = data
print(image.size(), label) # torch.Size([3, 244, 320]) 0
break

torch.utils.data.DataLoader

torch.utils.data.DataLoader{target="_blank"}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

#创建一个自定义的Dataset类
class CustomDataset(Dataset):
def __init__(self, input, label):
self.input, self.label = input, label

def __getitem__(self, index):
return self.input[index], self.label[index]

def __len__(self):
return len(self.input)
1
2
3
4
5
6
7
8
9
10
11
#构建一个简单的数据集
images = torch.rand([64, 3, 28, 42]) #64个样本, 每个样本是一个3*28*42的图片
targets = torch.randint(0, 10, (64,)) #64个标签, 每个标签都是[0,10)的整数

#通过CustomDataset获取经过Dataset封装的可迭代的数据集
dataset= CustomDataset(input=images, label=targets)

#通过DataLoader加载数据集
loader = DataLoader(dataset=dataset,
shuffle=True,
batch_size=4)

默认的batch_size = 1, 指定 batch_size = 4 将一次处理 4 个样本

1
2
3
4
5
6
7
print(type(loader)) #<class 'torch.utils.data.dataloader.DataLoader'>
for batch in loader:
print(type(batch), len(batch)) #<class 'list'> 2
images, labels = batch
print(images.size()) #torch.Size([4, 3, 28, 42])
print(labels.size()) #torch.Size([4])
break

关于 num_worker

Single- and Multi-process Data Loading{target="_blank"}

torch.utils.data.Subset

How to get a subset of a whole dataset?

torch.utils.data.Subset{target="_blank"}

1
2
3
4
5
6
7
8
9
10
11
12
from torchvision.datasets import MNIST
from torch.utils.data import Subset, DataLoader

trainset = MNIST(root="dataset/", train=True, download=True) # len: 60000

evens = list(range(0, len(trainset), 2)) # len: 30000
odds = list(range(1, len(trainset), 2)) # len: 30000
trainset_1 = Subset(dataset=trainset, indices=evens) # len: 30000
trainset_2 = Subset(dataset=trainset, indices=odds) # len: 30000

trainloader_1 = DataLoader(trainset_1, batch_size=4) # len: 7500
trainloader_2 = DataLoader(trainset_2, batch_size=4) # len: 7500