ACT: Action Chunking with Transformers¶
ACT (Action Chunking with Transformers) is a state-of-the-art imitation learning method that dramatically improves success rates on complex manipulation tasks.
Paper: "Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware" (Zhao et al., 2023)
GitHub: https://github.com/tonyzhaozh/act
Overview¶
ACT addresses critical limitations of traditional behavioral cloning:
- ✗Problem: Standard BC predicts one action at a time → compounding errors
- ✓ Solution: Predict action sequences (chunks) → temporal consistency
- ✗Problem: Deterministic policies can't handle multimodality
- ✓ Solution: Use CVAE (Conditional VAE) → model action distributions
Key Results: - Success rate: 80-90% on complex bimanual tasks - Sample efficiency: Works with 50-100 demonstrations - Hardware: Runs on low-cost robot arms (<$20K total)
Architecture¶
ACT combines transformers with a CVAE for temporal action prediction:
graph TD
A[Images<br/>Multiple Cameras] --> B[ResNet Encoders]
C[Proprio State] --> D[MLP Encoder]
E[Actions<br/>t-4 to t-1] --> F[Action Encoder]
B --> G[Transformer<br/>Encoder]
D --> G
F --> G
H[Latent z] --> I[Transformer<br/>Decoder]
G --> I
I --> J[Action Chunk<br/>t to t+k]
style H fill:#f9f,stroke:#333
style J fill:#9f9,stroke:#333
Complete Implementation¶
import torch
import torch.nn as nn
from torch.nn import functional as F
class ACT(nn.Module):
"""
Action Chunking with Transformers
Predicts sequences of actions conditioned on observations
Uses CVAE to model multimodal action distributions
"""
def __init__(
self,
state_dim,
action_dim,
chunk_size=100, # Predict 100 future actions
hidden_dim=512,
num_encoder_layers=4,
num_decoder_layers=7,
num_heads=8,
latent_dim=32, # CVAE latent dimension
):
super().__init__()
self.chunk_size = chunk_size
self.latent_dim = latent_dim
self.action_dim = action_dim
# Vision encoder (per camera)
self.vision_encoder = ResNetEncoder(output_dim=hidden_dim)
# State encoder
self.state_encoder = nn.Linear(state_dim, hidden_dim)
# Action encoder (for past actions)
self.action_encoder = nn.Linear(action_dim, hidden_dim)
# Transformer encoder (processes observations)
self.encoder = TransformerEncoder(
dim=hidden_dim,
depth=num_encoder_layers,
heads=num_heads
)
# CVAE components
# Encoder: (obs, actions) -> latent z
self.latent_encoder = nn.Sequential(
nn.Linear(hidden_dim + action_dim * chunk_size, 256),
nn.ReLU(),
nn.Linear(256, latent_dim * 2) # mean and logvar
)
# Prior: obs -> latent z (for inference without actions)
self.latent_prior = nn.Sequential(
nn.Linear(hidden_dim, 256),
nn.ReLU(),
nn.Linear(256, latent_dim * 2)
)
# Decoder: (obs, latent) -> action chunk
self.latent_projection = nn.Linear(latent_dim, hidden_dim)
self.decoder = TransformerDecoder(
dim=hidden_dim,
depth=num_decoder_layers,
heads=num_heads
)
# Action prediction head
self.action_head = nn.Linear(hidden_dim, action_dim)
# Learnable query tokens for decoder
self.query_embed = nn.Embedding(chunk_size, hidden_dim)
def encode_observations(self, images, state, past_actions=None):
"""
Encode multi-modal observations
Args:
images: list of (B, 3, H, W) from different cameras
state: (B, state_dim) proprioceptive state
past_actions: (B, history_len, action_dim) recent actions
Returns:
encoded: (B, seq_len, hidden_dim)
"""
B = state.shape[0]
tokens = []
# Encode images from all cameras
for img in images:
vis_feat = self.vision_encoder(img) # (B, hidden_dim)
tokens.append(vis_feat.unsqueeze(1)) # (B, 1, hidden_dim)
# Encode proprioceptive state
state_feat = self.state_encoder(state).unsqueeze(1) # (B, 1, hidden_dim)
tokens.append(state_feat)
# Encode past actions if provided
if past_actions is not None:
for t in range(past_actions.shape[1]):
action_feat = self.action_encoder(past_actions[:, t]).unsqueeze(1)
tokens.append(action_feat)
# Concatenate all tokens
all_tokens = torch.cat(tokens, dim=1) # (B, seq_len, hidden_dim)
# Process with transformer encoder
encoded = self.encoder(all_tokens)
return encoded
def encode_latent(self, obs_encoded, actions=None):
"""
Encode latent variable z
During training: encode from (obs, actions)
During inference: sample from prior p(z|obs)
Args:
obs_encoded: (B, seq_len, hidden_dim)
actions: (B, chunk_size, action_dim) ground truth actions (training only)
Returns:
z: (B, latent_dim) sampled latent
mu, logvar: distribution parameters
"""
# Global observation representation (use CLS token or mean)
obs_global = obs_encoded.mean(dim=1) # (B, hidden_dim)
if actions is not None:
# Training: encode from (obs, actions)
actions_flat = actions.reshape(actions.shape[0], -1) # (B, chunk_size * action_dim)
combined = torch.cat([obs_global, actions_flat], dim=-1)
mu_logvar = self.latent_encoder(combined)
else:
# Inference: use prior
mu_logvar = self.latent_prior(obs_global)
mu, logvar = mu_logvar.chunk(2, dim=-1)
# Reparameterization trick
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z, mu, logvar
def decode_actions(self, obs_encoded, latent):
"""
Decode action chunk from observations and latent
Args:
obs_encoded: (B, seq_len, hidden_dim)
latent: (B, latent_dim)
Returns:
actions: (B, chunk_size, action_dim)
"""
B = obs_encoded.shape[0]
# Project latent to hidden dim
latent_proj = self.latent_projection(latent).unsqueeze(1) # (B, 1, hidden_dim)
# Create query tokens for decoder
query_tokens = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1) # (B, chunk_size, hidden_dim)
# Add latent to each query
query_tokens = query_tokens + latent_proj
# Transformer decoder: cross-attend to observations
decoded = self.decoder(query_tokens, obs_encoded) # (B, chunk_size, hidden_dim)
# Predict actions
actions = self.action_head(decoded) # (B, chunk_size, action_dim)
return actions
def forward(self, images, state, past_actions=None, ground_truth_actions=None):
"""
Full forward pass
Training: use ground truth actions to compute CVAE loss
Inference: sample from prior
"""
# Encode observations
obs_encoded = self.encode_observations(images, state, past_actions)
# Encode latent
latent, mu, logvar = self.encode_latent(obs_encoded, ground_truth_actions)
# Decode actions
predicted_actions = self.decode_actions(obs_encoded, latent)
return predicted_actions, mu, logvar
def compute_loss(self, batch):
"""
Compute CVAE loss
L = Reconstruction loss + KL divergence
"""
images = batch['images'] # List of camera views
state = batch['state']
actions = batch['actions'] # (B, chunk_size, action_dim)
# Forward pass
predicted_actions, mu, logvar = self.forward(
images, state,
ground_truth_actions=actions
)
# Reconstruction loss (L2)
recon_loss = F.mse_loss(predicted_actions, actions)
# KL divergence
# KL(q(z|x,a) || p(z|x)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]
# Total loss with KL weight (important for stability)
kl_weight = 10.0 # Tune this
total_loss = recon_loss + kl_weight * kl_loss
return total_loss, {
'recon_loss': recon_loss.item(),
'kl_loss': kl_loss.item(),
'total_loss': total_loss.item()
}
@torch.no_grad()
def predict(self, images, state, past_actions=None):
"""
Inference: predict action chunk
Sample latent from prior and decode
"""
self.eval()
# Encode observations
obs_encoded = self.encode_observations(images, state, past_actions)
# Sample from prior
latent, _, _ = self.encode_latent(obs_encoded, actions=None)
# Decode actions
actions = self.decode_actions(obs_encoded, latent)
return actions
class ResNetEncoder(nn.Module):
"""Vision encoder using ResNet"""
def __init__(self, output_dim=512):
super().__init__()
# Load pre-trained ResNet18
resnet = torchvision.models.resnet18(pretrained=True)
# Remove final FC layer
self.encoder = nn.Sequential(*list(resnet.children())[:-1])
# Project to output dim
self.projection = nn.Linear(512, output_dim)
def forward(self, x):
"""
Args:
x: (B, 3, H, W)
Returns:
features: (B, output_dim)
"""
features = self.encoder(x)
features = features.flatten(1)
features = self.projection(features)
return features
class TransformerEncoder(nn.Module):
"""Standard transformer encoder"""
def __init__(self, dim=512, depth=4, heads=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerEncoderLayer(dim, heads)
for _ in range(depth)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class TransformerDecoder(nn.Module):
"""Transformer decoder with cross-attention"""
def __init__(self, dim=512, depth=7, heads=8):
super().__init__()
self.layers = nn.ModuleList([
TransformerDecoderLayer(dim, heads)
for _ in range(depth)
])
def forward(self, queries, context):
"""
Args:
queries: (B, query_len, dim)
context: (B, context_len, dim)
"""
for layer in self.layers:
queries = layer(queries, context)
return queries
class TransformerEncoderLayer(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.attention = nn.MultiheadAttention(dim, heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
# Feed-forward
ff_out = self.ff(x)
x = self.norm2(x + ff_out)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, dim, heads):
super().__init__()
self.self_attention = nn.MultiheadAttention(dim, heads, batch_first=True)
self.cross_attention = nn.MultiheadAttention(dim, heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def forward(self, queries, context):
# Self-attention on queries
self_attn_out, _ = self.self_attention(queries, queries, queries)
queries = self.norm1(queries + self_attn_out)
# Cross-attention to context
cross_attn_out, _ = self.cross_attention(queries, context, context)
queries = self.norm2(queries + cross_attn_out)
# Feed-forward
ff_out = self.ff(queries)
queries = self.norm3(queries + ff_out)
return queries
Training ACT¶
Data Collection¶
ACT works best with multi-camera observations:
class ACTDataset(torch.utils.data.Dataset):
"""Dataset for ACT training"""
def __init__(self, demonstrations, chunk_size=100):
self.demonstrations = demonstrations
self.chunk_size = chunk_size
def __len__(self):
# Count valid chunks across all demonstrations
total_chunks = 0
for demo in self.demonstrations:
total_chunks += max(0, len(demo) - self.chunk_size)
return total_chunks
def __getitem__(self, idx):
# Find demonstration and chunk index
demo_idx = 0
chunk_idx = idx
while chunk_idx >= len(self.demonstrations[demo_idx]) - self.chunk_size:
chunk_idx -= len(self.demonstrations[demo_idx]) - self.chunk_size
demo_idx += 1
demo = self.demonstrations[demo_idx]
# Extract observation at chunk start
obs = demo[chunk_idx]
# Extract action chunk
actions = np.array([
demo[chunk_idx + i]['action']
for i in range(self.chunk_size)
])
return {
'images': [obs['camera_1'], obs['camera_2'], obs['camera_3']], # Multi-camera
'state': obs['state'],
'actions': actions # (chunk_size, action_dim)
}
Training Loop¶
def train_act(model, dataset, config):
"""Train ACT model"""
dataloader = DataLoader(
dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=4
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# KL annealing schedule (gradually increase KL weight)
kl_weight_schedule = np.linspace(0, 10.0, config.kl_warmup_steps)
global_step = 0
for epoch in range(config.num_epochs):
for batch in dataloader:
# Move to GPU
batch = {k: v.cuda() if isinstance(v, torch.Tensor) else [img.cuda() for img in v]
for k, v in batch.items()}
# Forward pass
total_loss, loss_dict = model.compute_loss(batch)
# KL annealing
if global_step < config.kl_warmup_steps:
# Reweight KL term
kl_weight = kl_weight_schedule[global_step]
total_loss = loss_dict['recon_loss'] + kl_weight * loss_dict['kl_loss']
# Backward pass
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()
global_step += 1
# Logging
if global_step % 100 == 0:
print(f"Step {global_step}:")
print(f" Recon Loss: {loss_dict['recon_loss']:.4f}")
print(f" KL Loss: {loss_dict['kl_loss']:.4f}")
print(f" Total Loss: {total_loss.item():.4f}")
return model
Deployment: Temporal Ensembling¶
ACT's key innovation at deployment: temporal ensembling
class ACTController:
"""ACT controller with temporal ensembling"""
def __init__(self, model, chunk_size=100, k_obs=1):
self.model = model
self.chunk_size = chunk_size
self.k_obs = k_obs # Number of observations to ensemble
# Rolling buffer of predicted action chunks
self.action_buffer = collections.deque(maxlen=k_obs)
def predict_action(self, observation, timestep):
"""
Predict action using temporal ensembling
Key idea: Average predictions from last k_obs observations
"""
# Predict action chunk from current observation
with torch.no_grad():
action_chunk = self.model.predict(
images=observation['images'],
state=observation['state']
) # (1, chunk_size, action_dim)
# Add to buffer
self.action_buffer.append({
'chunk': action_chunk.squeeze(0).cpu().numpy(),
'query_timestep': timestep
})
# Temporal ensembling: average predictions
action = self.ensemble_actions(timestep)
return action
def ensemble_actions(self, current_timestep):
"""
Ensemble action predictions from recent observations
For each buffered prediction, extract the action at current timestep
and average
"""
if len(self.action_buffer) == 0:
return None
actions_to_average = []
for prediction in self.action_buffer:
# How many steps since this prediction was made?
steps_since_query = current_timestep - prediction['query_timestep']
# Extract action at current timestep
if steps_since_query < self.chunk_size:
action = prediction['chunk'][steps_since_query]
actions_to_average.append(action)
# Average (exponentially weighted recency)
weights = np.exp(-0.1 * np.arange(len(actions_to_average)))
weights = weights / weights.sum()
ensembled_action = np.average(actions_to_average, axis=0, weights=weights)
return ensembled_action
# Deployment
controller = ACTController(model, chunk_size=100, k_obs=1)
obs = env.reset()
for t in range(500):
action = controller.predict_action(obs, timestep=t)
obs, reward, done, info = env.step(action)
if done:
break
Why temporal ensembling works: - Reduces variance in predictions - Smoother action sequences - Better handling of observation noise
Comparison with Other Methods¶
| Method | Action Horizon | Multimodal | Success Rate | Sample Efficiency |
|---|---|---|---|---|
| BC | 1 | ✗ | 40% | Poor |
| Diffusion Policy | 16 | ✓ | 75% | Good |
| ACT | 100 | ✓ | 85% | Very Good |
| VLA (RT-2) | 1 | ✓ | 90%* | Excellent* |
*With web-scale pre-training
Key Advantages¶
vs Standard BC: - ✓ Longer action horizons → less compounding error - ✓ CVAE → handles multimodal distributions - ✓ Temporal ensembling → smoother execution
vs Diffusion Policy: - ✓ Longer chunks (100 vs 16) - ✓ Faster inference (single forward pass vs iterative denoising) - ✓ Better for long-horizon tasks
Tips for Success¶
1. Chunk Size Selection¶
# Rule of thumb: chunk_size should cover ~3-5 seconds of execution
fps = 30 # Control frequency
duration_seconds = 3
chunk_size = fps * duration_seconds # 90
# Or: make it cover typical subtask duration
# E.g., reaching takes ~2 sec, so chunk_size >= 60
2. KL Weight Tuning¶
# Start low, increase gradually
kl_weights = [1.0, 5.0, 10.0, 20.0]
for kl_weight in kl_weights:
model.kl_weight = kl_weight
train_epoch(model)
# Check: if KL loss → 0, increase weight
# if recon loss increases a lot, decrease weight
3. Data Collection¶
- Use multiple camera views (3+ recommended)
- Ensure diverse demonstrations (different initial conditions)
- Collect 50-100 demos minimum per task
- Include failure recovery demonstrations
Resources¶
- Paper: https://arxiv.org/abs/2304.13705
- Code: https://github.com/tonyzhaozh/act
- Project Page: https://tonyzhaozh.github.io/aloha/
Next Steps¶
- Diffusion Policies - Compare with diffusion-based IL
- Transformers for IL - Other transformer approaches
- Training Guide - General IL training tips
- Data Collection - Collect quality demonstrations