torch.utils.data
torch.utils.data | Torch Docs
torch.utils.data.Dataset
torch.utils.data.Dataset | Torch Docs
自建数据集
所有的自建数据集都需要继承Dataset 这个类, 并且重写__getitem__()
方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import torchfrom torch.utils.data import Datasetclass CustomDataset (Dataset ): def __init__ (self, input , target ): self.input , self.target = input , target def __getitem__ (self, index ): return self.input [index], self.target[index] def __len__ (self ): return len (self.input ) input = torch.rand([108 ,64 ]) target = torch.randint(0 ,10 ,(108 ,)) dataset = CustomDataset(input , target) print (type (dataset)) for item in dataset: print (type (item)) print (item) break
从本地加载图片数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 my_dataset.py flowers : 30 ├── daisies : 6 | ├── daisy000.jpg | ├── ... | └── daisy005.jpg | ├── dandelions : 6 | ├── dandelion000.jpg | ├── ... | └── dandelion005.jpg | ├── roses : 6 | ├── rose000.jpg | ├── ... | └── rose005.jpg | ├── sunflowers : 6 | ├── sunflower000.jpg | ├── ... | └── sunflower005.jpg | ├── tulips : 6 | ├── tulip000.jpg | ├── ... | └── tulip005.jpg | └── LICENSE.txt
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import osimport randomfrom PIL import Imagefrom torch.utils.data import Datasetfrom torch.utils.data import DataLoaderfrom torchvision import transformsclass MyDataSet (Dataset ): """自定义Dataset类""" def __init__ (self, image_paths: list , image_labels: list , transform=None ): self.image_paths = image_paths self.image_labels = image_labels self.transform = transform def __len__ (self ): return len (self.image_paths) def __getitem__ (self, index ): image = Image.open (self.image_paths[index]) label = self.image_labels[index] if self.transform is not None : image = self.transform(image) return image, label
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 def read_image_paths (root: str , val_rate: float =0.2 ): flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] flower_classes.sort() cla2lab = dict ((k, v) for v, k in enumerate (flower_classes)) train_image_paths = [] train_image_labels = [] val_image_paths = [] val_image_labels = [] for cla in flower_classes: cla_path = os.path.join(root, cla) image_paths = [os.path.join(root, cla, file_name) for file_name in os.listdir(cla_path)] image_paths.sort() image_label = cla2lab[cla] val_image_path = random.sample(image_paths, k=int (len (image_paths) * val_rate)) val_image_label = [image_label] * len (val_image_path) train_image_path = [image_path for image_path in image_paths if image_path not in val_image_path] train_image_label = [image_label] * len (train_image_path) val_image_paths += val_image_path val_image_labels += val_image_label train_image_paths += train_image_path train_image_labels += train_image_label return train_image_paths, train_image_labels, val_image_paths, val_image_labels
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 train_image_paths, train_image_labels, \ val_image_paths, val_image_labels = read_image_paths('./flowers' , val_rate=0.2 ) print (len (train_image_paths)) print (len (train_image_labels)) print (len (val_image_paths)) print (len (val_image_labels))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 train_dataset = MyDataSet(image_paths=train_image_paths, image_labels=train_image_labels, transform=transforms.ToTensor()) val_dataset = MyDataSet(image_paths=val_image_paths, image_labels=val_image_labels, transform=transforms.ToTensor()) for data in train_dataset: image, label = data print (image.size(), label) break
torch.utils.data.DataLoader
torch.utils.data.DataLoader {target="_blank"}
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import torchfrom torch.utils.data import Datasetfrom torch.utils.data import DataLoaderclass CustomDataset (Dataset ): def __init__ (self, input , label ): self.input , self.label = input , label def __getitem__ (self, index ): return self.input [index], self.label[index] def __len__ (self ): return len (self.input )
1 2 3 4 5 6 7 8 9 10 11 images = torch.rand([64 , 3 , 28 , 42 ]) targets = torch.randint(0 , 10 , (64 ,)) dataset= CustomDataset(input =images, label=targets) loader = DataLoader(dataset=dataset, shuffle=True , batch_size=4 )
默认的batch_size = 1
, 指定 batch_size = 4
将一次处理 4 个样本
1 2 3 4 5 6 7 print (type (loader)) for batch in loader: print (type (batch), len (batch)) images, labels = batch print (images.size()) print (labels.size()) break
关于 num_worker
Single- and Multi-process Data Loading {target="_blank"}
torch.utils.data.Subset
How to get a subset of a whole dataset?
torch.utils.data.Subset {target="_blank"}
1 2 3 4 5 6 7 8 9 10 11 12 from torchvision.datasets import MNISTfrom torch.utils.data import Subset, DataLoadertrainset = MNIST(root="dataset/" , train=True , download=True ) evens = list (range (0 , len (trainset), 2 )) odds = list (range (1 , len (trainset), 2 )) trainset_1 = Subset(dataset=trainset, indices=evens) trainset_2 = Subset(dataset=trainset, indices=odds) trainloader_1 = DataLoader(trainset_1, batch_size=4 ) trainloader_2 = DataLoader(trainset_2, batch_size=4 )