Skip to content

ACT: Action Chunking with Transformers

ACT (Action Chunking with Transformers) is a state-of-the-art imitation learning method that dramatically improves success rates on complex manipulation tasks.

Paper: "Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware" (Zhao et al., 2023)

GitHub: https://github.com/tonyzhaozh/act

Overview

ACT addresses critical limitations of traditional behavioral cloning:

  • Problem: Standard BC predicts one action at a time → compounding errors
  • Solution: Predict action sequences (chunks) → temporal consistency
  • Problem: Deterministic policies can't handle multimodality
  • Solution: Use CVAE (Conditional VAE) → model action distributions

Key Results: - Success rate: 80-90% on complex bimanual tasks - Sample efficiency: Works with 50-100 demonstrations - Hardware: Runs on low-cost robot arms (<$20K total)

Architecture

ACT combines transformers with a CVAE for temporal action prediction:

graph TD
    A[Images<br/>Multiple Cameras] --> B[ResNet Encoders]
    C[Proprio State] --> D[MLP Encoder]
    E[Actions<br/>t-4 to t-1] --> F[Action Encoder]

    B --> G[Transformer<br/>Encoder]
    D --> G
    F --> G

    H[Latent z] --> I[Transformer<br/>Decoder]
    G --> I

    I --> J[Action Chunk<br/>t to t+k]

    style H fill:#f9f,stroke:#333
    style J fill:#9f9,stroke:#333

Complete Implementation

import torch
import torch.nn as nn
from torch.nn import functional as F

