Latent Diffusion & Stable Diffusion
Pixel-space diffusion on 512×512 images is a computational war crime. Rombach et al. (2022) noticed that you do not need all 786k dimensions to generate an image — you need enough to capture semantic structure, and a separate decoder for the rest. Run diffusion inside a VAE's latent space. That one idea is Stable Diffusion.
Type: Build Languages: Python Prerequisites: Phase 8 · 02 (VAE), Phase 8 · 06 (DDPM), Phase 7 · 09 (ViT) Time: ~75 minutes
The Problem
Pixel-space diffusion at 512² means the U-Net runs on tensors of shape [B, 3, 512, 512]. Each sampling step is ~100 GFLOPS for a 500M-param U-Net. Fifty steps is 5 TFLOPS per image. Train on a billion images and the compute bill is absurd.
Most of those FLOPs go to pushing perceptually unimportant details through the net — the high-frequency texture that a lossy VAE could compress away. Rombach's idea: train a VAE once (the first stage), freeze it, and run diffusion entirely in the 4-channel 64×64 latent space (the second stage). Same U-Net. 1/16th the pixels. ~64x fewer FLOPs for comparable quality.
This is the Stable Diffusion recipe. SD 1.x / 2.x used an 860M U-Net over 64×64×4 latents, SDXL used a 2.6B U-Net over 128×128×4, SD3 swapped the U-Net for a Diffusion Transformer (DiT) with flow matching. Flux.1-dev (Black Forest Labs, 2024) ships a 12B-param DiT-MMDiT. All run on the same two-stage substrate.
The Concept
Latent diffusion: VAE compression + diffusion in latent space
Two stages, separately trained.
Stage 1 — VAE. Encoder
E(x) → z, decoderD(z) → x. Target compression: 8× downsample in each spatial axis + adjust channels so total latent size is ~1/16th of pixel count. Loss = reconstruction (L1 + LPIPS perceptual) + KL (small weight sozisn't forced too Gaussian, because we do not need exact sampling fromz). Often trained with an adversarial loss so decoded images are sharp.Stage 2 — diffusion on
z. Treatz = E(x_real)as the data. Train a U-Net (or DiT) to denoisez_t. At inference: samplez_0via diffusion, thenx = D(z_0).
Text conditioning. Two additional components. A frozen text encoder (CLIP-L for SD 1.x, CLIP-L+OpenCLIP-G for SD 2/XL, T5-XXL for SD3 and Flux). A cross-attention injection: every U-Net block takes [Q = image features, K = V = text tokens] and mixes them in. The tokens are the only way text influences the image.
The loss function is identical to Lesson 06. Same DDPM / flow matching MSE on noise. You just swap the data domain.
Architecture variants
| Model | Year | Backbone | Latent shape | Text encoder | Params |
|---|---|---|---|---|---|
| SD 1.5 | 2022 | U-Net | 64×64×4 | CLIP-L (77 tokens) | 860M |
| SD 2.1 | 2022 | U-Net | 64×64×4 | OpenCLIP-H | 865M |
| SDXL | 2023 | U-Net + refiner | 128×128×4 | CLIP-L + OpenCLIP-G | 2.6B + 6.6B |
| SDXL-Turbo | 2023 | Distilled | 128×128×4 | same | 1-4 step sampling |
| SD3 | 2024 | MMDiT (multimodal DiT) | 128×128×16 | T5-XXL + CLIP-L + CLIP-G | 2B / 8B |
| Flux.1-dev | 2024 | MMDiT | 128×128×16 | T5-XXL + CLIP-L | 12B |
| Flux.1-schnell | 2024 | MMDiT distilled | 128×128×16 | T5-XXL + CLIP-L | 12B, 1-4 step |
The trend: replace U-Net with DiT (transformer over latent patches), scale the text encoder (T5 beats CLIP for prompt adherence), increase latent channels (4 → 16 gives more detail headroom).
Build It
code/main.py stacks a toy 1-D "VAE" (identity encoder + decoder, for demonstration; a real VAE would be a conv net) on top of the DDPM from Lesson 06 and adds class conditioning with classifier-free guidance. It shows that the same diffusion loss works whether you run on raw 1-D values or on encoded values — the key insight.
Step 1: encoder/decoder
def encode(x): return x * 0.5 # toy "compression" to smaller scale
def decode(z): return z * 2.0A real VAE has trained weights. For pedagogy, this linear map is enough to show that diffusion operates on z without caring about the original data space.
Step 2: diffusion in z-space
Same DDPM as Lesson 06. The data the net sees is z = E(x). After sampling z_0, decode with D(z_0).
Step 3: classifier-free guidance
During training, drop the class label 10% of the time (replace with a null token). At inference, compute both ε_cond and ε_uncond, then:
eps_cfg = (1 + w) * eps_cond - w * eps_uncondw = 0 = no guidance (full diversity), w = 3 = default, w = 7+ = saturated / over-sharp.
Step 4: text conditioning (concept, not code)
Replace the class label with a frozen text encoder output. Feed the text embedding to the U-Net via cross-attention:
h = h + CrossAttention(Q=h, K=text_embed, V=text_embed)This is the only substantive difference between a class-conditional diffusion model and Stable Diffusion.
Pitfalls
- VAE-scale mismatch. SD 1.x VAEs have a scaling constant (
scaling_factor ≈ 0.18215) applied after encoding. Forgetting this makes the U-Net train on latents with wildly wrong variance. Every checkpoint ships one. - Text encoder silently wrong. SD3 needs T5-XXL with >=128 tokens, and the fallback to CLIP-only is lossy. Always check
use_t5=Trueor prompt fidelity craters. - Mixing latent spaces. SDXL, SD3, Flux all use different VAEs. A LoRA trained on SDXL latents will not work on SD3. Hugging Face diffusers 0.30+ refuses to load mismatched checkpoints.
- CFG too high.
w > 10produces saturated, oily images and over-fits the prompt at the cost of diversity. The sweet spot isw = 3-7. - Negative prompts leaking. Empty negative prompt becomes the null token; a filled negative prompt becomes the
ε_uncond. These are not the same; some pipelines silently default to the null.
Use It
Production stacks in 2026:
| Target | Recommended backbone |
|---|---|
| Narrow domain, paired data, training a model from scratch | SDXL fine-tune (LoRA / full) — fastest to ship |
| Open-domain text-to-image, open weights | Flux.1-dev (12B, Apache / non-commercial) or SD3.5-Large |
| Fastest inference, open weights | Flux.1-schnell (1-4 step, Apache) or SDXL-Lightning |
| Best prompt adherence, hosted | GPT-Image / DALL-E 3 (still), Midjourney v7, Imagen 4 |
| Edit workflows | Flux.1-Kontext (Dec 2024) — natively accepts image + text |
| Research, baseline | SD 1.5 — ancient but well-studied |
Ship It
Save outputs/skill-sd-prompter.md. Skill takes a text prompt + target style and outputs: model + checkpoint, CFG scale, sampler, negative prompt, resolution, optional ControlNet/IP-Adapter combo, and a per-step QA checklist.
Exercises
- Easy. Run
code/main.pywith guidancew ∈ {0, 1, 3, 7, 15}. Record mean sample by class. At whatwdo the class means diverge past the real data means? - Medium. Swap the toy linear encoder for a tanh-MLP encoder/decoder pair with a reconstruction loss. Retrain diffusion on the new latents. Does sample quality change?
- Hard. Set up a real Stable Diffusion inference with diffusers: load
sdxl-base, run 30 Euler steps with CFG=7, time it. Now switch tosdxl-turbowith 4 steps and CFG=0. Same subject, different quality — describe what changed and why.
Key Terms
| Term | What people say | What it actually means |
|---|---|---|
| First stage | "The VAE" | Trained encoder/decoder pair; compresses 512² to 64². |
| Second stage | "The U-Net" | Diffusion model over the latent space. |
| CFG | "Guidance scale" | (1+w)·ε_cond - w·ε_uncond; tunes conditioning strength. |
| Null token | "Empty prompt embed" | Unconditional embed used for ε_uncond. |
| Cross-attention | "How text gets in" | Each U-Net block attends to text tokens as K and V. |
| DiT | "Diffusion Transformer" | Replace U-Net with a transformer over latent patches; scales better. |
| MMDiT | "Multi-modal DiT" | SD3's architecture: text and image streams with joint attention. |
| VAE scaling factor | "Magic number" | Divides latents by ~5.4 so diffusion operates in unit-variance space. |
Production note: running Flux-12B on an 8GB consumer GPU
the reference Flux integration is the canonical "I have a consumer GPU, can I ship this?" recipe. The trick is the same three-knob recipe production inference literature lists applied to a diffusion DiT:
- Staggered loading. Flux has three networks that never need to coexist in VRAM: T5-XXL text encoder (~10 GB in fp32), CLIP-L (small), the 12B MMDiT, and the VAE. Encode the prompt first, delete the encoders, load the DiT, denoise, delete the DiT, load the VAE, decode. Consumer 8GB GPUs only fit one stage at a time.
- 4-bit quantization via bitsandbytes.
BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)on both the T5 encoder and the DiT. Cuts memory 8×, quality drop is imperceptible for text-to-image per Aritra's benchmarks (linked in the notebook). - CPU offload.
pipe.enable_model_cpu_offload()auto-swaps modules between CPU and GPU as each forward pass advances. Adds 10-20% latency but makes the pipeline run at all.
The memory accounting is: 10 GB T5 / 8 = 1.25 GB quantized, 12 B params × 0.5 bytes = ~6 GB quantized DiT, plus activations. In stas00's terms this is the extreme-end of TP=1 inference — no model parallelism, maximum quantization. For production you'd run TP=2 or TP=4 on H100s; for a single dev laptop, this is the recipe.
Further Reading
- Rombach et al. (2022). High-Resolution Image Synthesis with Latent Diffusion Models — Stable Diffusion.
- Podell et al. (2023). SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis — SDXL.
- Peebles & Xie (2023). Scalable Diffusion Models with Transformers (DiT) — DiT.
- Esser et al. (2024). Scaling Rectified Flow Transformers for High-Resolution Image Synthesis — SD3, MMDiT.
- Ho & Salimans (2022). Classifier-Free Diffusion Guidance — CFG.
- Labs (2024). Flux.1 — Black Forest Labs announcement — Flux.1 family.
- Hugging Face Diffusers docs — reference implementation for every checkpoint above.