tech share
  • tech-share
  • Engineering
    • 登录鉴权
    • SSR 页面路由
    • npm 版本号
    • 缓存
    • 数据库容灾
    • 动态效果导出 gif
    • Chrome-devtools
    • C 端 H5 性能优化
    • Docker
    • Monorepo 最佳实践
    • 技术架构演化
    • 项目规范最佳实践
    • snowpack
    • 静态资源重试
    • 前端页面渲染分析
    • Git
    • 前端重构
    • 微前端
    • 项目依赖分析
    • 前端监控原理
    • webpack
    • BS 架构与 CS 架构
    • HTTPS
    • package-lock.json 生成逻辑
    • SVN(Subversion)
    • 数据库分类
    • gulp
    • 前端架构
    • Bundle & Bundless
    • 控制反转 IoC
  • JavaScript
    • Javascript 性能
    • JavaScript 原型(2) - 原型与原型链
    • JavaScript 原型(1) - 构造函数
    • JavaScript - Promise
    • ES6 解构赋值
    • 前端离线化
    • Proxy
    • Object.defineProperty()简介
    • TypeScript
  • MachineLearning
    • GAN生成对抗网络
    • 虚拟对抗训练
    • 深度度量学习
    • 原型网络
    • PyTorch优化器
    • 隐马尔可夫模型2
    • Shapley Value 算法
    • Embarassingly Autoencoder算法
    • AutoRec算法及其后续发展
    • 深度学习常用激活函数
    • 序列预测ConvTran算法
    • 联邦学习
    • 深度学习推荐系统算法整理
    • 隐马尔可夫模型
    • 黎曼优化方法
    • FM算法
    • 机器学习常见评价指标
    • VAE算法
    • Adam优化器详解
    • Transformer算法
    • Self-attention 推荐算法
    • CNN 卷积神经网络
    • 图嵌入
    • 集成学习算法
    • RecBole开源框架
    • NCE-PLRec
    • 深度学习初始化方法
    • RNN循环神经网络
    • PyTorch数据处理
    • PyTorch安装和基本操作
    • XGBoost算法
    • NCF算法与简单MF的对比
    • 计算最佳传输
  • CSS
    • 什么是BFC
    • 纯CSS实现可拖动布局
    • 滚动穿透解决方案
  • React
    • React 生命周期
    • React Ref
    • React Hooks
    • SWR
    • React 数据流
    • React 函数式组件和类组件的区别
  • 可视化
    • OffscreenCanvas
    • Echarts 平滑曲线端点为什么不平滑
    • 颜色空间
    • 词云布局解析
    • 3D 数学基础
    • Canvas 图片处理
    • GLGL ES
    • WebGL 中绘制直线
    • Graphics API
    • 现代计算机图形学基础
    • Canvas 灰度
  • Vue
    • Vue2.x全局挂载整理
    • Vue2.6.x源码阅读
      • Vue2.6.x源码阅读 - 2.目录结构分析
      • Vue2.6.x源码阅读 - 4.源码阅读-platform
      • Vue2.6.x源码阅读 - 1.准备工作
      • Vue2.6.x源码阅读 - 5.源码阅读-core-Vue构造函数
      • Vue2.6.x源码阅读 - 7.源码阅读-core-响应式原理
      • Vue2.6.x源码阅读 - 3.源码阅读-shared
      • Vue2.6.x源码阅读 - 6.源码阅读-core-组件挂载
    • Vue + TypeScript Web应用实践
    • Vue2.x指令
    • nextTick()的使用
    • vue-cli2.x 的使用与项目结构分析
    • Vue响应式原理及总结
    • VueX的使用
    • Electron-Vue + Python 桌面应用实践
    • Vite
    • Vue组件通信整理
    • 记录一个问题的探索过程
  • Linux
    • memcg
  • GameDev
    • 游戏中的几种投影视图
    • 从零开始写软渲染器06
    • 从零开始写软渲染器05
    • 从零开始写软渲染器04
    • 从零开始写软渲染器03
    • 从零开始写软渲染器02
    • 从零开始写软渲染器01
    • 从零开始写软渲染器00
    • 现代游戏常用的几种寻路方案(一)
  • Node
    • NPM Dependency
    • Node 优势
    • Node Stream
    • Node 模块系统
  • HTML
    • html5语义与结构元素
  • 跨端
    • Flutter 介绍
  • Golang
    • Golang 基础
  • AR
    • SceneKit
由 GitBook 提供支持
在本页
  • DataLoader
  • dataset
  • sampler
  • collate_fn

这有帮助吗?

  1. MachineLearning

PyTorch数据处理

PyTorch对数据处理有一套标准的接口操作,本篇对其进行一个总结,方便在使用PyTorch进行数据预处理和数据提取时使用。

PyTorch的几个重要的和数据处理相关的类都在torch.utils.data包中。

DataLoader

PyTorch数据处理的核心是DataLoader类。通过以下语句引入。

