This document explains the SimpleScoreNetwork architecture from the SDE formulation notebook, focusing on the SiLU activation function and how architecture choices differ for toy 2D data vs. real images.
- Notebook:
notebooks/diffusion/02_sde_formulation/02_sde_formulation.ipynb - Module: The notebook imports from
genailab.diffusionmodule:from genailab.diffusion import SimpleScoreNetwork as ScoreNet from genailab.diffusion import VPSDE as VPSDE_Module from genailab.diffusion import train_score_network, sample_reverse_sde
SiLU = Sigmoid Linear Unit, also known as Swish
where
In PyTorch: nn.SiLU() or F.silu()
ReLU(x): SiLU(x):
| |
| / | /
| / | /
____|__/_______ x ___|____/_______ x
| /|
| / |
/ |
ReLU: Sharp corner at 0, completely kills negative values
SiLU: Smooth everywhere, allows small negative values through
| Property | ReLU | SiLU |
|---|---|---|
| Formula | ||
| Range | ||
| For |
|
|
| For |
||
| For |
|
Small negative value (alive) |
| Smoothness | Not smooth at 0 | Smooth everywhere |
| Derivative | $\begin{cases}0 & x < 0 \ 1 & x > 0\end{cases}$ |
ReLU issue: When
- Neuron can get "stuck" with negative input
- No gradient flows back → can't recover
- Known as "dying ReLU" problem
Example:
x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
relu_output = F.relu(x)
# tensor([0., 0., 0., 0., 0.5, 1., 2.])
# All negative values become 0 - information lost!
silu_output = F.silu(x)
# tensor([-0.2384, -0.2689, -0.1887, 0., 0.3113, 0.7311, 1.7616])
# Negative values still contribute (small but non-zero)- No dead neurons: Even for negative inputs, SiLU outputs non-zero values and has non-zero gradients
- Smooth gradients: Better for optimization (no sharp corners)
-
Self-gating: The
$x \cdot \sigma(x)$ structure allows the network to learn adaptive gating - Empirically better: Often improves performance in deep networks
import torch
import torch.nn.functional as F
x = torch.linspace(-5, 5, 11)
silu_x = F.silu(x)
for xi, si in zip(x, silu_x):
print(f"x={xi:5.1f} → SiLU(x)={si:7.4f}")
# Output:
# x= -5.0 → SiLU(x)=-0.0337 (small negative)
# x= -4.0 → SiLU(x)=-0.0719 (slightly larger)
# x= -3.0 → SiLU(x)=-0.1423 (getting bigger)
# x= -2.0 → SiLU(x)=-0.2384 (peak negativity around here)
# x= -1.0 → SiLU(x)=-0.2689 (maximum negative value)
# x= 0.0 → SiLU(x)= 0.0000 (zero)
# x= 1.0 → SiLU(x)= 0.7311 (positive)
# x= 2.0 → SiLU(x)= 1.7616 (approaching x)
# x= 3.0 → SiLU(x)= 2.8577 (close to x)
# x= 4.0 → SiLU(x)= 3.9281 (almost x)
# x= 5.0 → SiLU(x)= 4.9665 (≈ x)Key observation:
- For large positive
$x$ :$\text{SiLU}(x) \approx x$ (acts like identity) - For large negative
$x$ :$\text{SiLU}(x) \approx 0$ (but not exactly 0) - Minimum value:
$\approx -0.278$ (at$x \approx -1.278$ )
Key point: Derivative is non-zero for all
x = torch.tensor([-2.0, 0.0, 2.0], requires_grad=True)
y = F.silu(x)
y.sum().backward()
print(x.grad)
# tensor([0.1050, 0.5000, 1.0762])
# Even x=-2.0 has non-zero gradient (0.1050)Why y.sum().backward() instead of y.backward()?
In PyTorch, backward() can only be called on a scalar (single value), not on a tensor with multiple elements.
The issue:
x = torch.tensor([-2.0, 0.0, 2.0], requires_grad=True)
y = F.silu(x) # y has 3 elements: tensor([-0.2384, 0., 1.7616])
# This would fail:
# y.backward() # RuntimeError: grad can be implicitly created only for scalar outputsWhy? Because backward() computes y is a vector, PyTorch doesn't know which scalar loss you want to differentiate with respect to.
Solutions:
-
Sum to scalar (most common for demos):
y.sum().backward() # Computes gradient of sum(y) w.r.t. x
This treats the "loss" as
$L = \sum_i y_i$ , so$\frac{\partial L}{\partial x_i} = \frac{\partial y_i}{\partial x_i}$ . -
Mean to scalar:
y.mean().backward() # Computes gradient of mean(y) w.r.t. x
-
Select one element:
y[0].backward() # Only compute gradient for first element
-
Provide gradient weights explicitly:
y.backward(torch.ones_like(y)) # Equivalent to y.sum().backward()
The argument to
backward()specifies$\frac{\partial L}{\partial y}$ (gradient of loss w.r.t. output).
In neural network training, you don't need .sum() because:
- Loss functions already return scalars
loss.backward()works directly
# Real training example
loss = F.mse_loss(pred, target) # loss is already scalar
loss.backward() # Works!For our demo, we use .sum() just to show that gradients exist and are non-zero, even for negative inputs.
Note: The notebook now imports SimpleScoreNetwork from genailab.diffusion module:
from genailab.diffusion import SimpleScoreNetwork as ScoreNetThe architecture shown below is the same as the module implementation, shown here for educational purposes:
class SimpleScoreNetwork(nn.Module):
"""Simple MLP score network for 2D data."""
def __init__(self, data_dim=2, hidden_dim=128, time_dim=32):
super().__init__()
# Time embedding (sinusoidal)
self.time_dim = time_dim
# Network
self.net = nn.Sequential(
nn.Linear(data_dim + time_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, data_dim),
)Design choices:
- MLP (Multi-Layer Perceptron): Simple fully connected layers
- 3 hidden layers: Deep enough to model complex score functions
- SiLU activations: Smooth, non-linear transformations
- Time embedding: Sinusoidal encoding of time (like Transformer positional encoding)
Why it works for 2D toy data:
- 2D data has no spatial structure (just points in plane)
- MLPs are sufficient for learning score functions of 2D distributions
- No need for convolutions (no spatial locality to exploit)
Problem: Images have:
- Spatial structure: Nearby pixels are correlated
- Local patterns: Edges, textures, shapes
- High dimensionality: 256×256×3 = 196,608 dimensions
- Translation invariance: A cat is a cat regardless of position
MLP issues:
- Treats every pixel independently
- Ignores spatial relationships
- Requires massive number of parameters
- Can't generalize across positions
U-Net is the standard architecture for image diffusion models (DDPM, Stable Diffusion, etc.)
Input image (+ time embedding)
↓
[Encoder - Downsampling]
Conv → Conv → ↓ (downsample)
Conv → Conv → ↓
Conv → Conv → ↓
↓
[Bottleneck - Latent Representation]
↓
[Decoder - Upsampling]
Conv → Conv → ↑ (upsample) ← skip connection
Conv → Conv → ↑ ← skip connection
Conv → Conv → ↑ ← skip connection
↓
Output (predicted noise or score)
Key features:
- Convolutional layers: Exploit spatial locality
- Downsampling: Capture features at multiple scales
- Skip connections: Preserve fine details from encoder
- Time conditioning: Time embedding injected at each resolution
class UNetScoreNetwork(nn.Module):
def __init__(self, in_channels=3, base_channels=64):
super().__init__()
# Time embedding
self.time_embed = nn.Sequential(
nn.Linear(1, 128),
nn.SiLU(),
nn.Linear(128, base_channels)
)
# Encoder (downsampling)
self.enc1 = ConvBlock(in_channels, base_channels) # 256×256
self.enc2 = ConvBlock(base_channels, base_channels*2) # 128×128
self.enc3 = ConvBlock(base_channels*2, base_channels*4) # 64×64
# Bottleneck
self.bottleneck = ConvBlock(base_channels*4, base_channels*8) # 32×32
# Decoder (upsampling)
self.dec3 = ConvBlock(base_channels*8, base_channels*4) # 64×64
self.dec2 = ConvBlock(base_channels*4, base_channels*2) # 128×128
self.dec1 = ConvBlock(base_channels*2, base_channels) # 256×256
# Output
self.out = nn.Conv2d(base_channels, in_channels, 3, padding=1)
def forward(self, x, t):
# Time embedding
t_emb = self.time_embed(t.unsqueeze(-1))
# Encoder with skip connections
e1 = self.enc1(x, t_emb)
e2 = self.enc2(F.avg_pool2d(e1, 2), t_emb)
e3 = self.enc3(F.avg_pool2d(e2, 2), t_emb)
# Bottleneck
b = self.bottleneck(F.avg_pool2d(e3, 2), t_emb)
# Decoder with skip connections
d3 = self.dec3(F.interpolate(b, scale_factor=2) + e3, t_emb)
d2 = self.dec2(F.interpolate(d3, scale_factor=2) + e2, t_emb)
d1 = self.dec1(F.interpolate(d2, scale_factor=2) + e1, t_emb)
return self.out(d1)Why U-Net?
- Preserves spatial structure
- Multi-scale feature extraction
- Efficient (reuses features via skip connections)
- Proven effective for image generation
Modern U-Nets often use:
- SiLU (or Swish): For most layers
- GroupNorm: Instead of BatchNorm (works better for small batches)
- Attention: At lower resolutions (e.g., 16×16) for global context
DiT = Diffusion Transformer (Peebles & Xie, 2023)
Input image → Patch Embedding
↓
[Transformer Blocks × N]
Self-Attention (with time conditioning)
Feed-Forward (with time conditioning)
↓
Unpatch → Output image
Key features:
- Patch-based: Treats image as sequence of patches
- Self-attention: Captures global relationships
- Adaptive LayerNorm: Conditions on time via AdaLN
- Scalability: Larger models → better quality
class DiTScoreNetwork(nn.Module):
def __init__(self, img_size=256, patch_size=16, dim=512, depth=12):
super().__init__()
# Patch embedding
self.patch_embed = nn.Conv2d(3, dim, patch_size, stride=patch_size)
num_patches = (img_size // patch_size) ** 2
# Positional embedding
self.pos_embed = nn.Parameter(torch.randn(1, num_patches, dim))
# Time embedding
self.time_embed = nn.Sequential(
nn.Linear(1, dim),
nn.SiLU(),
nn.Linear(dim, dim*2) # For AdaLN (scale & shift)
)
# Transformer blocks
self.blocks = nn.ModuleList([
DiTBlock(dim, num_heads=8) for _ in range(depth)
])
# Unpatch
self.unpatch = nn.Linear(dim, 3 * patch_size ** 2)
def forward(self, x, t):
# Patch embedding
x = self.patch_embed(x) # (B, C, H, W) → (B, dim, H/P, W/P)
x = x.flatten(2).transpose(1, 2) # → (B, num_patches, dim)
x = x + self.pos_embed
# Time embedding
t_emb = self.time_embed(t.unsqueeze(-1)) # (B, dim*2)
# Transformer blocks with time conditioning
for block in self.blocks:
x = block(x, t_emb)
# Unpatch
x = self.unpatch(x) # (B, num_patches, 3*P*P)
# Reshape to image...
return xWhy DiT?
- Scalability: Can scale to billions of parameters
- Global context: Self-attention captures long-range dependencies
- State-of-the-art: Achieves best FID scores on ImageNet
- Flexible: Easy to condition on additional inputs
Trade-offs:
- Compute: More expensive than U-Net
- Data: Requires more data to train effectively
-
Memory: Attention is
$O(n^2)$ in sequence length
| Aspect | MLP (2D data) | U-Net (Images) | DiT (Transformers) |
|---|---|---|---|
| Data type | Low-dim (e.g., 2D points) | Images, medical scans | Images (large-scale) |
| Spatial structure | None | Local (convolutions) | Global (attention) |
| Parameters | ~100K | ~100M | ~100M-1B |
| Compute | Low | Medium | High |
| Training data | Small (10K samples) | Medium (100K-1M) | Large (1M-100M) |
| Use case | Toy problems, demos | Practical image generation | State-of-the-art generation |
| Activation | SiLU | SiLU + GroupNorm | SiLU (in FFN) |
For medical images (X-rays, MRIs, CT scans), you would typically use:
Why U-Net is popular for medical imaging:
- Proven effective: Originally designed for medical image segmentation
- Works with small datasets: Medical data is often limited
- Preserves details: Skip connections maintain fine structures
- 3D extensions: Easy to adapt for 3D volumes (MRI, CT)
Modifications for medical imaging:
class Medical3DUNet(nn.Module):
"""3D U-Net for medical volumes."""
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
# Use 3D convolutions instead of 2D
self.enc1 = nn.Conv3d(in_channels, 64, 3, padding=1)
# ... encoder blocks ...
# Same U-Net structure but in 3D
# ...Special considerations:
- Normalization: GroupNorm or InstanceNorm (works better than BatchNorm for small medical datasets)
- Activation: SiLU or GELU
- Conditioning: Can condition on patient metadata (age, sex, etc.)
- Domain-specific: May need to preserve certain properties (Hounsfield units for CT)
Modern approaches combine best of both:
- U-ViT: U-Net with Vision Transformer in bottleneck
- Swin Transformer: Hierarchical transformer (like U-Net structure)
- CoAtNet: Convolutions early, attention later
SiLU advantages:
- Smooth (no sharp corners) → better gradients
- Non-zero for negative values → no dying neurons
- Self-gating mechanism → adaptive feature selection
- Empirically better than ReLU in deep networks
When to use:
- ✅ Deep networks (diffusion models, transformers)
- ✅ When gradient flow is critical
- ✅ Modern architectures (post-2017)
| Task | Architecture | Activation | Why |
|---|---|---|---|
| 2D toy data | MLP | SiLU | No spatial structure, simple |
| Images (general) | U-Net | SiLU + GroupNorm | Spatial structure, proven effective |
| Images (SOTA) | DiT / U-ViT | SiLU (+ GELU in attention) | Scale, global context |
| Medical 2D | U-Net | SiLU + InstanceNorm | Small datasets, detail preservation |
| Medical 3D | 3D U-Net | SiLU + GroupNorm | Volumetric data, isotropic structure |
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Compare activations
x = torch.linspace(-5, 5, 200)
relu = F.relu(x)
silu = F.silu(x)
plt.figure(figsize=(10, 4))
plt.plot(x, relu, label='ReLU', linewidth=2)
plt.plot(x, silu, label='SiLU', linewidth=2)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.grid(True, alpha=0.3)
plt.legend()
plt.xlabel('x')
plt.ylabel('Activation(x)')
plt.title('ReLU vs SiLU')
plt.show()See the simplified example above, or check out:
- DDPM paper: github.com/hojonathanho/diffusion
- Hugging Face: github.com/huggingface/diffusers
- SiLU/Swish: Ramachandran et al. (2017), "Searching for Activation Functions"
- GELU: Hendrycks & Gimpel (2016), "Gaussian Error Linear Units"
- U-Net: Ronneberger et al. (2015), "U-Net: Convolutional Networks for Biomedical Image Segmentation"
- DDPM: Ho et al. (2020), "Denoising Diffusion Probabilistic Models"
- DiT: Peebles & Xie (2023), "Scalable Diffusion Models with Transformers"
- Stable Diffusion: Rombach et al. (2022), "High-Resolution Image Synthesis with Latent Diffusion Models"
- Advanced Architectures: advanced_architectures.md
- Time Embedding and FiLM: time_embedding_and_film.md
- Training Loss and Denoising: ../02_sde_formulation/supplements/04_training_loss_and_denoising.md
- Notebook:
notebooks/diffusion/02_sde_formulation/02_sde_formulation.ipynb