手撕 DiT
实现一个轻量级的 DiT (Diffusion Transformer) 架构,结合了目前前沿的 Flow Matching (流匹配) 范式,并在自定义的模拟数据集上完成了完整的训练与欧拉法采样闭环。
DiT 架构核心原理
DiT (Diffusion Transformer) 的核心思想非常简单粗暴但极其有效:用 Transformer 替换掉传统 Diffusion 模型(如 DDPM)中常用的 U-Net 架构。 它的成功证明了 Transformer 在视觉生成任务上同样具备强大的 Scaling Law。DiT 的架构主要依赖以下几个关键机制:
图像块化 (Patchify) :继承自 ViT 的思想。DiT 将输入的潜变量空间切分成不重叠的 Patch,并展平为一维序列,将图像生成转换为标准的序列建模任务。
条件注入 (adaLN-Zero) :模型需要知道当前时间步 t t t 和条件 y y y 。DiT 抛弃了交叉注意力,使用 adaLN (自适应层归一化)。条件向量被映射为缩放 γ \gamma γ 、平移 β \beta β 和残差门控 α \alpha α 。初始化时,这些参数全为 0,这保证了深层 Transformer 早期训练的极高稳定性。
极简目标 (Flow Matching) :这套代码使用了 Flow Matching 而非 DDPM。它构建了一条从纯噪声 x 0 ∼ N ( 0 , I ) x_0 \sim \mathcal{N}(0, I) x 0 ∼ N ( 0 , I ) 到真实数据 x 1 x_1 x 1 的直线路径:x t = x 0 + t ⋅ ( x 1 − x 0 ) x_t = x_0 + t \cdot (x_1 - x_0) x t = x 0 + t ⋅ ( x 1 − x 0 ) 。模型只需预测这个线性过程的向量场 v = x 1 − x 0 v = x_1 - x_0 v = x 1 − x 0 ,推理时用欧拉法积分即可生成图像。
环境准备与全局设置
这部分主要负责导入必要的库,并固定随机种子以保证实验的绝对可复现性。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torch.utils.data import DataLoader, TensorDatasetimport matplotlib.pyplot as pltimport numpy as npimport os os.makedirs("images" , exist_ok=True )def set_seed (seed=42 ): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True set_seed(42 )
时间位置编码
时间位置编码 (SinusoidalPositionEmbeddings)
作用 :将连续的时间步 t ∈ [ 0 , 1 ] t \in [0, 1] t ∈ [ 0 , 1 ] 转换为高维向量,让神经网络能“感受”到时间的流逝。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class SinusoidalPositionEmbeddings (nn.Module): def __init__ (self, dim ): super ().__init__() self .dim = dim def forward (self, time ): device = time.device half_dim = self .dim // 2 embeddings = torch.log(torch.tensor(10000.0 )) / (half_dim - 1 ) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None ] * embeddings[None , :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1 ) return embeddings
解读 :这是经典的 Transformer 位置编码公式实现。通过 sin 和 cos 的组合,它能将标量时间转化为具有周期性特征的 dim 维张量,这是后续生成 adaLN 调制参数的原始信号源。
核心机制:DiT Block 与 adaLN-Zero
作用 :DiT 的特征提取基石。包含了 Self-Attention、MLP 以及最核心的 6 参数 adaLN-Zero 调制层。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 class DiTBlock (nn.Module): def __init__ (self, hidden_size, num_heads, mlp_ratio=4.0 ): super ().__init__() self .norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False , eps=1e-6 ) self .attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True ) self .norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False , eps=1e-6 ) mlp_hidden_dim = int (hidden_size * mlp_ratio) self .mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim), nn.GELU(), nn.Linear(mlp_hidden_dim, hidden_size) ) self .adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, hidden_size * 6 ) ) nn.init.zeros_(self .adaLN_modulation[-1 ].weight) nn.init.zeros_(self .adaLN_modulation[-1 ].bias) def forward (self, x, t_emb ): modulation = self .adaLN_modulation(t_emb) scale_mha, shift_mha, gate_mha, scale_mlp, shift_mlp, gate_mlp = modulation.chunk(6 , dim=-1 ) scale_mha = scale_mha[:, None , :] shift_mha = shift_mha[:, None , :] gate_mha = gate_mha[:, None , :] scale_mlp = scale_mlp[:, None , :] shift_mlp = shift_mlp[:, None , :] gate_mlp = gate_mlp[:, None , :] residual = x x = self .norm1(x) x = x * (1 + scale_mha) + shift_mha attn_out, _ = self .attn(x, x, x) x = residual + gate_mha * attn_out residual = x x = self .norm2(x) x = x * (1 + scale_mlp) + shift_mlp mlp_out = self .mlp(x) x = residual + gate_mlp * mlp_out return x
解读 :这是全篇最精妙的地方。elementwise_affine=False 是必须的,因为规范化参数交给了 adaLN_modulation 接管。全 0 初始化意味着在训练的第一步,scale 和 shift 为 0(等价于普通 LayerNorm),gate 也为 0(等价于 x = residual + 0)。这让初始网络成了一个完美的恒等映射,梯度可以毫无阻碍地传导,极大地加速了早期收敛。
宏观架构:DiT 主干网络
作用 :处理图像到序列的转换 (Patchify),融合时间和类别条件,并利用堆叠的 DiT Blocks 提取特征,最终还原回图像维度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 class DiT (nn.Module): def __init__ ( self, input_size=32 , patch_size=2 , in_channels=4 , hidden_size=256 , depth=6 , num_heads=8 , mlp_ratio=4.0 , num_classes=4 ): super ().__init__() self .in_channels = in_channels self .patch_size = patch_size self .num_patches = (input_size // patch_size) ** 2 self .patch_embed = nn.Conv2d(in_channels, hidden_size, kernel_size=patch_size, stride=patch_size) self .pos_embed = nn.Parameter(torch.zeros(1 , self .num_patches, hidden_size)) self .time_embed = nn.Sequential( SinusoidalPositionEmbeddings(hidden_size), nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, hidden_size) ) self .class_embed = nn.Embedding(num_classes, hidden_size) self .blocks = nn.ModuleList([ DiTBlock(hidden_size, num_heads, mlp_ratio) for _ in range (depth) ]) self .norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False , eps=1e-6 ) self .final_modulation = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, hidden_size * 2 ) ) self .head = nn.Linear(hidden_size, patch_size * patch_size * in_channels) self ._init_weights() def _init_weights (self ): nn.init.normal_(self .pos_embed, std=0.02 ) nn.init.trunc_normal_(self .head.weight, std=0.02 ) nn.init.zeros_(self .head.bias) nn.init.zeros_(self .final_modulation[-1 ].weight) nn.init.zeros_(self .final_modulation[-1 ].bias) def forward (self, x, t, y=None ): B = x.shape[0 ] x = self .patch_embed(x) x = x.flatten(2 ).transpose(1 , 2 ) x = x + self .pos_embed t_emb = self .time_embed(t) if y is not None : t_emb = t_emb + self .class_embed(y) for block in self .blocks: x = block(x, t_emb) scale_final, shift_final = self .final_modulation(t_emb).chunk(2 , dim=-1 ) x = self .norm_final(x) x = x * (1 + scale_final[:, None , :]) + shift_final[:, None , :] x = self .head(x) h = w = int (self .num_patches ** 0.5 ) p = self .patch_size x = x.transpose(1 , 2 ).reshape(B, self .in_channels, p, p, h, w) x = torch.einsum('n c p q h w -> n c h p w q' , x) x = x.reshape(B, self .in_channels, h * p, w * p) return x
解读 :这里的重点在于特征空间的转换。输入是 [B, C, H, W],经过 patch_embed 变成了 [B, Seq_Len, Hidden],加上了位置编码 pos_embed 后进入 Blocks。在输出阶段,通过一个 Linear 层和 einops 风格的重排(einsum+reshape),将 Transformer 处理完的一维序列,完好无损地拼接回了原始的图像分辨率。
模拟数据集与 Flow Matching 损失函数
作用 :为了快速验证模型,使用 Numpy 生成了几何图案作为真实数据分布;利用 Flow Matching 构建从高斯噪声到数据的直线目标。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 def generate_simulation_dataset ( num_samples=2000 , img_size=32 , in_channels=4 , num_classes=4 , noise_level=0.05 ): """ 生成带类别规律的模拟DiT训练数据集 """ x = np.linspace(-np.pi, np.pi, img_size) y = np.linspace(-np.pi, np.pi, img_size) xx, yy = np.meshgrid(x, y) data, labels = [], [] samples_per_class = num_samples // num_classes for class_id in range (num_classes): for _ in range (samples_per_class): if class_id == 0 : base_pattern = np.sin(2 * xx) elif class_id == 1 : base_pattern = np.sin(4 * yy) elif class_id == 2 : base_pattern = np.sign(np.sin(3 * xx) * np.sin(3 * yy)) else : base_pattern = np.exp(-(xx**2 + yy**2 ) / 2 ) * np.cos(4 * np.sqrt(xx**2 + yy**2 )) multi_channel_pattern = np.stack([ base_pattern + 0.1 * i for i in range (in_channels) ], axis=0 ) multi_channel_pattern += np.random.normal(0 , noise_level, multi_channel_pattern.shape) multi_channel_pattern = multi_channel_pattern / np.max (np.abs (multi_channel_pattern)) data.append(multi_channel_pattern) labels.append(class_id) data = torch.tensor(np.array(data), dtype=torch.float32) labels = torch.tensor(np.array(labels), dtype=torch.long) return data, labelsdef flow_matching_loss (model, x1, y=None ): """ Flow Matching 核心损失 x1: 真实数据 y: 类别标签 """ B = x1.shape[0 ] device = x1.device t = torch.rand(B, device=device) t_expand = t[:, None , None , None ] x0 = torch.randn_like(x1) x_t = x0 + t_expand * (x1 - x0) v_pred = model(x_t, t, y) v_true = x1 - x0 loss = F.mse_loss(v_pred, v_true) return loss
解读 :flow_matching_loss 完美展现了当前主流(如 Stable Diffusion 3 都在用)的 Flow Matching 的优雅之处。相比 DDPM 复杂的 α t \alpha_t α t 和 β t \beta_t β t 调度表,FM 直接定义真实状态是 x 1 x_1 x 1 ,噪声是 x 0 x_0 x 0 ,模型只需要简单地学习向量差 x 1 − x 0 x_1 - x_0 x 1 − x 0 即可,代码极其简练且数学背景坚实。
训练与推理采样 (Euler Method)
作用 :调度优化器进行模型训练,并通过常微分方程 (ODE) 的欧拉方法,根据预测的向量场逆向生成图像。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 def train_dit (epochs=50 , batch_size=64 , lr=1e-4 , device=None ): if device is None : device = torch.device("cuda" if torch.cuda.is_available() else "cpu" ) print (f"使用设备: {device} " ) print ("正在生成模拟数据集..." ) data, labels = generate_simulation_dataset(num_samples=2000 ) dataset = TensorDataset(data, labels) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True , drop_last=True ) model = DiT().to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4 ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) grad_clip = 1.0 total_params = sum (p.numel() for p in model.parameters()) print (f"模型总参数量: {total_params / 1e6 :.2 f} M" ) print ("开始训练..." ) model.train() loss_history = [] for epoch in range (epochs): total_loss = 0.0 for batch_x1, batch_y in dataloader: batch_x1, batch_y = batch_x1.to(device), batch_y.to(device) optimizer.zero_grad() loss = flow_matching_loss(model, batch_x1, batch_y) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() total_loss += loss.item() * batch_x1.shape[0 ] avg_loss = total_loss / len (dataset) loss_history.append(avg_loss) scheduler.step() if (epoch + 1 ) % 5 == 0 : print (f"Epoch [{epoch+1 } /{epochs} ], 平均损失: {avg_loss:.6 f} , LR: {scheduler.get_last_lr()[0 ]:.6 f} " ) print ("训练完成!" ) return model, loss_history, data, labels@torch.no_grad() def dit_sample (model, num_samples=8 , class_id=0 , img_size=32 , in_channels=4 , num_steps=20 , device=None ): if device is None : device = torch.device("cuda" if torch.cuda.is_available() else "cpu" ) model.eval () x = torch.randn(num_samples, in_channels, img_size, img_size, device=device) y = torch.tensor([class_id] * num_samples, device=device) dt = 1.0 / num_steps t_list = torch.linspace(0 , 1 - dt, num_steps, device=device) for t in t_list: t_batch = torch.ones(num_samples, device=device) * t v_pred = model(x, t_batch, y) x = x + v_pred * dt return x.cpu()
解读 :采样函数 dit_sample 实现了从纯噪声变回数据的魔法。由于 Flow Matching 的公式是 d x d t = v ( x , t ) \frac{dx}{dt} = v(x, t) d t d x = v ( x , t ) ,代码中通过 x = x + v_pred * dt 实现了最简单的数值积分(欧拉法),每一步都沿着模型指引的方向往前走一小步 dt,经过 20 步就能逼近真实的图像流形。
可视化与执行入口
作用 :将训练结果直观展示,对比真实分布和生成的分布。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 def visualize_all_classes (real_samples, gen_samples, num_classes=4 ): """ 对比可视化:所有类别的真实数据 vs 生成数据(放在同一张图) 第一行:真实图像,第二行:生成图像 """ plt.figure(figsize=(num_classes * 3 , 6 )) for cls_idx in range (num_classes): plt.subplot(2 , num_classes, cls_idx + 1 ) plt.imshow(real_samples[cls_idx][0 ], cmap='viridis' ) plt.title(f"True Class {cls_idx} " ) plt.axis('off' ) plt.subplot(2 , num_classes, cls_idx + 1 + num_classes) plt.imshow(gen_samples[cls_idx][0 ], cmap='viridis' ) plt.title(f"Gen Class {cls_idx} " ) plt.axis('off' ) plt.tight_layout() plt.savefig("images/all_classes_comparison.png" , dpi=150 , bbox_inches='tight' ) plt.close()def plot_loss_curve (loss_history ): plt.figure(figsize=(8 , 4 )) plt.plot(loss_history) plt.title("training loss curve" ) plt.xlabel("Epoch" ) plt.ylabel("Flow Matching Loss" ) plt.grid(True , alpha=0.3 ) plt.savefig("images/loss_curve.png" , dpi=150 , bbox_inches='tight' ) plt.close()if __name__ == "__main__" : trained_model, loss_history, real_data, real_labels = train_dit(epochs=50 ) plot_loss_curve(loss_history) target_classes = [0 , 1 , 2 , 3 ] real_samples = [] gen_samples = [] print ("\n正在生成所有类别的样本..." ) for cls in target_classes: real_sample = real_data[real_labels == cls][0 ] real_samples.append(real_sample) gen_sample = dit_sample(trained_model, num_samples=1 , class_id=cls) gen_samples.append(gen_sample[0 ]) visualize_all_classes(real_samples, gen_samples) print ("\n所有任务完成!生成的图片已保存到 images 文件夹" )
训练损失和真实图像和生成图像展示