from torch.utils.data import DataLoader

# DataLoader初始化参数

dataloader = DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
                        batch_sampler=None, num_workers=0, collate_fn=None, 
                        pin_memory=False, drop_last=False, timeout=0, 
                        worker_init_fn=None)

代码中给出了DataLoader类可以提供的参数。

  • dataset 用于提取数据的dataset,通常是继承torch.utils.data.Dataset的对象。

  • batch_size 进行训练每个batch的样本数量。

  • shuffle 是否在每个训练epoch打乱数据顺序。

  • sampler 采样器,用于定义随机获得每个batch的采样策略,如果设置了sampler就必须设置shuffle=False。

  • batch_sampler 和batch_size+shuffle+sampler+drop_last起到作用类似,通常不使用。

  • num_workers 定义是否开启多子进程数据加载。

  • pin_memory 设置为True时可以提高将cpu上的Tensor转到GPU上的速率,但会提高内存消耗。

  • collated_fn 将一个样本列表组装成一个batch的中间处理函数,可自定义。

  • drop_last 是否去除最后一个没有达到batch_size大小的batch

  • timeout 如果设置成正数,会在超时后停止batch采集作业。

  • worker_init_fn workers的初始化函数

其中我们主要通过实现dataset sampler collate_fn三个参数来实现自定义数据加载器的功能。

dataset

dataset参数必须继承Dataset类。

其中Dataset类实际上是用于形成Map类型的datasets。另一种IterableDataset用于流式数据加载的场景,不是很常见。

Dataset类的子类需要实现两个函数:

  1. __getitem__(idx) 该函数表示根据输入的索引idx,从数据中提取一个样本返回。

  2. __len__() 该函数表示返回数据样本的总数

一个简单的已经将数据按照样本一条条排布的数据集如下所示,但实际上,这个类赋予了我们很多的灵活性,可以在函数中进行很多额外的检查和操作。

from torch.utils.data import Dataset

class MyDataset(Dataset):

    def __init__(self, total_data):
        # total_data是一个10000*10的numpy.ndarray类型数据,其中前9列是特征,最后一列是标签
        self.features = total_data[:, :9]
        self.label = total_data[:, -1]

    def __getitem__(self, idx):
        y = self.features[idx]
        x = self.label[idx]

        return {'x':x, 'y':y}

    def __len__(self):
        return len(self.features)

有时候我们也可以调用PyTorch已经实现好的TensorDataset、ConcatDataset等类,避免自己实现的麻烦。

sampler

sampler需要继承torch.utils.data.Sampler类,一般要以data_source作为参数,负责传递索引给Dataset对象。

PyTorch中已经提供的几种采样器名称和功能如下:

  1. torch.utils.data.SequentialSampler(data_source) 按顺序进行采样。

  2. torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None) 根据参数进行一定数量的有或者无放回随机采样。

  3. torch.utils.data.SubsetRandomSampler(indices) 在给定的索引列表中进行再次采样。

  4. torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True) 按照权重进行随机采样。

sampler需要实现两个函数 _len_()和__iter\(),作用分别是给出数据总长度和该次迭代返回的索引值。

下面我们实现一个根据数据集大小决定是否进行随机采样的sampler。

from torch.utils.data import Sampler
import numpy as np

class MySampler(Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        if len(self.data_source) >= 10000:
            return iter(range(len(self.data_source)))
        else:
            return np.random.randint(len(self.data_source))

    def __len__(self):
        return len(self.data_source)

collate_fn

collate_fn是一个callable的对象,其实就是一个函数。这个函数的输入是由sampler按照index在dataset中提取出的batch_size个样本组成的list列表。

collate_fn负责将这个list中的元素组装成一个完整的Tensor矩阵。为此默认的collate_fn函数具有以下三个性质:

  1. 总是会在将每个样本组合起来后,在最左边添加一个大小为batch_size的新维度。

  2. 自动将numpy数组转化成PyTorch的Tensor类型数据。

  3. 会保留数据结构,例如在dataset节中,我们将输出结果设置成一个有x和y两个key的字典值。collate_fn将会保留这个字典形式的结构,只是将每个key的value替换成batch_size长度的Tensor。

一个自定义的在第二个维度上做stack操作的collate_fn如下(假设每个样本格式为字典):

import torch

def my_collate_fn(data)
    x_list = [data[i]['x'] for i in range(len(data))]
    y_list = [data[i]['y'] for i in range(len(data))]

    x_tensor = torch.FloatTensor(x_list)
    y_tensor = torch.FloatTensor(y_list)

    return {'x': x_tensor, 'y': y_tensor}

总结来说,只需要熟练掌握自定义和使用dataset、sampler、collate_fn的方法,可以满足我们各种实验的需求,也会使得我们的代码更加有条理,提高实验效率。

上一页RNN循环神经网络下一页PyTorch安装和基本操作

最后更新于4年前

这有帮助吗?