手撕DiT

手撕 DiT

实现一个轻量级的 DiT (Diffusion Transformer) 架构,结合了目前前沿的 Flow Matching (流匹配) 范式,并在自定义的模拟数据集上完成了完整的训练与欧拉法采样闭环。

DiT 架构核心原理

DiT (Diffusion Transformer) 的核心思想非常简单粗暴但极其有效:用 Transformer 替换掉传统 Diffusion 模型(如 DDPM)中常用的 U-Net 架构。 它的成功证明了 Transformer 在视觉生成任务上同样具备强大的 Scaling Law。DiT 的架构主要依赖以下几个关键机制:

  1. 图像块化 (Patchify):继承自 ViT 的思想。DiT 将输入的潜变量空间切分成不重叠的 Patch,并展平为一维序列,将图像生成转换为标准的序列建模任务。
  2. 条件注入 (adaLN-Zero):模型需要知道当前时间步 tt 和条件 yy。DiT 抛弃了交叉注意力,使用 adaLN (自适应层归一化)。条件向量被映射为缩放 γ\gamma、平移 β\beta 和残差门控 α\alpha。初始化时,这些参数全为 0,这保证了深层 Transformer 早期训练的极高稳定性。
  3. 极简目标 (Flow Matching):这套代码使用了 Flow Matching 而非 DDPM。它构建了一条从纯噪声 x0N(0,I)x_0 \sim \mathcal{N}(0, I) 到真实数据 x1x_1 的直线路径:xt=x0+t(x1x0)x_t = x_0 + t \cdot (x_1 - x_0)。模型只需预测这个线性过程的向量场 v=x1x0v = x_1 - x_0,推理时用欧拉法积分即可生成图像。

环境准备与全局设置

这部分主要负责导入必要的库,并固定随机种子以保证实验的绝对可复现性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
import os

# 新建images文件夹,避免保存图片报错
os.makedirs("images", exist_ok=True)

# 固定随机种子,保证可复现
def set_seed(seed=42):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

set_seed(42)

时间位置编码

时间位置编码 (SinusoidalPositionEmbeddings)

作用:将连续的时间步 t[0,1]t \in [0, 1] 转换为高维向量,让神经网络能“感受”到时间的流逝。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# -------------------------- 1. 正弦时间位置编码 --------------------------
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, time):
device = time.device
half_dim = self.dim // 2
# 利用对数空间的指数运算来生成不同频率的 base
embeddings = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
# 将时间 t 与频率相乘
embeddings = time[:, None] * embeddings[None, :]
# 拼接 sin 和 cos,形成完整的位置编码
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings

解读:这是经典的 Transformer 位置编码公式实现。通过 sincos 的组合,它能将标量时间转化为具有周期性特征的 dim 维张量,这是后续生成 adaLN 调制参数的原始信号源。

核心机制:DiT Block 与 adaLN-Zero

作用:DiT 的特征提取基石。包含了 Self-Attention、MLP 以及最核心的 6 参数 adaLN-Zero 调制层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -------------------------- 2. 核心DiT Block(6参数调制核心修改) --------------------------
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
# 注意:这里的 LayerNorm 关闭了 elementwise_affine (即可学习的 gamma 和 beta)
# 因为这两个参数将由 adaLN 动态生成
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)

self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim),
nn.GELU(),
nn.Linear(mlp_hidden_dim, hidden_size)
)

# ========== 6参数调制层,输出维度为 hidden_size * 6 ==========
# 对应:scale_mha, shift_mha, gate_mha, scale_mlp, shift_mlp, gate_mlp
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, hidden_size * 6)
)

# ========== adaLN-Zero零初始化,保证训练初始稳定性 ==========
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)

def forward(self, x, t_emb):
# ========== 拆分6个调制参数 ==========
modulation = self.adaLN_modulation(t_emb)
scale_mha, shift_mha, gate_mha, scale_mlp, shift_mlp, gate_mlp = modulation.chunk(6, dim=-1)

