前言:
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
1、Dataset 在 torch.utils.data 无论是加载文本还是图像数据集,加载自定义数据集都需要他。官方提供的dataset则从torchvision里import。
2、DataLoader 在torch.utils.data 不管是文本还是图片都用这个包。
3、 对图像的预处理 用torchvision.transforms 包 。
数据预处理部分:
- 数据增强:torchvision中transforms模块自带功能,比较实用
- 数据预处理:torchvision中transforms也帮我们实现好了,直接调用即可
data_transforms = {
'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
transforms.CenterCrop(224),#从中心开始裁剪
transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
]),
'valid': transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
图像数据集加载部分:
- 几种数据集加载的方法区别本质在于文件里的内容,标签的位置之类的。
- 想让PyTorch能读取我们自己的数据,首先要了解pytroch读取图片的机制和流程,然后按流程编写代码。
Dataset类
PyTorch读取图片,主要是通过Dataset类,所以先简单了解一下Dataset类。Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类。
源码如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
这里重点看 getitem函数,getitem接收一个index,然后返回一个batch大小的图片数据和标签,其中这个index是一个列表,这个列表是由dataloader里的sampler采样器生成的。感兴趣的可以详细了解这里的数据集的加载Dataset和DataLoader原理。
如bitch_size的值是16,其在pycharm中的表示形式为:
Index={list}<class ‘list’>: [4, 135, 113, 34, 47, 140, 87, 0, 59, 33,144, 43, 83, 133, 1, 78]
self={_SingleProcessDataLoaderlter}<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x000001F11BF6A7C8>
一、自定义Dataset加载
要让PyTorch能读取自己的数据集,只需要两步:
- 制作图片数据的索引
- 构建Dataset子类
然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中,然后从该txt中读取。
整个读取自己数据的基本流程就是:
- 制作存储了图片的路径和标签信息的txt
- 将这些信息转化为list,该list每一个元素对应一个样本
- 通过getitem函数,读取数据和标签,并返回数据和标签。
首先制作图片数据的索引
就是读取图片路径,标签,保存到txt文件中。
1)一堆相同类别的图片已经在一个文件夹下了,可以用下面这种方法产生一个txt文件。
参考:如何用python生成带图片名称和标签的.txt文件(代码)
2)标签和图片标号都在csv文件里,可以用以下方法。
pytorch 自定义数据集载入(标签在csv文件里)
然后构建Dataset子类
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r') #读取 制作好的txt文件的 图片路径和标签到imgs里
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index] #self.imgs是一个list,self.imgs的一个元素是一个str,包含图片路径,图片标签,这些信息是在init函数中从txt文件中读取的
# fn是一个图片路径
img = Image.open(fn).convert('RGB') #利用Image.open对图片进行读取,img类型为 Image ,mode=‘RGB’
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
- 注意到Dataset类里的初始化中还会初始化transform,transform是一个Compose类型,里边有一个list,list中就会定义了各种对图像进行处理的操作,可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作。
- 在这里我们要知道,一张图片读取进来之后,会经过数据处理(数据增强),最终变成模型的输入数据。这里就有一点需要注意,PyTorch的数据增强是将原始图片进行了处理,并不会生成新的一份图片,而是“覆盖”原图,当采用randomcrop之类的随机操作时,每个epoch输入进来的图片几乎不会是一模一样的,这达到了样本多样性的功能。
最后DataLoader加载即可
- 当自定义Dataset构建好,剩下的操作就交给DataLoader了。在DataLoader中,会触发Mydataset中的getiterm函数读取一个batch大小的图片的数据和标签,并返回,(清晰的底层逻辑见该博客)作为模型真正的输入。
- 最后像下面这样,处理好了前面说的两步之后,得到data,交给DataLoader就很简单了。
train_data = MyDataset(txt='../gender/train1.txt',type = "train", transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=batch_size,
sampler=train_sampler)
二、用torchvision里的ImageFolder图像分类数据集的加载
1、data_transforms中指定了所有图像预处理操作
2、ImageFolder 假设所有的文件按文件夹保存好,每个文件夹下面存贮同一类别的图片,文件夹的名字为分类的名字。
仍然先制作数据源
举个例子,做的另一个项目,花的类别分类。他的数据集如下图。即同一种花都在一个文件夹中,文件夹的名称即为标签类别。
然后利用torchvision里的ImageFolder类
如下面的代码,ImageFolder已经写好datasets了。就像手写数字的datasets一样该datasets里面init,getitem,len魔法函数已实现了,只要保存数据集的格式符合要求,就可以直接使用。
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'valid']}
上面的代码是花分类的项目的train和valid两个数据集
关于datasets官方给的答案是:
All datasets are subclasses of torch.utils.data.Dataset i.e, they have getitem and len methods implemented(都已实现了getitem和len,不需要像第一种自定义方法自己写dataset类了). Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples in parallel using torch.multiprocessing workers.
这里All datasets还有很多,比如ImageNet等,具体可以去pytorch官网查看。
最后再dataloader加载
使用如下。
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
三、torch自带图像分类数据集的处理和加载
像手写数字之类等,都可以在官网查看具体还有哪些数据集,都自带相应的dataset。
import torch
from torchvision import datasets, transforms
import helper
import matplotlib.pyplot as plt
import numpy
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))],)
# Download and load the training data
trainset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.FashionMNIST('F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)