PyTorch 2.2 大更新:集成 FlashAttention-2,性能提升 2 倍

新的一年,PyTorch 也迎来了重大更新,PyTorch 2.2 集成了 FlashAttention-2 和 AOTInductor 等新特性,计算性能翻倍。

继去年十月份的 PyTorch 大会发布了 2.1 版本之后,全世界各地的 521 位开发者贡献了 3628 个提交,由此形成了最新的 PyTorch 2.2 版本。

新的版本集成了 FlashAttention-2,使得 scaled_dot_product_attention (SDPA)相较于之前的版本有了约 2 倍的性能提升。

PyTorch 2.2 还引入了一个新的 TorchInductor 提前扩展,称为 AOTInductor,旨在为非 python 服务器端编译和部署 PyTorch 程序。

PyTorch 中的 torch.distributed 支持了一个叫作 device_mesh 的新抽象,用于初始化和表示 ProcessGroups。

另外,PyTorch 2.2 提供了一个标准化的、可配置的日志记录机制,——TORCH_LOGS。

PyTorch 2.2 还对 torch.compile 做了许多改进,包括改进了对编译优化器的支持,以及 TorchInductor 融合和布局优化。

最后值得注意的是,PyTorch 将放弃对 macOS x86 的支持,PyTorch 2.2.x 是支持 macOS x64 的最后一个版本。

PyTorch 2.2 新特性

首先请注意,如果从源代码构建 PyTorch 2.2,需要 GCC 9.4 或更高版本,PyTorch 代码库已从 C++ 14 迁移到 C++ 17。

FlashAttention-2

FlashAttention-2 通过优化 GPU 上不同线程块和 warps 之间的工作分区,来解决占用率低或不必要的共享内存读写。

FlashAttention-2 调整了算法以减少非 matmul 的计算量,同时提升了 Attention 计算的并行性(即使是单个头,也可以跨不同的线程块,以增加占用率),在每个线程块中,优化 warps 之间的工作分配,以减少通过共享内存的通信。

PyTorch 2.2 将 FlashAttention 内核更新到了 v2 版本,不过需要注意的是,之前的 Flash Attention 内核具有 Windows 实现,Windows 用户可以强制使用 sdp_kernel,仅启用 Flash Attention 的上下文管理器。

而在 2.2 中,如果必须使用 sdp_kernel 上下文管理器,请使用 memory efficient 或 math 内核(在 Windows 上)。

在 FlashAttention-2 的加持之下,torch.nn.functional.scaled_dot_product_attention 的速度提升了大约 2 倍,在 A100 GPU 上达到了理论计算峰值的 50%-73%。

AOTInductor

AOTInductor 是 TorchInductor 的扩展,用于处理导出的 PyTorch 模型,对其进行优化,并生成共享库以及其他相关工件。

这些编译的工件可以部署在非 Python 环境中,经常用于服务器端的推理。

下面的示例演示了如何调用 aot_compile 将模型转换为共享库。

AOTInductor 支持与 Inductor 相同的后端,包括 CUDA、ROCm 和 CPU。

TORCH_LOGS

PyTorch 2.2 提供了一个标准化的、可配置的日志记录机制,可用于分析各种子系统的状态,例如编译和分布式操作可以通过 TORCH_LOGS 环境变量启用日志。比如通过在命令行中修改环境变量:

将 TorchDynamo 的日志级别设置为 logging.ERROR,将 TorchInductor 的日志级别设置为 logging.DEBUG

当然也可以在代码中以 API 的形式使用:

torch.distributed.device_mesh

PyTorch 2.2 引入了一个新的抽象,用于表示分布式并行中涉及的 ProcessGroups,称为 torch.distributed.device_mesh

为分布式训练设置分布式通信器(NCCL)是一件麻烦的事情。用户需要编写不同并行度的工作负载,并为每个并行度手动设置和管理 NCCL 通信器(ProcessGroup )。

这个过程可能很复杂,容易出错。而 DeviceMesh 可以简化此过程,使其更易于管理。

DeviceMesh 是管理 ProcessGroup 的更高级别的抽象。它允许用户毫不费力地创建节点间和节点内进程组,而不必担心如何为不同的子进程组正确设置等级。

例如,数组的其中一个维度可以表示 FSDP 中的数据并行(data parallelism),而另一个维度可以表示 FSDP 中的张量并行(tensor parallelism)。

用户还可以通过 DeviceMesh 轻松管理底层 process_groups,以实现多维并行。

DeviceMesh 在处理多维并行性(如 3D 并行)时很有用。如上图所示,当你的并行解决方案需要跨主机和每个主机内部进行通信时,可以创建一个 2D 网格,用于连接每个主机中的设备,并以同构设置将每个设备与其他主机上的对应设备连接起来。

借助 init_device_mesh () ,我们可以在短短两行内完成上面这个 2D 设置:

而如果不使用 DeviceMesh,我们大概需要自己写下面这一堆代码:

当然,如果需要,我们仍然可以访问底层 ProcessGroup:

优化器的改进

大概有以下几点:

编译优化器在所有基准测试中都提高了性能:HuggingFace +18%、TorchBench +19%、TIMM +8% E2E;

编译的优化器增加对 cudagraphs 的支持;

对测试套件中所有模型进行平均,每个测试套件的基准测试平均编译时间增加约 40 秒;正在进行的优化可能会将其降低到 30 秒以下。

用于多张量优化器编译的 inductor 中缺少的主要功能是 foreach 算子的高效编码生成。

在调度器内部,将所有在下放过程中注册的缓冲区列表凝聚到 ForeachKernelSchedulerNodes 中(FusedSchedulerNode 的子类)。

为了检查融合是否合法,每个内部 SchedulerNode 执行的写操作必须与消费 SchedulerNode 在同一列表索引处的读操作相匹配。

此外,正常的垂直融合规则必须允许在消费者和生产者 SchedulerNode 列表的每个索引处进行融合。

如果满足了这些条件,ForeachKernelSchedulerNode 将垂直融合成一个 ForeachKernelSchedulerNode,其中每个列表上的相应点操作都将被融合。

通过实现这种融合,可以将一系列 foreach 运算融合到单个内核中,从而实现多张量优化器的完全融合。

性能改进

TorchInductor 中添加了许多性能优化,包括对 torch.concat 的水平融合支持、改进的卷积布局优化、以及改进 scaled_dot_product_attention 模式匹配。

PyTorch 2.2 还包括 aarch64 的许多性能增强,包括对 mkldnn 权重预打包的支持、改进的 ideep 基元缓存,以及通过对 OneDNN 的固定格式内核改进,来提高推理速度。

参考资料:

  • https://pytorch.org/blog/pytorch2-2/

Published by

风君子

独自遨游何稽首 揭天掀地慰生平