tech share
  • tech-share
  • Engineering
    • 登录鉴权
    • SSR 页面路由
    • npm 版本号
    • 缓存
    • 数据库容灾
    • 动态效果导出 gif
    • Chrome-devtools
    • C 端 H5 性能优化
    • Docker
    • Monorepo 最佳实践
    • 技术架构演化
    • 项目规范最佳实践
    • snowpack
    • 静态资源重试
    • 前端页面渲染分析
    • Git
    • 前端重构
    • 微前端
    • 项目依赖分析
    • 前端监控原理
    • webpack
    • BS 架构与 CS 架构
    • HTTPS
    • package-lock.json 生成逻辑
    • SVN(Subversion)
    • 数据库分类
    • gulp
    • 前端架构
    • Bundle & Bundless
    • 控制反转 IoC
  • JavaScript
    • Javascript 性能
    • JavaScript 原型(2) - 原型与原型链
    • JavaScript 原型(1) - 构造函数
    • JavaScript - Promise
    • ES6 解构赋值
    • 前端离线化
    • Proxy
    • Object.defineProperty()简介
    • TypeScript
  • MachineLearning
    • GAN生成对抗网络
    • 虚拟对抗训练
    • 深度度量学习
    • 原型网络
    • PyTorch优化器
    • 隐马尔可夫模型2
    • Shapley Value 算法
    • Embarassingly Autoencoder算法
    • AutoRec算法及其后续发展
    • 深度学习常用激活函数
    • 序列预测ConvTran算法
    • 联邦学习
    • 深度学习推荐系统算法整理
    • 隐马尔可夫模型
    • 黎曼优化方法
    • FM算法
    • 机器学习常见评价指标
    • VAE算法
    • Adam优化器详解
    • Transformer算法
    • Self-attention 推荐算法
    • CNN 卷积神经网络
    • 图嵌入
    • 集成学习算法
    • RecBole开源框架
    • NCE-PLRec
    • 深度学习初始化方法
    • RNN循环神经网络
    • PyTorch数据处理
    • PyTorch安装和基本操作
    • XGBoost算法
    • NCF算法与简单MF的对比
    • 计算最佳传输
  • CSS
    • 什么是BFC
    • 纯CSS实现可拖动布局
    • 滚动穿透解决方案
  • React
    • React 生命周期
    • React Ref
    • React Hooks
    • SWR
    • React 数据流
    • React 函数式组件和类组件的区别
  • 可视化
    • OffscreenCanvas
    • Echarts 平滑曲线端点为什么不平滑
    • 颜色空间
    • 词云布局解析
    • 3D 数学基础
    • Canvas 图片处理
    • GLGL ES
    • WebGL 中绘制直线
    • Graphics API
    • 现代计算机图形学基础
    • Canvas 灰度
  • Vue
    • Vue2.x全局挂载整理
    • Vue2.6.x源码阅读
      • Vue2.6.x源码阅读 - 2.目录结构分析
      • Vue2.6.x源码阅读 - 4.源码阅读-platform
      • Vue2.6.x源码阅读 - 1.准备工作
      • Vue2.6.x源码阅读 - 5.源码阅读-core-Vue构造函数
      • Vue2.6.x源码阅读 - 7.源码阅读-core-响应式原理
      • Vue2.6.x源码阅读 - 3.源码阅读-shared
      • Vue2.6.x源码阅读 - 6.源码阅读-core-组件挂载
    • Vue + TypeScript Web应用实践
    • Vue2.x指令
    • nextTick()的使用
    • vue-cli2.x 的使用与项目结构分析
    • Vue响应式原理及总结
    • VueX的使用
    • Electron-Vue + Python 桌面应用实践
    • Vite
    • Vue组件通信整理
    • 记录一个问题的探索过程
  • Linux
    • memcg
  • GameDev
    • 游戏中的几种投影视图
    • 从零开始写软渲染器06
    • 从零开始写软渲染器05
    • 从零开始写软渲染器04
    • 从零开始写软渲染器03
    • 从零开始写软渲染器02
    • 从零开始写软渲染器01
    • 从零开始写软渲染器00
    • 现代游戏常用的几种寻路方案(一)
  • Node
    • NPM Dependency
    • Node 优势
    • Node Stream
    • Node 模块系统
  • HTML
    • html5语义与结构元素
  • 跨端
    • Flutter 介绍
  • Golang
    • Golang 基础
  • AR
    • SceneKit
由 GitBook 提供支持
在本页
  • VAE算法
  • VAE基本原理
  • VAE代码实现

这有帮助吗?

  1. MachineLearning

VAE算法

上一页机器学习常见评价指标下一页Adam优化器详解

最后更新于4年前

这有帮助吗?

VAE算法

在过去我们介绍过VAE是一种用于自动生成图片的算法,可以理解成是对抗生成网络GAN的前身,本次我们对其做详细介绍。

VAE基本原理

VAE即变分自编码器(Variational auto-encoder,VAE),于2013年由Diederik P.Kingma和Max Welling提出。原论文参见

我们知道auto-encoder的基本原理就是将输入x通过encoder映射到一个隐藏层z,然后再将z输入到decoder中重新构建出x',然后通过损失函数让重构出的x'和x尽可能一致。之后我们就可以利用这种方法学到的参数进行后续工作。

而VAE在Auto-encoder的基础上,试图增加学习到的信息量,使得模型具备以下两种能力。

  1. encoder -> 识别模型

  2. decoder -> 生成模型

因此VAE对于auto-encoder 最基本的改动就是在隐藏层,将原本固定的一个隐藏向量替换为了一个隐藏向量的正态分布。

图中很明显可以看出隐藏层的z变成了一个正态分布分布,通过z_mean和z_log_var两组变量作为该正态分布的参数。深度学习时代使得我们可以轻易的用两个分立的全连接层对均值和标准差分别进行参数估计。

此处也有两个细节:

  1. 该处的log_var 表示的是方差的e对数,这是因为方差具有必须大于0的特点,如果直接通过神经网络训练还需要添加额外的激活函数,而对数则使得网络输出的范围扩展到了实数域,不需要添加额外的激活函数。

  2. decoder的输入来自于z分布的采样,由于采样这个步骤不可导,因此使用了reparemerization的技巧。先从标准正态分布N(0, 1)中采样一个epsilon,然后再令z = z_mean + sigma * epsilon。

    这个计算过程中会将epsilon作为常数,从而可以将梯度反向传播到z_mean 和z_log_var上。

损失函数设计方面,主要涉及两个部分:

  1. 基本的auto-encoder loss,这部分和普通AE相同。

  2. KL散度。这部分loss主要目的是为了避免模型的退化。模型有可能会将方差部分全部变为0,从而将模型退化成普通的auto-encoder模型。因此我们限制z的分布和标准正态分布接近。

    即KL(N(μ,σ^2)‖N(0,I)

VAE代码实现

如果省略前后的卷积部分,我们会发现VAE的模型在代码实现上其实非常简单。

z_mean和z_log_var的拟合:

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

reparameter操作:

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Will a single z be enough ti compute the expectation
        for the loss??
        :param mu: (Tensor) Mean of the latent Gaussian
        :param logvar: (Tensor) Standard deviation of the latent Gaussian
        :return:
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

损失函数:

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        self.num_iter += 1
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

        recons_loss =F.mse_loss(recons, input)

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

代码主要参考github上AntixK的项目PyTorch_VAE, 参见

https://github.com/AntixK/PyTorch-VAE/blob/master/models/beta_vae.py
Auto-encoding variational bayes
AutoEncoder
VAE
VAE_KL