5.7 Comparison of U-Net and Transformer architectures
这篇主要介绍Unet 与 Transformer架构的区分,方便我们后续的在实际调整参数之中的理解
Last updated
这篇主要介绍Unet 与 Transformer架构的区分,方便我们后续的在实际调整参数之中的理解
Last updated
在过去的几年里,生成扩散模型主要是基于 U-Net 架构发展的。随着 OpenAI 发布 Sora 以及 Stability AI 发布 Stable Diffusion 3,基于 Transformer 架构的扩散模型 (如 DiT) 成为目前备受关注的亮点。
文本生成图像领域蓬勃发展,基于Unet模型和Transformer模型在各自训练范式上大放异彩。
自 21 年 Openai 提出 GLIDE,基于 Unet 网络作为主干网络成为文本生成图像 pipeline 的标准配置。在像素空间亦或隐空间,UNet 通过上/下采样编/解码信息,逐层压缩提升优化效率,逐层解码恢复图像信息,通过注意力机制引入文本信息引导。
Stability AI 开源的 Stable Diffusion 作为经典 baseline,将图像先通过 VAE 编码到隐空间,再使用 Unet 在隐空间进行扩散/逆扩散过程。2023 年 2 月,Stable Diffusion 迎来升级,Stable Diffusion XL 扩大参数量并级联 Unet,获得社区青睐。
随着算力/模型的增加,Unet 模型容易陷入性能瓶颈,且难以灵活适配多模态任务需求(文本/图像/视频/3D)。相比大语言模型的主干网络 Transformer,存在 scaling laws,参数/数据量越多,性能越强。VIT 验证了 Transformer 在视觉任务的强大建模能力后,越来越多的研究人员开始聚焦于文本生成图像领域中的 Transformer 建模方案。
2023 年 Meta 提出 DiT,将 Transformer 用于扩散模型,将图像通过 patchify 序列化为离散的 token,同时在序列上拼接标签 embedding,通过 n 个 DiT Block 进行扩散/逆扩散过程,验证了基于 Transformer 框架下类别控制图像生成的有效性。
Huawei 在 2023 年提出的 PixArt- alpha 进一步优化 DiT 框架,将标签引导升级为文本引导,通过 T5 text encoder 编码文本信息,使用 cross-attention 注意力机制将文本信息与图像信息融合,需要注意的是,为了提升优化效率,PixAer- alpha 使用 class-conditional model 在 imagenet 数据集上的预训练模型作为权重初始化。使用 LLaVA 构建高密度 text-image pairs 训练文本生成图像模型,使用准确且信息丰富的数据提高了优化效率。
Tsinghua 在 2023 年提出的 U-ViT,在串联 Transformer 网络中引入 Long-skip connection,提升了优化效率。同时,U-ViT 将所有的输入,包括 Timesteps,引导条件和图像 patch 均看作 token,展现了 Transformer 建模的可扩展性,为未来引入多模态建模提供了实验基础。
综上,U-Net 和 Transformer 模型成为文生图领域主干网络的两种重要解决方案。前者以 Stable diffusion 模型为代表,保留了 U-Net 的多级上/下采样,借助 Transformer block 进行条件注入控制。当前社区已支持 Billion 级别的模型建模 Billion 级别的数据,极大提升生成图像质量,且支持多分辨率生成,灵活支持应用需求。后者以 DiT 模型为代表,采用一次性 Patchify,将图像信息序列化为离散的 token,通过 Self/Cross Attention 进行条件控制。由于 Transformer 网络天然的易扩展性,在算力/数据不断提升的未来,必然成为社区研究重点和热点。近期引爆 AI 社区的文生视频模型 Sora 和最新的文生图模型 Stable Diffusion 3 展现了惊艳的视频/图像生成效果,均使用 Transformer 模型作为主干网络核心,进一步佐证了 Transformer 模型的发展空间。如何发挥模型的扩展能力,以提升生成图像质感,并支持丰富的应用需求,这对研究者提出了较大的挑战。
这里我们从多个方面对比 U-Net 和Transformer 架构的特点:
U-Net
Transformer
模型结构
● CNN 结构,通常需要加额外的Transformer Block 增强表达能力.
● 多次下采样
● Long skip connection
● 统一的Attention + FFN 结构,易于优化和加速
● Patchify
● Long skip connection
表达能力
● 具有 inductive bias,善于建模局部特征
● 无 inductive bias,善于建模长距离依赖, 可以更好地理解图像上下文
位置编码
● 无 (依赖卷积层的 padding)
● 二维 RoPE 位置编码
多分辨率
● 原生支持
● 位置编码插值
条件控制
● 借助Transformer Block 来实现
● Scale & Shift
● 采用Self-Attention
扩大规模
● 容易遇到性能瓶颈
● 存在scaling laws,参数/数据量越多,性能越强
扩展能力
● 需要为不同模态数据设计扩展方式
● 有较强的扩展性,统一的token 表示,易于实现多模态数据表达