class ACT(nn.Module):
    """
    Action Chunking with Transformers

    Predicts sequences of actions conditioned on observations
    Uses CVAE to model multimodal action distributions
    """
    def __init__(
        self,
        state_dim,
        action_dim,
        chunk_size=100,  # Predict 100 future actions
        hidden_dim=512,
        num_encoder_layers=4,
        num_decoder_layers=7,
        num_heads=8,
        latent_dim=32,  # CVAE latent dimension
    ):
        super().__init__()

        self.chunk_size = chunk_size
        self.latent_dim = latent_dim
        self.action_dim = action_dim

        # Vision encoder (per camera)
        self.vision_encoder = ResNetEncoder(output_dim=hidden_dim)

        # State encoder
        self.state_encoder = nn.Linear(state_dim, hidden_dim)

        # Action encoder (for past actions)
        self.action_encoder = nn.Linear(action_dim, hidden_dim)

        # Transformer encoder (processes observations)
        self.encoder = TransformerEncoder(
            dim=hidden_dim,
            depth=num_encoder_layers,
            heads=num_heads
        )

        # CVAE components
        # Encoder: (obs, actions) -> latent z
        self.latent_encoder = nn.Sequential(
            nn.Linear(hidden_dim + action_dim * chunk_size, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)  # mean and logvar
        )

        # Prior: obs -> latent z (for inference without actions)
        self.latent_prior = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)
        )

        # Decoder: (obs, latent) -> action chunk
        self.latent_projection = nn.Linear(latent_dim, hidden_dim)

        self.decoder = TransformerDecoder(
            dim=hidden_dim,
            depth=num_decoder_layers,
            heads=num_heads
        )

        # Action prediction head
        self.action_head = nn.Linear(hidden_dim, action_dim)

        # Learnable query tokens for decoder
        self.query_embed = nn.Embedding(chunk_size, hidden_dim)

    def encode_observations(self, images, state, past_actions=None):
        """
        Encode multi-modal observations

        Args:
            images: list of (B, 3, H, W) from different cameras
            state: (B, state_dim) proprioceptive state
            past_actions: (B, history_len, action_dim) recent actions
        Returns:
            encoded: (B, seq_len, hidden_dim)
        """
        B = state.shape[0]
        tokens = []

        # Encode images from all cameras
        for img in images:
            vis_feat = self.vision_encoder(img)  # (B, hidden_dim)
            tokens.append(vis_feat.unsqueeze(1))  # (B, 1, hidden_dim)

        # Encode proprioceptive state
        state_feat = self.state_encoder(state).unsqueeze(1)  # (B, 1, hidden_dim)
        tokens.append(state_feat)

        # Encode past actions if provided
        if past_actions is not None:
            for t in range(past_actions.shape[1]):
                action_feat = self.action_encoder(past_actions[:, t]).unsqueeze(1)
                tokens.append(action_feat)

        # Concatenate all tokens
        all_tokens = torch.cat(tokens, dim=1)  # (B, seq_len, hidden_dim)

        # Process with transformer encoder
        encoded = self.encoder(all_tokens)

        return encoded

    def encode_latent(self, obs_encoded, actions=None):
        """
        Encode latent variable z

        During training: encode from (obs, actions)
        During inference: sample from prior p(z|obs)

        Args:
            obs_encoded: (B, seq_len, hidden_dim)
            actions: (B, chunk_size, action_dim) ground truth actions (training only)
        Returns:
            z: (B, latent_dim) sampled latent
            mu, logvar: distribution parameters
        """
        # Global observation representation (use CLS token or mean)
        obs_global = obs_encoded.mean(dim=1)  # (B, hidden_dim)

        if actions is not None:
            # Training: encode from (obs, actions)
            actions_flat = actions.reshape(actions.shape[0], -1)  # (B, chunk_size * action_dim)
            combined = torch.cat([obs_global, actions_flat], dim=-1)

            mu_logvar = self.latent_encoder(combined)
        else:
            # Inference: use prior
            mu_logvar = self.latent_prior(obs_global)

        mu, logvar = mu_logvar.chunk(2, dim=-1)

        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return z, mu, logvar

    def decode_actions(self, obs_encoded, latent):
        """
        Decode action chunk from observations and latent

        Args:
            obs_encoded: (B, seq_len, hidden_dim)
            latent: (B, latent_dim)
        Returns:
            actions: (B, chunk_size, action_dim)
        """
        B = obs_encoded.shape[0]

        # Project latent to hidden dim
        latent_proj = self.latent_projection(latent).unsqueeze(1)  # (B, 1, hidden_dim)

        # Create query tokens for decoder
        query_tokens = self.query_embed.weight.unsqueeze(0).repeat(B, 1, 1)  # (B, chunk_size, hidden_dim)

        # Add latent to each query
        query_tokens = query_tokens + latent_proj

        # Transformer decoder: cross-attend to observations
        decoded = self.decoder(query_tokens, obs_encoded)  # (B, chunk_size, hidden_dim)

        # Predict actions
        actions = self.action_head(decoded)  # (B, chunk_size, action_dim)

        return actions

    def forward(self, images, state, past_actions=None, ground_truth_actions=None):
        """
        Full forward pass

        Training: use ground truth actions to compute CVAE loss
        Inference: sample from prior
        """
        # Encode observations
        obs_encoded = self.encode_observations(images, state, past_actions)

        # Encode latent
        latent, mu, logvar = self.encode_latent(obs_encoded, ground_truth_actions)

        # Decode actions
        predicted_actions = self.decode_actions(obs_encoded, latent)

        return predicted_actions, mu, logvar

    def compute_loss(self, batch):
        """
        Compute CVAE loss

        L = Reconstruction loss + KL divergence
        """
        images = batch['images']  # List of camera views
        state = batch['state']
        actions = batch['actions']  # (B, chunk_size, action_dim)

        # Forward pass
        predicted_actions, mu, logvar = self.forward(
            images, state,
            ground_truth_actions=actions
        )

        # Reconstruction loss (L2)
        recon_loss = F.mse_loss(predicted_actions, actions)

        # KL divergence
        # KL(q(z|x,a) || p(z|x)) = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / mu.shape[0]

        # Total loss with KL weight (important for stability)
        kl_weight = 10.0  # Tune this
        total_loss = recon_loss + kl_weight * kl_loss

        return total_loss, {
            'recon_loss': recon_loss.item(),
            'kl_loss': kl_loss.item(),
            'total_loss': total_loss.item()
        }

    @torch.no_grad()
    def predict(self, images, state, past_actions=None):
        """
        Inference: predict action chunk

        Sample latent from prior and decode
        """
        self.eval()

        # Encode observations
        obs_encoded = self.encode_observations(images, state, past_actions)

        # Sample from prior
        latent, _, _ = self.encode_latent(obs_encoded, actions=None)

        # Decode actions
        actions = self.decode_actions(obs_encoded, latent)

        return actions


