AI Infra 学习路线


AI Infra(人工智能基础设施)是大模型时代壁垒最高、最核心的技术高地。本文从前置基础到推理部署,系统梳理 AI Infra 的完整学习路线,为每个模块列出需要掌握的知识点、推荐学习资料以及可量化的检验标准,帮助从业者建立体系化的知识树。


🗺️ 全景概览:三层架构

AI Infra 的本质是 "用系统工程释放硬件算力"。自底向上分为三个核心层级加一个前置知识层:

层级名称核心关注点
第零层前置知识编程语言、数学基础、Transformer 架构、PyTorch、通信拓扑
第一层CUDA编程与算子优化GPU架构、存储层次、Kernel编写、FlashAttention、AI编译器
第二层分布式训练数据并行、3D并行、ZeRO、混合精度
第三层推理与部署KV Cache、PagedAttention、量化、Speculative Decoding

所有的优化都是在 "计算、通信、显存" 这个不可能三角中做取舍:ZeRO 是用通信换显存;重计算(Activation Checkpointing)是用计算换显存;量化是用精度换显存和带宽。学习时始终问自己:这个技术牺牲了什么,换取了什么?


📖 第零层:前置知识

0.1 知识点

编程语言

  • Python:熟练使用面向对象、装饰器、生成器、多进程/多线程、性能 profiling
  • C/C++:理解指针、内存管理、编译链接过程,能读懂 C++ 项目代码
  • Linux 基础:命令行操作、Shell 脚本、进程管理、环境变量配置

数学基础

  • 线性代数:矩阵乘法、转置、分块矩阵、特征值分解基本概念。看到 (B, S, H) × (H, V) 能立刻知道结果是 (B, S, V)
  • 基础概率论:概率分布、期望、方差、Softmax 的概率解释、交叉熵损失含义
  • 微积分(了解):链式法则、梯度的含义

Transformer 架构

必须理解:

  • Self-Attention:Q、K、V 的含义与计算过程(QK^T → scale → softmax → PV),Attention 的计算复杂度为 O(N2)O(N^2)
  • 前馈网络(FFN):两层线性变换 + 激活函数
  • 位置编码:Sinusoidal、RoPE 等
  • LayerNorm:Pre-Norm vs Post-Norm 的区别
  • 完整前向过程:从 token embedding 开始,能逐步跟踪数据在一个 Transformer Block 中的流转,说清每一步的输入输出维度

PyTorch 框架

  • Tensor 操作、自动微分(autograd)、Module / Parameter 的组织方式
  • 训练循环:DataLoader → forward → loss → backward → optimizer.step
  • 基本调试:torch.cuda.memory_summary()torch.profiler

通信拓扑

  • 单机内部:NVLink / NVSwitch 带宽与拓扑
  • 多机间:InfiniBand(IB)网络、RoCE 协议
  • 集合通信原语:AllReduce、AllGather、ReduceScatter 的含义与通信量公式
  • NCCL:NVIDIA 集合通信库的基本用法

0.2 推荐资料

类型资料说明
论文Attention Is All You NeedTransformer 原始论文,必读
教程The Illustrated Transformer (Jay Alammar)图文并茂的 Transformer 入门
工具Andrej Karpathy:Let's build GPT from scratch从零手写 GPT,每个模块都过一遍
教程PyTorch 官方教程(60 Minute Blitz)PyTorch 快速入门
书籍3Blue1Brown:线性代数的本质(视频系列)建立线性代数几何直觉
官方文档NVIDIA NCCL 文档集合通信原语与多卡编程

0.3 检验标准

  • Transformer 白板默写:不看资料,能画出一个 Decoder Block 完整结构,标注每步输入输出维度
  • 维度推导:给定 7B 模型配置(hidden_dim=4096, num_heads=32, num_layers=32, vocab_size=32000),能手算总参数量(误差 ≤20%)
  • PyTorch 训练脚本:能独立写出完整训练循环(含 DataLoader、forward、loss、backward、optimizer step、checkpoint),在 GPU 上跑通
  • Linux 日常:SSH 登录、tmux、conda/pip 管理、nvidia-smi、git、bash 脚本

💻 第一层:CUDA编程与算子优化

1.1 知识点

GPU 硬件架构

把 GPU 想象成一座拥有数千个简单工人的超级工厂——每个工人(CUDA Core)只会基本加减乘除,但胜在人多,吞吐量远超 CPU。

  • SM(流多处理器)、Tensor Core、CUDA Core 的区别与协作
  • 主流 GPU 规格:A100 / H100 / H200 的算力、显存带宽、HBM 容量
  • Memory Wall:显存带宽瓶颈往往比算力瓶颈更致命
  • 存储层次:寄存器 > 共享内存 > L1/L2 Cache > HBM > 主机内存

CUDA 编程基础

  • 编程模型:Grid / Block / Thread 层级,线程索引计算
  • 内存模型:全局内存、共享内存、寄存器、常量内存
  • 关键概念:Warp(32线程的最小调度单位)、Bank Conflict、Coalesced Access、Occupancy

