Back to Projects
DiffSeg — Unsupervised Zero-Shot Segmentation with Stable Diffusion
Computer VisionDeep LearningFeaturedMarch 2025

DiffSeg — Unsupervised Zero-Shot Segmentation with Stable Diffusion

Unofficial PyTorch reimplementation of DiffSeg (CVPR 2024) — segment any image with zero labels, zero training, zero text prompts, using only the self-attention maps of a frozen Stable Diffusion UNet.

PyTorchDiffusersGradioHuggingFace SpacesStable Diffusion

Overview

This project is an unofficial PyTorch reimplementation of DiffSeg, presented at CVPR 2024 by Tian et al. from Google and Georgia Tech. The original codebase was written in TensorFlow 2 + KerasCV — this port brings the full algorithm to the PyTorch / HuggingFace diffusers ecosystem, with a live interactive Gradio demo deployed on HuggingFace Spaces.

A full paper breakdown is coming soon on my blog and Medium. Here I focus on the implementation decisions and what makes this project technically interesting.


What Problem Does It Solve?

Most segmentation models need one of:

  • Dense annotations (pixel-level labels — expensive, time-consuming)
  • A text prompt (like SAM or CLIP-based methods)
  • Domain-specific training (not generalisable)

DiffSeg needs none of these. It segments any image — photos, illustrations, medical scans, satellite imagery — using only a frozen pre-trained Stable Diffusion model and zero additional training.


The Core Idea

Stable Diffusion's UNet contains self-attention layers that have implicitly learned to group objects together during generative training. The key observation from the paper:

Pixels belonging to the same object attend to the same set of other pixels.

This means the self-attention maps are already a form of unsupervised segmentation — they just need to be extracted and organised.


Pipeline

The pipeline (shown in the diagram below) has three main stages:

Scan count

1 · Attention Extraction

One single denoising step is run at a chosen timestep (typically 50–200). During this forward pass, PyTorch forward hooks tap into every self-attention layer of the UNet via a custom AttnProcessor, capturing 16 attention tensors at 4 spatial resolutions: 64×64, 32×32, 16×16, 8×8.

class CapturingAttnProcessor:
    def __call__(self, attn, hidden_states, encoder_hidden_states=None, ...):
        q = attn.to_q(hidden_states)
        k = attn.to_k(hidden_states)
        # compute weights, capture them, return correct output
        weights = softmax(q @ k.T / scale)
        attn._captured_weights = weights.detach()
        ...

The key implementation challenge here was avoiding the bug present in naive monkey-patching approaches: computing Q/K twice causes shape mismatches like (32768×64) vs (512×512). Using a proper AttnProcessor solves this cleanly.

2 · Attention Aggregation

All 16 maps are aggregated into a single 4D tensor Af ∈ R^(64×64×64×64) using the paper's exact upsampling strategy — not bilinear interpolation on both axes (a common mistake), but:

  • Key axis (what a pixel attends to): bilinear upsample
  • Query axis (which pixel is looking): tile/repeat
# Key axis — bilinear interpolate the probability distribution
attn_4d = attn.reshape(1, n, src_res, src_res)
key_up = F.interpolate(attn_4d, size=(64, 64), mode="bilinear")
 
# Query axis — tile each low-res position to its high-res patch
query_up = key_up_2d.repeat_interleave(scale, dim=0)
             .repeat_interleave(scale, dim=1)

This asymmetric treatment is critical — bilinearly interpolating the query axis creates false intermediate attention values that corrupt the segmentation boundaries.

3 · Iterative KL-Divergence Merging

Starting from an M×M grid of anchor points, attention maps are iteratively merged when their symmetric KL divergence falls below a threshold:

KL(pᵢ ‖ pⱼ) < threshold  →  merge into one group

This is the "slider" in the live demo. Lower threshold = more fine-grained segments. Higher threshold = fewer coarse regions. Unlike K-means, no cluster count is needed upfront.

After merging, Non-Maximum Suppression converts the proposal maps into a clean integer label map, upsampled back to original image resolution using nearest-neighbor interpolation.


Live Demo Controls

The Gradio app exposes all algorithm parameters as interactive controls:

ControlWhat it doesRequires re-encode?
TimestepWhich denoising step to extract attention from✅ Yes
KL thresholdMerge aggressiveness (more/fewer segments)❌ Instant
Resolution presetWhich attention resolutions to use (64²/32²/16²/8²)❌ Instant
w64 / w32 / w16 / w8Custom per-resolution weights❌ Instant
Semantic labelsBLIP caption + noun assignment per segment❌ Instant

The key UX decision was caching the attention bundle after encode. The expensive SD forward pass (~3–5s on GPU) runs once. All slider interactions only re-run the cheap merging step (~0.1s).

A screenshot of the live demo (shown in the image below)

demo live screenshot


Implementation Differences from Original

Original (TF/KerasCV)This repo (PyTorch)
FrameworkTensorFlow 2.14 + KerasCVPyTorch 2.x + diffusers
Attention extractionKerasCV layer hooksAttnProcessor (correct diffusers pattern)
Aggregation upsamplingBilinear key + tile querySame — paper-exact ✅
Live demoJupyter notebookGradio app on HF Spaces
ParametersHardcodedAll interactive sliders

Results

For this simple airplane image, using timestep=100, KL=1.2, resolution=64×64+32×32: the model cleanly separates the fuselage, wings, engines, tail, and sky — with no labels, no text, no fine-tuning.

The proportional aggregation strategy (higher weight to higher resolution maps) balances detail and coherence — matching the paper's Figure 4 ablation exactly.


What I Learned

  • Diffusers AttnProcessor is the correct way to intercept attention in modern diffusers — monkey-patching forward() causes double-computation bugs that are hard to debug
  • Asymmetric upsampling (bilinear on key, tile on query) is a subtle but critical detail not obvious from the paper text alone — it took reading the actual formula carefully
  • Caching architecture is essential for interactive ML demos — separating the expensive encode from the cheap segment step makes sliders feel instant
  • Git LFS + HuggingFace binary file handling requires careful setup — the HuggingFace API upload_folder() is more reliable than git push for Spaces deployment


Citation

@inproceedings{tian2024diffuse,
  title={Diffuse, Attend, and Segment: Unsupervised Zero-Shot Segmentation
         using Stable Diffusion},
  author={Tian, Junjiao and Aggarwal, Lavisha and Colaco, Andrea
          and Kira, Zsolt and Gonzalez-Franco, Mar},
  booktitle={CVPR},
  year={2024}
}