class ResNetEncoder(nn.Module):
    """Vision encoder using ResNet"""
    def __init__(self, output_dim=512):
        super().__init__()

        # Load pre-trained ResNet18
        resnet = torchvision.models.resnet18(pretrained=True)

        # Remove final FC layer
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])

        # Project to output dim
        self.projection = nn.Linear(512, output_dim)

    def forward(self, x):
        """
        Args:
            x: (B, 3, H, W)
        Returns:
            features: (B, output_dim)
        """
        features = self.encoder(x)
        features = features.flatten(1)
        features = self.projection(features)
        return features


class TransformerEncoder(nn.Module):
    """Standard transformer encoder"""
    def __init__(self, dim=512, depth=4, heads=8):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(dim, heads)
            for _ in range(depth)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class TransformerDecoder(nn.Module):
    """Transformer decoder with cross-attention"""
    def __init__(self, dim=512, depth=7, heads=8):
        super().__init__()

        self.layers = nn.ModuleList([
            TransformerDecoderLayer(dim, heads)
            for _ in range(depth)
        ])

    def forward(self, queries, context):
        """
        Args:
            queries: (B, query_len, dim)
            context: (B, context_len, dim)
        """
        for layer in self.layers:
            queries = layer(queries, context)
        return queries


class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Self-attention
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)

        # Feed-forward
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)

        return x


class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim, heads):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.cross_attention = nn.MultiheadAttention(dim, heads, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, queries, context):
        # Self-attention on queries
        self_attn_out, _ = self.self_attention(queries, queries, queries)
        queries = self.norm1(queries + self_attn_out)

        # Cross-attention to context
        cross_attn_out, _ = self.cross_attention(queries, context, context)
        queries = self.norm2(queries + cross_attn_out)

        # Feed-forward
        ff_out = self.ff(queries)
        queries = self.norm3(queries + ff_out)

        return queries

Training ACT

Data Collection

ACT works best with multi-camera observations:

class ACTDataset(torch.utils.data.Dataset):
    """Dataset for ACT training"""
    def __init__(self, demonstrations, chunk_size=100):
        self.demonstrations = demonstrations
        self.chunk_size = chunk_size

    def __len__(self):
        # Count valid chunks across all demonstrations
        total_chunks = 0
        for demo in self.demonstrations:
            total_chunks += max(0, len(demo) - self.chunk_size)
        return total_chunks

    def __getitem__(self, idx):
        # Find demonstration and chunk index
        demo_idx = 0
        chunk_idx = idx

        while chunk_idx >= len(self.demonstrations[demo_idx]) - self.chunk_size:
            chunk_idx -= len(self.demonstrations[demo_idx]) - self.chunk_size
            demo_idx += 1

        demo = self.demonstrations[demo_idx]

        # Extract observation at chunk start
        obs = demo[chunk_idx]

        # Extract action chunk
        actions = np.array([
            demo[chunk_idx + i]['action']
            for i in range(self.chunk_size)
        ])

        return {
            'images': [obs['camera_1'], obs['camera_2'], obs['camera_3']],  # Multi-camera
            'state': obs['state'],
            'actions': actions  # (chunk_size, action_dim)
        }

Training Loop

def train_act(model, dataset, config):
    """Train ACT model"""

    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4
    )

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    # KL annealing schedule (gradually increase KL weight)
    kl_weight_schedule = np.linspace(0, 10.0, config.kl_warmup_steps)

    global_step = 0

    for epoch in range(config.num_epochs):
        for batch in dataloader:
            # Move to GPU
            batch = {k: v.cuda() if isinstance(v, torch.Tensor) else [img.cuda() for img in v]
                    for k, v in batch.items()}

            # Forward pass
            total_loss, loss_dict = model.compute_loss(batch)

            # KL annealing
            if global_step < config.kl_warmup_steps:
                # Reweight KL term
                kl_weight = kl_weight_schedule[global_step]
                total_loss = loss_dict['recon_loss'] + kl_weight * loss_dict['kl_loss']

            # Backward pass
            optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
            optimizer.step()

            global_step += 1

            # Logging
            if global_step % 100 == 0:
                print(f"Step {global_step}:")
                print(f"  Recon Loss: {loss_dict['recon_loss']:.4f}")
                print(f"  KL Loss: {loss_dict['kl_loss']:.4f}")
                print(f"  Total Loss: {total_loss.item():.4f}")

    return model