常见算子实现与优化

  • Reduce:并行归约(Warp Shuffle、多级归约)
  • GEMM:分块、向量化、Shared Memory Tiling、Tensor Core
  • Softmax:Online normalizer calculation
  • 算子融合:将多个小算子合并为一个 kernel,减少全局内存读写

Attention 算子

  • FlashAttention V1/V2:通过 tiling 减少 HBM 访问——把大桌子上的拼图分成小块,每次只搬一小块到手边,避免把所有碎片一股脑倒出来。HBM 读写从 O(N2)O(N^2) 降到 O(N)O(N)
  • FlashAttention-3:在 Hopper 架构上进一步拉高利用率
  • Flash-Decoding:面向 Decode 阶段的 Attention 加速
  • PagedAttention CUDA Kernel:vLLM 中 PagedAttention 的底层实现

AI 编译器

  • Triton:OpenAI 开源的 GPU 编程语言,大幅降低高效算子编写门槛
  • torch.compile:PyTorch 2.x 的编译模式,理解 Graph Break 与性能收益

1.2 推荐资料

类型资料说明
入门教程小小将:CUDA编程入门极简教程CUDA 零基础入门
官方文档NVIDIA CUDA Programming GuideCUDA 编程权威参考
GEMM猛猿:从啥也不会到CUDA GEMM优化从基础分块到极致优化
AttentionFlashAttention V1/V2 PaperMemory-aware Attention 里程碑
解读猛猿:图解FlashAttention V1/V2 系列适合新手入门的图文解读
编译器Triton 官方教程GPU 编程新范式
工具Nsight Systems User GuideCPU-GPU 交互分析
工具Nsight Compute Profiling GuideKernel 级下钻,定位瓶颈

1.3 检验标准

  • 硬件参数直觉:拿到 H100,不查资料能说出 HBM 容量(80GB)、带宽(~3.35TB/s)的量级
  • Reduce 三连:从全局内存原子加 → 共享内存+树形归约 → Warp Shuffle,三版本跑 Nsight Compute 对比
  • GEMM 分块:实现基于 Shared Memory Tiling 的 GEMM kernel,达到 cuBLAS 50% 以上性能
  • FlashAttention 白板推导:能在白板上画出 tiling 过程,说清为什么 HBM 读写从 O(N2)O(N^2) 降到 O(N)O(N)
  • Profiling 实战:用 Nsight Systems 定位 GPU idle gap 来源;用 Nsight Compute 判断 kernel 是 memory bound 还是 compute bound

🏋️ 第二层:分布式训练

打个比方,训练千亿参数大模型就像抄写一本数万页的百科全书——数据并行是把同一本书复印多份、每人抄不同章节内容然后汇总;张量并行是把每一页拆成几列、每人只抄自己那几列;流水线并行则是第一个人抄完第一章就传给第二人继续。

2.1 知识点

优化器

  • Adam / AdamW:每个参数维护两个状态——一阶动量(梯度的指数移动平均)和二阶动量(梯度平方的指数移动平均)
  • 优化器状态的显存开销:以 AdamW + 混合精度为例,每个参数额外需要 FP32 参数副本(4B)+ 一阶动量(4B)+ 二阶动量(4B)= 12字节/参数。7B 模型的优化器状态就要占 ~84GB

数据并行

  • DDP(DistributedDataParallel):多进程数据并行,理解 AllReduce 梯度同步
  • FSDP:PyTorch 原生的 ZeRO-3 实现

模型并行(3D 并行)

  • 张量并行(TP):将矩阵乘法沿特定维度切分到多卡,通信密集,通常限于单机
  • 流水线并行(PP):将模型不同层切分到不同机器
  • 序列并行(SP):沿序列维度切分,与 TP 配合减少激活显存

显存优化

ZeRO 系列(好比合租房里每人只存自己那份家具,需要时互相借用):

  • ZeRO-1:优化器状态切分
  • ZeRO-2:优化器状态 + 梯度切分
  • ZeRO-3:优化器状态 + 梯度 + 参数切分(用通信换显存)

混合精度训练:FP16 / BF16 / FP8 训练,减少显存占用。BF16 比 FP16 指数位更宽(8位 vs 5位),动态范围接近 FP32,不容易 overflow/underflow。

Activation Checkpointing:只保存部分激活值,需要时重新计算,用计算换显存。

2.2 推荐资料

类型资料说明
论文Megatron-LM PaperTP 与 PP 原理的里程碑论文
论文ZeRO Paper(DeepSpeed)显存优化的核心方法
文档DeepSpeed 官方文档ZeRO 配置与使用
文档PyTorch DDP / FSDP 教程原生分布式训练入门
论文DeepSeek V2 技术报告MLA 注意力机制
论文DeepSeekMoE PaperMoE 架构设计