# 扩展维度,适配序列形状 [B, 1, hidden_size],支持广播
scale_mha = scale_mha[:, None, :]
shift_mha = shift_mha[:, None, :]
gate_mha = gate_mha[:, None, :]
scale_mlp = scale_mlp[:, None, :]
shift_mlp = shift_mlp[:, None, :]
gate_mlp = gate_mlp[:, None, :]

# ========== Self-Attention 分支,使用专属3个参数 ==========
residual = x
x = self.norm1(x)
x = x * (1 + scale_mha) + shift_mha # 标准adaLN缩放偏移
attn_out, _ = self.attn(x, x, x)
x = residual + gate_mha * attn_out # 门控残差连接

# ========== MLP 分支,使用专属3个参数 ==========
residual = x
x = self.norm2(x)
x = x * (1 + scale_mlp) + shift_mlp # 标准adaLN缩放偏移
mlp_out = self.mlp(x)
x = residual + gate_mlp * mlp_out # 门控残差连接

return x

解读:这是全篇最精妙的地方。elementwise_affine=False 是必须的,因为规范化参数交给了 adaLN_modulation 接管。全 0 初始化意味着在训练的第一步,scaleshift 为 0(等价于普通 LayerNorm),gate 也为 0(等价于 x = residual + 0)。这让初始网络成了一个完美的恒等映射,梯度可以毫无阻碍地传导,极大地加速了早期收敛。

宏观架构:DiT 主干网络

作用:处理图像到序列的转换 (Patchify),融合时间和类别条件,并利用堆叠的 DiT Blocks 提取特征,最终还原回图像维度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# -------------------------- 3. 简化版DiT主模型 --------------------------
# 适配小数据集,减小模型规模,加快训练
class DiT(nn.Module):
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=256, # 减小隐藏层维度,适配小数据
depth=6, # 减少层数,加快训练
num_heads=8,
mlp_ratio=4.0,
num_classes=4 # 匹配我们模拟数据的4个类别
):
super().__init__()
self.in_channels = in_channels
self.patch_size = patch_size
self.num_patches = (input_size // patch_size) ** 2

# Patch嵌入 (利用步长等于核大小的卷积极其高效地完成)
self.patch_embed = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_size))

# 时间+类别嵌入(条件注入的总来源,输出给每个block的adaLN)
self.time_embed = nn.Sequential(
SinusoidalPositionEmbeddings(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size)
)
self.class_embed = nn.Embedding(num_classes, hidden_size)

# Transformer Blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio) for _ in range(depth)
])

# 输出头
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# 最终输出层也适配adaLN,增加调制层
self.final_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, hidden_size * 2)
)
# 输出维度为 patch内的像素总数 (p*p*C)
self.head = nn.Linear(hidden_size, patch_size * patch_size * in_channels)

# 初始化权重
self._init_weights()

def _init_weights(self):
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.head.weight, std=0.02)
nn.init.zeros_(self.head.bias)
nn.init.zeros_(self.final_modulation[-1].weight)
nn.init.zeros_(self.final_modulation[-1].bias)

def forward(self, x, t, y=None):
B = x.shape[0]

# Patchify + 位置编码
x = self.patch_embed(x)
x = x.flatten(2).transpose(1, 2)
x = x + self.pos_embed

# 时间+类别条件嵌入(合并为统一的条件向量)
t_emb = self.time_embed(t)
if y is not None:
t_emb = t_emb + self.class_embed(y)

# 经过DiT Blocks,每个block都接收统一的条件嵌入
for block in self.blocks:
x = block(x, t_emb)

# 最终输出层adaLN调制
scale_final, shift_final = self.final_modulation(t_emb).chunk(2, dim=-1)
x = self.norm_final(x)
x = x * (1 + scale_final[:, None, :]) + shift_final[:, None, :]
x = self.head(x)

# 重组Patch,还原回图像空间 [B, C, H, W]
h = w = int(self.num_patches ** 0.5)
p = self.patch_size
x = x.transpose(1, 2).reshape(B, self.in_channels, p, p, h, w)
x = torch.einsum('n c p q h w -> n c h p w q', x)
x = x.reshape(B, self.in_channels, h * p, w * p)

return x

