PyTorch数据处理
PyTorch对数据处理有一套标准的接口操作,本篇对其进行一个总结,方便在使用PyTorch进行数据预处理和数据提取时使用。
PyTorch的几个重要的和数据处理相关的类都在torch.utils.data包中。
DataLoader
PyTorch数据处理的核心是DataLoader类。通过以下语句引入。
代码中给出了DataLoader类可以提供的参数。
dataset 用于提取数据的dataset,通常是继承torch.utils.data.Dataset的对象。
batch_size 进行训练每个batch的样本数量。
shuffle 是否在每个训练epoch打乱数据顺序。
sampler 采样器,用于定义随机获得每个batch的采样策略,如果设置了sampler就必须设置shuffle=False。
batch_sampler 和batch_size+shuffle+sampler+drop_last起到作用类似,通常不使用。
num_workers 定义是否开启多子进程数据加载。
pin_memory 设置为True时可以提高将cpu上的Tensor转到GPU上的速率,但会提高内存消耗。
collated_fn 将一个样本列表组装成一个batch的中间处理函数,可自定义。
drop_last 是否去除最后一个没有达到batch_size大小的batch
timeout 如果设置成正数,会在超时后停止batch采集作业。
worker_init_fn workers的初始化函数
其中我们主要通过实现dataset sampler collate_fn三个参数来实现自定义数据加载器的功能。
dataset
dataset参数必须继承Dataset类。
其中Dataset类实际上是用于形成Map类型的datasets。另一种IterableDataset用于流式数据加载的场景,不是很常见。
Dataset类的子类需要实现两个函数:
__getitem__(idx) 该函数表示根据输入的索引idx,从数据中提取一个样本返回。
__len__() 该函数表示返回数据样本的总数
一个简单的已经将数据按照样本一条条排布的数据集如下所示,但实际上,这个类赋予了我们很多的灵活性,可以在函数中进行很多额外的检查和操作。
有时候我们也可以调用PyTorch已经实现好的TensorDataset、ConcatDataset等类,避免自己实现的麻烦。
sampler
sampler需要继承torch.utils.data.Sampler类,一般要以data_source作为参数,负责传递索引给Dataset对象。
PyTorch中已经提供的几种采样器名称和功能如下:
torch.utils.data.SequentialSampler(data_source) 按顺序进行采样。
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None) 根据参数进行一定数量的有或者无放回随机采样。
torch.utils.data.SubsetRandomSampler(indices) 在给定的索引列表中进行再次采样。
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True) 按照权重进行随机采样。
sampler需要实现两个函数 _len_()和__iter\(),作用分别是给出数据总长度和该次迭代返回的索引值。
下面我们实现一个根据数据集大小决定是否进行随机采样的sampler。
collate_fn
collate_fn是一个callable的对象,其实就是一个函数。这个函数的输入是由sampler按照index在dataset中提取出的batch_size个样本组成的list列表。
collate_fn负责将这个list中的元素组装成一个完整的Tensor矩阵。为此默认的collate_fn函数具有以下三个性质:
总是会在将每个样本组合起来后,在最左边添加一个大小为batch_size的新维度。
自动将numpy数组转化成PyTorch的Tensor类型数据。
会保留数据结构,例如在dataset节中,我们将输出结果设置成一个有x和y两个key的字典值。collate_fn将会保留这个字典形式的结构,只是将每个key的value替换成batch_size长度的Tensor。
一个自定义的在第二个维度上做stack操作的collate_fn如下(假设每个样本格式为字典):
总结来说,只需要熟练掌握自定义和使用dataset、sampler、collate_fn的方法,可以满足我们各种实验的需求,也会使得我们的代码更加有条理,提高实验效率。
最后更新于