2.3 检验标准

  • 显存账本:拿到 7B 模型,能口算 FP16 下参数占 ~14GB、Adam 优化器状态占 ~56GB,判断单卡 80GB 能否放下完整训练状态
  • ZeRO 拆解:能一句话讲清 ZeRO-2 和 ZeRO-3 的差异(参数是否切分,通信量差异)
  • DDP 改造:拿到单卡训练脚本,30 分钟内改成 DDP 多卡版本并跑通
  • 3D 并行拓扑:给 64 卡集群(8节点×8卡),能设计出 TP=8(机内)、PP=4(跨机)、DP=2 的并行方案,说明为什么 TP 不能跨机

🚀 第三层:推理与部署

训练是"教会模型知识",推理是"让模型上考场答题"——最重要的是答题速度和同时服务多少考生。

3.1 LLM 推理基础

  • 两阶段:Prefill(处理输入,compute-bound)与 Decode(逐 token 生成,memory-bound)
  • KV Cache:自回归生成的"草稿纸",把已算过的 K/V 缓存起来避免重复计算,但会随序列长度线性增长显存
  • 关键指标:TTFT(首 token 延迟)、TPOT(每 token 延迟)、吞吐量(token/s)、P50/P95 尾延迟

KV Cache 估算:给定 LLaMA-2-7B(32层、32头、head_dim=128),上下文长度 4096,batch_size=16,FP16:

KV=2×32×32×128×4096×16×2B32GBKV = 2 \times 32 \times 32 \times 128 \times 4096 \times 16 \times 2\text{B} \approx 32\text{GB}

3.2 推理引擎

  • PagedAttention:vLLM 提出的虚拟内存分页思想管理 KV Cache,解决碎片化问题
  • Continuous Batching:动态组批,请求随到随处理(类似网约车拼单,随到随拼)
  • Prefix Cache / RadixAttention:复用已计算的 KV Cache,优化重复前缀场景
框架核心特性适用场景
vLLMPagedAttention、Continuous Batching、Prefix Cache通用推理服务,社区活跃
SGLangRadixAttention、cFSM 结构化输出加速复杂 Agent、多轮生成
TensorRT-LLMInflight Batching、深度硬件优化追求极限性能、NVIDIA 生态

3.3 量化

量化的本质是把高清照片压缩成缩略图——用更少的比特位表示权重,省下显存和带宽,代价是精度会有一定损失。

  • W8A8(SmoothQuant):将 activation 的 outlier 难题转移到 weights,工程友好
  • INT4(GPTQ / AWQ):只量化权重到 3/4-bit,减少显存和带宽占用
  • KV Cache 量化(KIVI):2-bit 量化,长上下文场景效果显著
目标:省显存?省带宽?提吞吐?
├─ 通用、工程友好 → W8A8 (SmoothQuant)
├─ 更省显存/带宽 → INT4 weight-only (AWQ/GPTQ)
└─ 长上下文/大并发 → KV Cache 量化 (KIVI)

3.4 Speculative Decoding

好比让实习生先快速起草一段文字,再让资深主编一次性审阅:猜对的直接用,猜错的当场改,比主编逐字逐句从头写快得多。

  • Speculative Sampling:小模型批量猜测 → 大模型一次性验证,保证分布无偏
  • Medusa:多个 Decoding Heads 并行预测多 token
  • EAGLE-2:动态 Draft Tree,更激进地产生可接受 token

正确性保证:rejection sampling 机制保证接受的 token 严格服从 target model 的分布。

3.5 性能分析工具

  • torch.profiler:PyTorch 官方 profiler,定位算子与 shape
  • Nsight Systems:CPU-GPU 交互全链路分析,找到"哪里慢"
  • Nsight Compute:Kernel 级分析,找到"为什么慢"
  • GenAI-Perf:LLM 指标一站式输出(TTFT/TPOT/throughput)

🧭 新人破局指南

推荐学习路径

基础阶段(0-3个月)

  1. 完成第零层全部检验标准
  2. 学习 CUDA 编程基础,能写简单的 Reduce / GEMM kernel
  3. 用 PyTorch DDP 将训练分布到两张卡上

专项深入(3-6个月)

  1. 精读四篇里程碑论文并对照代码:Megatron-LM、ZeRO、FlashAttention、vLLM
  2. 参与开源项目(vLLM、DeepSpeed、SGLang)

工程实践(6个月以上)

  1. 在 GPU 集群上部署百亿/千亿参数模型,优化端到端性能
  2. 建立完整的性能分析与回归体系

核心权衡思维

优化技术牺牲了什么换取了什么
ZeRO通信带宽显存空间
Activation Checkpointing计算时间显存空间
量化精度显存 + 带宽 + 吞吐
Speculative DecodingPrefill 开销Decode 速度
FlashAttention实现复杂度显存 + 速度
Prefill/Decode 解耦系统复杂度尾延迟 + goodput

📚 核心参考论文