解读:这里的重点在于特征空间的转换。输入是 [B, C, H, W],经过 patch_embed 变成了 [B, Seq_Len, Hidden],加上了位置编码 pos_embed 后进入 Blocks。在输出阶段,通过一个 Linear 层和 einops 风格的重排(einsum+reshape),将 Transformer 处理完的一维序列,完好无损地拼接回了原始的图像分辨率。

模拟数据集与 Flow Matching 损失函数

作用:为了快速验证模型,使用 Numpy 生成了几何图案作为真实数据分布;利用 Flow Matching 构建从高斯噪声到数据的直线目标。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def generate_simulation_dataset(
num_samples=2000,
img_size=32,
in_channels=4,
num_classes=4,
noise_level=0.05 # 加少量噪声,模拟真实latent分布
):
"""
生成带类别规律的模拟DiT训练数据集
"""
x = np.linspace(-np.pi, np.pi, img_size)
y = np.linspace(-np.pi, np.pi, img_size)
xx, yy = np.meshgrid(x, y)

data, labels = [], []
samples_per_class = num_samples // num_classes

for class_id in range(num_classes):
for _ in range(samples_per_class):
if class_id == 0:
base_pattern = np.sin(2 * xx)
elif class_id == 1:
base_pattern = np.sin(4 * yy)
elif class_id == 2:
base_pattern = np.sign(np.sin(3 * xx) * np.sin(3 * yy))
else:
base_pattern = np.exp(-(xx**2 + yy**2) / 2) * np.cos(4 * np.sqrt(xx**2 + yy**2))

multi_channel_pattern = np.stack([
base_pattern + 0.1 * i for i in range(in_channels)
], axis=0)

multi_channel_pattern += np.random.normal(0, noise_level, multi_channel_pattern.shape)
multi_channel_pattern = multi_channel_pattern / np.max(np.abs(multi_channel_pattern))

data.append(multi_channel_pattern)
labels.append(class_id)

data = torch.tensor(np.array(data), dtype=torch.float32)
labels = torch.tensor(np.array(labels), dtype=torch.long)
return data, labels

def flow_matching_loss(model, x1, y=None):
"""
Flow Matching 核心损失
x1: 真实数据
y: 类别标签
"""
B = x1.shape[0]
device = x1.device

# 1. 均匀采样时间步 t ~ U[0, 1]
t = torch.rand(B, device=device)
t_expand = t[:, None, None, None] # 广播到图像维度

# 2. 采样源噪声 x0 ~ N(0, 1)
x0 = torch.randn_like(x1)

# 3. 构造线性插值路径 x_t = x0 + t*(x1 - x0)
x_t = x0 + t_expand * (x1 - x0)

# 4. 模型预测向量场
v_pred = model(x_t, t, y)

# 5. 损失:MSE(预测向量场, 真实向量场 x1-x0)
v_true = x1 - x0
loss = F.mse_loss(v_pred, v_true)

return loss

解读flow_matching_loss 完美展现了当前主流(如 Stable Diffusion 3 都在用)的 Flow Matching 的优雅之处。相比 DDPM 复杂的 αt\alpha_tβt\beta_t 调度表,FM 直接定义真实状态是 x1x_1,噪声是 x0x_0,模型只需要简单地学习向量差 x1x0x_1 - x_0 即可,代码极其简练且数学背景坚实。

训练与推理采样 (Euler Method)

