Diffusion Policies for Robotics¶
Diffusion Policies represent a breakthrough in imitation learning, enabling multi-modal action distributions and superior performance on complex manipulation tasks.
Paper: "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" (Chi et al., RSS 2023)
Why Diffusion for Robot Learning?¶
Traditional behavioral cloning fails on multi-modal action distributions:
# Problem: Averaging diverse expert actions
expert_demo_1: grasp_from_top() # Action: [0, 0, -1, ...]
expert_demo_2: grasp_from_side() # Action: [1, 0, 0, ...]
# Behavioral Cloning averages:
bc_action = mean([demo_1, demo_2]) # Result: [0.5, 0, -0.5, ...] ✗Invalid!
Diffusion Policies model the full distribution instead of collapsing to a single mode.
Core Concept¶
Diffusion models learn to denoise:
graph LR
A[Pure Noise] -->|Denoise| B[Slightly Less Noisy]
B -->|Denoise| C[Recognizable Action]
C -->|Denoise| D[Clean Action]
style A fill:#ff9999
style B fill:#ffcc99
style C fill:#99ccff
style D fill:#99ff99
Mathematical Foundation¶
Forward Process (add noise):
Reverse Process (denoise):
Where: - \(a_t\): Noisy action at diffusion step \(t\) - \(o\): Observation (image, state, etc.) - \(\beta_t\): Noise schedule - \(\theta\): Neural network parameters
Architecture¶
import torch
import torch.nn as nn
class DiffusionPolicy(nn.Module):
"""
Diffusion Policy for visuomotor control
Adapted from Chi et al., RSS 2023
"""
def __init__(self, obs_dim, action_dim, action_horizon=16, diffusion_steps=100):
super().__init__()
self.action_dim = action_dim
self.action_horizon = action_horizon
self.diffusion_steps = diffusion_steps
# Vision encoder (ResNet or ViT)
self.vision_encoder = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.GroupNorm(8, 64),
nn.ReLU(),
ResNetBlocks(num_blocks=3),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(512, 256)
)
# Observation encoder (for proprioception)
self.obs_encoder = nn.Sequential(
nn.Linear(obs_dim, 128),
nn.ReLU(),
nn.Linear(128, 128)
)
# Noise prediction network (U-Net style)
self.noise_pred_net = ConditionalUNet1D(
input_dim=action_dim,
global_cond_dim=256 + 128, # vision + obs
diffusion_step_embed_dim=128,
down_dims=[256, 512, 1024],
kernel_size=5,
n_groups=8
)
# Noise schedule (DDPM)
self.betas = self.make_beta_schedule(diffusion_steps)
self.alphas = 1 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
def make_beta_schedule(self, num_steps, schedule='cosine'):
"""Create noise schedule"""
if schedule == 'linear':
return torch.linspace(0.0001, 0.02, num_steps)
elif schedule == 'cosine':
# Improved cosine schedule from Nichol & Dhariwal 2021
s = 0.008
steps = num_steps + 1
x = torch.linspace(0, num_steps, steps)
alphas_cumprod = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def forward(self, obs, actions=None):
"""
Training: predict noise given noisy actions
Inference: iteratively denoise random noise
"""
# Encode observations
if 'image' in obs:
vision_features = self.vision_encoder(obs['image'])
else:
vision_features = torch.zeros(obs['state'].shape[0], 256).to(obs['state'].device)
obs_features = self.obs_encoder(obs['state'])
cond = torch.cat([vision_features, obs_features], dim=-1)
if self.training:
# Training: denoise actions
return self.compute_loss(cond, actions)
else:
# Inference: generate actions
return self.generate_actions(cond)
def compute_loss(self, cond, actions):
"""Compute diffusion training loss"""
batch_size = actions.shape[0]
# Sample random diffusion timesteps
t = torch.randint(0, self.diffusion_steps, (batch_size,), device=actions.device)
# Add noise to actions
noise = torch.randn_like(actions)
noisy_actions = self.add_noise(actions, noise, t)
# Predict the noise
predicted_noise = self.noise_pred_net(
noisy_actions,
timestep=t,
global_cond=cond
)
# MSE loss
loss = F.mse_loss(predicted_noise, noise)
return loss
def add_noise(self, actions, noise, t):
"""Add noise to actions according to schedule"""
sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod[t])
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod[t])
# Expand dimensions for broadcasting
sqrt_alphas_cumprod = sqrt_alphas_cumprod.view(-1, 1, 1)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.view(-1, 1, 1)
noisy_actions = sqrt_alphas_cumprod * actions + sqrt_one_minus_alphas_cumprod * noise
return noisy_actions
@torch.no_grad()
def generate_actions(self, cond, num_samples=1):
"""Generate action sequence through iterative denoising"""
batch_size = cond.shape[0]
# Start from random noise
actions = torch.randn(
batch_size,
self.action_horizon,
self.action_dim,
device=cond.device
)
# Iteratively denoise
for t in reversed(range(self.diffusion_steps)):
# Predict noise
t_batch = torch.full((batch_size,), t, device=cond.device, dtype=torch.long)
predicted_noise = self.noise_pred_net(
actions,
timestep=t_batch,
global_cond=cond
)
# Denoise step
actions = self.denoise_step(actions, predicted_noise, t)
return actions
def denoise_step(self, noisy_actions, predicted_noise, t):
"""Single denoising step (DDPM)"""
alpha = self.alphas[t]
alpha_cumprod = self.alphas_cumprod[t]
beta = self.betas[t]
# Predict x_0 (clean actions)
predicted_actions = (
noisy_actions - torch.sqrt(1 - alpha_cumprod) * predicted_noise
) / torch.sqrt(alpha_cumprod)
# Compute mean for p(x_{t-1} | x_t)
if t > 0:
alpha_cumprod_prev = self.alphas_cumprod[t - 1]
else:
alpha_cumprod_prev = torch.tensor(1.0)
predicted_mean = (
torch.sqrt(alpha_cumprod_prev) * beta / (1 - alpha_cumprod) * predicted_actions +
torch.sqrt(alpha) * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod) * noisy_actions
)
# Add noise (except at last step)
if t > 0:
noise = torch.randn_like(noisy_actions)
variance = beta * (1 - alpha_cumprod_prev) / (1 - alpha_cumprod)
predicted_mean = predicted_mean + torch.sqrt(variance) * noise
return predicted_mean
Key Innovations¶
1. Action Chunking with Receding Horizon¶
Predict multiple future actions, execute only the first few:
# Predict 16-step action sequence
action_sequence = diffusion_policy.generate_actions(obs) # Shape: (1, 16, 7)
# Execute only first 8 steps
for i in range(8):
robot.execute(action_sequence[0, i])
time.sleep(1/control_frequency)
# Re-predict for next window
Benefits: - Temporal coherence - Smoother trajectories - Better long-horizon planning
2. Conditional U-Net Architecture¶
class ConditionalUNet1D(nn.Module):
"""U-Net for 1D action sequences with global conditioning"""
def __init__(self, input_dim, global_cond_dim, diffusion_step_embed_dim, down_dims):
super().__init__()
# Time embedding (sinusoidal)
self.time_embedding = SinusoidalPosEmb(diffusion_step_embed_dim)
# Global conditioning projection
self.global_cond_proj = nn.Linear(global_cond_dim, down_dims[0])
# Downsampling path
self.down_blocks = nn.ModuleList()
in_dim = input_dim
for out_dim in down_dims:
self.down_blocks.append(
ResidualBlock1D(in_dim, out_dim, diffusion_step_embed_dim)
)
in_dim = out_dim
# Upsampling path
self.up_blocks = nn.ModuleList()
for out_dim in reversed(down_dims[:-1]):
self.up_blocks.append(
ResidualBlock1D(in_dim + out_dim, out_dim, diffusion_step_embed_dim) # Skip connection
)
in_dim = out_dim
# Final layer
self.final = nn.Conv1d(in_dim, input_dim, 1)
def forward(self, x, timestep, global_cond):
# Time embedding
t_emb = self.time_embedding(timestep)
# Global condition
g_emb = self.global_cond_proj(global_cond)
# Downsampling with skip connections
skips = []
for block in self.down_blocks:
x = block(x, t_emb, g_emb)
skips.append(x)
# Upsampling with skip connections
for block, skip in zip(self.up_blocks, reversed(skips[:-1])):
x = torch.cat([x, skip], dim=1)
x = block(x, t_emb, g_emb)
# Final projection
return self.final(x)
3. DDPM vs DDIM Sampling¶
DDPM (Denoising Diffusion Probabilistic Models): - Stochastic sampling - Requires all T diffusion steps - Higher quality but slower
DDIM (Denoising Diffusion Implicit Models): - Deterministic sampling - Can skip steps (e.g., use only 10 steps instead of 100) - Faster inference
@torch.no_grad()
def ddim_sample(self, cond, ddim_steps=10):
"""Fast sampling with DDIM"""
# Select subset of timesteps
timesteps = torch.linspace(0, self.diffusion_steps-1, ddim_steps, dtype=torch.long)
# Start from noise
actions = torch.randn(1, self.action_horizon, self.action_dim, device=cond.device)
# Iterative denoising (only ddim_steps iterations)
for i in reversed(range(ddim_steps)):
t = timesteps[i]
t_prev = timesteps[i-1] if i > 0 else 0
# Predict noise
predicted_noise = self.noise_pred_net(actions, timestep=t, global_cond=cond)
# DDIM update (deterministic)
alpha_t = self.alphas_cumprod[t]
alpha_t_prev = self.alphas_cumprod[t_prev]
predicted_x0 = (actions - torch.sqrt(1 - alpha_t) * predicted_noise) / torch.sqrt(alpha_t)
# Direction pointing to x_t
dir_xt = torch.sqrt(1 - alpha_t_prev) * predicted_noise
# Deterministic step
actions = torch.sqrt(alpha_t_prev) * predicted_x0 + dir_xt
return actions
Training¶
from torch.utils.data import DataLoader
def train_diffusion_policy(model, dataset, config):
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=8)
for epoch in range(config.num_epochs):
for batch in dataloader:
obs = {
'image': batch['images'].cuda(),
'state': batch['states'].cuda()
}
actions = batch['actions'].cuda() # Shape: (B, action_horizon, action_dim)
# Forward pass
loss = model(obs, actions)
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
Performance¶
Benchmark Results¶
On challenging manipulation tasks:
| Task | Behavioral Cloning | Diffusion Policy | Improvement |
|---|---|---|---|
| Can Sorting | 43% | 86% | +100% |
| Tool Use | 31% | 78% | +152% |
| Bimanual | 28% | 72% | +157% |
| Square Insertion | 12% | 48% | +300% |
Why So Much Better?¶
- Multi-modal actions: Handles multiple valid solutions
- Temporal coherence: Action chunking provides smooth trajectories
- Expressiveness: Can model complex action distributions
- Robustness: Less sensitive to distribution shift
Advanced: Image-Based Diffusion Policy¶
class ImageDiffusionPolicy(nn.Module):
"""Diffusion policy with vision transformer"""
def __init__(self, action_dim=7, action_horizon=16):
super().__init__()
# Vision encoder: DINOv2 ViT
self.vision_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
# Temporal aggregation over observations
self.obs_horizon = 3 # Use last 3 observations
self.temporal_aggregator = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=768, nhead=12),
num_layers=2
)
# Diffusion U-Net
self.diffusion_net = ConditionalUNet1D(
input_dim=action_dim,
global_cond_dim=768,
diffusion_step_embed_dim=128,
down_dims=[256, 512, 1024]
)
# ... rest of diffusion machinery
def encode_observations(self, image_sequence):
"""Encode sequence of images"""
# image_sequence: (B, T, C, H, W) where T = obs_horizon
B, T = image_sequence.shape[:2]
# Encode each image
features = []
for t in range(T):
with torch.no_grad():
feat = self.vision_encoder(image_sequence[:, t])
features.append(feat)
features = torch.stack(features, dim=1) # (B, T, 768)
# Temporal aggregation
aggregated = self.temporal_aggregator(features)
# Use last timestep
return aggregated[:, -1, :]
Practical Tips¶
1. Hyperparameter Tuning¶
# Good default hyperparameters
diffusion:
num_diffusion_steps: 100 # Training
ddim_steps: 10 # Inference (much faster)
action_horizon: 16 # Predict 16 steps ahead
obs_horizon: 3 # Use last 3 observations
noise_schedule: 'cosine' # Better than linear
training:
learning_rate: 1e-4
weight_decay: 1e-6
batch_size: 64
epochs: 1000
ema_decay: 0.995 # Exponential moving average
2. Action Normalization¶
Critical for stable training:
# Normalize actions to [-1, 1]
def normalize_actions(actions, stats):
return (actions - stats['mean']) / stats['std']
def denormalize_actions(actions_normalized, stats):
return actions_normalized * stats['std'] + stats['mean']
3. Efficient Inference¶
# Use EMA model for inference
ema_model = copy.deepcopy(model)
# During training
for step in training_steps:
loss.backward()
optimizer.step()
# Update EMA
with torch.no_grad():
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(0.995).add_(param.data, alpha=0.005)
# Use EMA model for inference
ema_model.eval()
actions = ema_model.generate_actions(obs)
Comparisons¶
| Method | Multi-modal | Sample Efficiency | Inference Speed |
|---|---|---|---|
| BC | ✗ | ⭐⭐⭐ | ⭐⭐⭐ |
| VAE Policy | ✓ | ⭐⭐ | ⭐⭐⭐ |
| Normalizing Flows | ✓ | ⭐⭐ | ⭐⭐ |
| Diffusion Policy | ✓ | ⭐⭐⭐ | ⭐ (DDPM) / ⭐⭐ (DDIM) |
References¶
- Chi et al., "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion", RSS 2023
- Ho et al., "Denoising Diffusion Probabilistic Models", NeurIPS 2020
- Song et al., "Denoising Diffusion Implicit Models", ICLR 2021
Code¶
Full implementation: https://github.com/real-stanford/diffusion_policy
Next Steps¶
- ACT (Action Chunking) - Another approach to temporal action prediction
- Transformers for IL - Transformer-based imitation learning
- Training Guide - Train your diffusion policy