Deployment: Temporal Ensembling

ACT's key innovation at deployment: temporal ensembling

class ACTController:
    """ACT controller with temporal ensembling"""
    def __init__(self, model, chunk_size=100, k_obs=1):
        self.model = model
        self.chunk_size = chunk_size
        self.k_obs = k_obs  # Number of observations to ensemble

        # Rolling buffer of predicted action chunks
        self.action_buffer = collections.deque(maxlen=k_obs)

    def predict_action(self, observation, timestep):
        """
        Predict action using temporal ensembling

        Key idea: Average predictions from last k_obs observations
        """
        # Predict action chunk from current observation
        with torch.no_grad():
            action_chunk = self.model.predict(
                images=observation['images'],
                state=observation['state']
            )  # (1, chunk_size, action_dim)

        # Add to buffer
        self.action_buffer.append({
            'chunk': action_chunk.squeeze(0).cpu().numpy(),
            'query_timestep': timestep
        })

        # Temporal ensembling: average predictions
        action = self.ensemble_actions(timestep)

        return action

    def ensemble_actions(self, current_timestep):
        """
        Ensemble action predictions from recent observations

        For each buffered prediction, extract the action at current timestep
        and average
        """
        if len(self.action_buffer) == 0:
            return None

        actions_to_average = []

        for prediction in self.action_buffer:
            # How many steps since this prediction was made?
            steps_since_query = current_timestep - prediction['query_timestep']

            # Extract action at current timestep
            if steps_since_query < self.chunk_size:
                action = prediction['chunk'][steps_since_query]
                actions_to_average.append(action)

        # Average (exponentially weighted recency)
        weights = np.exp(-0.1 * np.arange(len(actions_to_average)))
        weights = weights / weights.sum()

        ensembled_action = np.average(actions_to_average, axis=0, weights=weights)

        return ensembled_action

# Deployment
controller = ACTController(model, chunk_size=100, k_obs=1)

obs = env.reset()
for t in range(500):
    action = controller.predict_action(obs, timestep=t)
    obs, reward, done, info = env.step(action)

    if done:
        break

Why temporal ensembling works: - Reduces variance in predictions - Smoother action sequences - Better handling of observation noise

Comparison with Other Methods

Method Action Horizon Multimodal Success Rate Sample Efficiency
BC 1 40% Poor
Diffusion Policy 16 75% Good
ACT 100 85% Very Good
VLA (RT-2) 1 90%* Excellent*

*With web-scale pre-training

Key Advantages

vs Standard BC: - ✓ Longer action horizons → less compounding error - ✓ CVAE → handles multimodal distributions - ✓ Temporal ensembling → smoother execution

vs Diffusion Policy: - ✓ Longer chunks (100 vs 16) - ✓ Faster inference (single forward pass vs iterative denoising) - ✓ Better for long-horizon tasks

Tips for Success

1. Chunk Size Selection

# Rule of thumb: chunk_size should cover ~3-5 seconds of execution
fps = 30  # Control frequency
duration_seconds = 3
chunk_size = fps * duration_seconds  # 90

# Or: make it cover typical subtask duration
# E.g., reaching takes ~2 sec, so chunk_size >= 60

2. KL Weight Tuning

# Start low, increase gradually
kl_weights = [1.0, 5.0, 10.0, 20.0]

for kl_weight in kl_weights:
    model.kl_weight = kl_weight
    train_epoch(model)

    # Check: if KL loss → 0, increase weight
    # if recon loss increases a lot, decrease weight

3. Data Collection

  • Use multiple camera views (3+ recommended)
  • Ensure diverse demonstrations (different initial conditions)
  • Collect 50-100 demos minimum per task
  • Include failure recovery demonstrations

Resources

  • Paper: https://arxiv.org/abs/2304.13705
  • Code: https://github.com/tonyzhaozh/act
  • Project Page: https://tonyzhaozh.github.io/aloha/

Next Steps