作用:调度优化器进行模型训练,并通过常微分方程 (ODE) 的欧拉方法,根据预测的向量场逆向生成图像。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def train_dit(epochs=50, batch_size=64, lr=1e-4, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

print("正在生成模拟数据集...")
data, labels = generate_simulation_dataset(num_samples=2000)
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

model = DiT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
grad_clip = 1.0 # 梯度裁剪

total_params = sum(p.numel() for p in model.parameters())
print(f"模型总参数量: {total_params / 1e6:.2f}M")

print("开始训练...")
model.train()
loss_history = []

for epoch in range(epochs):
total_loss = 0.0
for batch_x1, batch_y in dataloader:
batch_x1, batch_y = batch_x1.to(device), batch_y.to(device)

optimizer.zero_grad()
loss = flow_matching_loss(model, batch_x1, batch_y)

loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()

total_loss += loss.item() * batch_x1.shape[0]

avg_loss = total_loss / len(dataset)
loss_history.append(avg_loss)
scheduler.step()

if (epoch + 1) % 5 == 0:
print(f"Epoch [{epoch+1}/{epochs}], 平均损失: {avg_loss:.6f}, LR: {scheduler.get_last_lr()[0]:.6f}")

print("训练完成!")
return model, loss_history, data, labels

@torch.no_grad()
def dit_sample(model, num_samples=8, class_id=0, img_size=32, in_channels=4, num_steps=20, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.eval()
# 1. 初始化:从标准高斯分布采样x0(t=0)
x = torch.randn(num_samples, in_channels, img_size, img_size, device=device)
y = torch.tensor([class_id] * num_samples, device=device)

# 3. 欧拉法时间步
dt = 1.0 / num_steps
t_list = torch.linspace(0, 1 - dt, num_steps, device=device)

# 4. 逐步积分求解ODE: dx/dt = v(x, t)
for t in t_list:
t_batch = torch.ones(num_samples, device=device) * t
v_pred = model(x, t_batch, y)
# 欧拉法更新
x = x + v_pred * dt

return x.cpu()

解读:采样函数 dit_sample 实现了从纯噪声变回数据的魔法。由于 Flow Matching 的公式是 dxdt=v(x,t)\frac{dx}{dt} = v(x, t),代码中通过 x = x + v_pred * dt 实现了最简单的数值积分(欧拉法),每一步都沿着模型指引的方向往前走一小步 dt,经过 20 步就能逼近真实的图像流形。

可视化与执行入口

作用:将训练结果直观展示,对比真实分布和生成的分布。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -------------------------- 可视化验证函数 --------------------------
def visualize_all_classes(real_samples, gen_samples, num_classes=4):
"""
对比可视化:所有类别的真实数据 vs 生成数据(放在同一张图)
第一行:真实图像,第二行:生成图像
"""
plt.figure(figsize=(num_classes * 3, 6))

for cls_idx in range(num_classes):
# 真实数据
plt.subplot(2, num_classes, cls_idx + 1)
plt.imshow(real_samples[cls_idx][0], cmap='viridis')
plt.title(f"True Class {cls_idx}")
plt.axis('off')

# 生成数据
plt.subplot(2, num_classes, cls_idx + 1 + num_classes)
plt.imshow(gen_samples[cls_idx][0], cmap='viridis')
plt.title(f"Gen Class {cls_idx}")
plt.axis('off')

plt.tight_layout()
plt.savefig("images/all_classes_comparison.png", dpi=150, bbox_inches='tight')
plt.close()

# 损失曲线可视化
def plot_loss_curve(loss_history):
plt.figure(figsize=(8, 4))
plt.plot(loss_history)
plt.title("training loss curve")
plt.xlabel("Epoch")
plt.ylabel("Flow Matching Loss")
plt.grid(True, alpha=0.3)
plt.savefig("images/loss_curve.png", dpi=150, bbox_inches='tight')
plt.close()

if __name__ == "__main__":
# 1. 训练模型
trained_model, loss_history, real_data, real_labels = train_dit(epochs=50)

# 2. 绘制损失曲线,验证训练收敛性
plot_loss_curve(loss_history)

# 3. 为每个类别生成1个样本,并收集1个真实样本
target_classes = [0, 1, 2, 3]
real_samples = []
gen_samples = []

print("\n正在生成所有类别的样本...")
for cls in target_classes:
# 收集1个真实样本
real_sample = real_data[real_labels == cls][0]
real_samples.append(real_sample)

# 生成1个样本
gen_sample = dit_sample(trained_model, num_samples=1, class_id=cls)
gen_samples.append(gen_sample[0])

# 4. 可视化所有类别在一张图里
visualize_all_classes(real_samples, gen_samples)

print("\n所有任务完成!生成的图片已保存到 images 文件夹")

训练损失和真实图像和生成图像展示

训练损失展示

真实图像和生成图像对比


手撕DiT
https://huan-yin.github.io/2026/04/09/手撕DiT/
作者
李相越
发布于
2026年4月9日
许可协议