Skip to content

Training VLA Models

This guide covers the complete training pipeline for Vision-Language-Action models.

Training Pipeline Overview

graph TD
    A[Collect Data] --> B[Prepare Dataset]
    B --> C[Configure Model]
    C --> D[Train Model]
    D --> E{Validation}
    E -->|Poor| F[Adjust Hyperparameters]
    F --> D
    E -->|Good| G[Simulation Testing]
    G --> H{Sim Performance OK?}
    H -->|No| F
    H -->|Yes| I[Real Robot Testing]
    I --> J{Real Performance OK?}
    J -->|No| K[Collect More Data]
    K --> A
    J -->|Yes| L[Deploy]

Prerequisites

Environment Setup

# Create conda environment
conda create -n vla python=3.10
conda activate vla

# Install dependencies
pip install torch torchvision
pip install transformers accelerate
pip install lerobot  # LeRobot dataset tools
pip install wandb tensorboard  # Logging

# For simulation
pip install isaacgym  # or IsaacLab

Hardware Requirements

Component Minimum Recommended
GPU 1x RTX 3090 (24GB) 4x A100 (80GB)
RAM 32GB 128GB+
Storage 500GB SSD 2TB+ NVMe
CPU 8 cores 32+ cores

Data Preparation

Using LeRobot Format

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset

# Load dataset
dataset = LeRobotDataset(
    repo_id="lerobot/pusht",
    root="/path/to/data"
)

# Dataset structure
"""
data/
├── meta/
│   ├── info.json
│   └── tasks.json
├── episodes/
│   ├── episode_000000.parquet
│   ├── episode_000001.parquet
│   └── ...
└── videos/
    ├── observation.image.mp4
    └── ...
"""

Custom Dataset Creation

import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd

class VLADataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        # Load episode data
        self.episodes = []
        for episode_file in sorted(Path(data_dir / "episodes").glob("*.parquet")):
            df = pd.read_parquet(episode_file)
            self.episodes.append(df)

    def __len__(self):
        return sum(len(ep) for ep in self.episodes)

    def __getitem__(self, idx):
        # Find episode and step
        episode_idx, step_idx = self._get_episode_step(idx)
        episode_data = self.episodes[episode_idx]

        # Load observation
        image_path = episode_data.iloc[step_idx]['observation.image']
        image = Image.open(image_path)
        if self.transform:
            image = self.transform(image)

        # Load instruction
        instruction = episode_data.iloc[step_idx]['instruction']

        # Load robot state
        state = torch.tensor(episode_data.iloc[step_idx]['observation.state'])

        # Load action
        action = torch.tensor(episode_data.iloc[step_idx]['action'])

        return {
            'image': image,
            'instruction': instruction,
            'robot_state': state,
            'action': action
        }

Data Augmentation

from torchvision import transforms

# Vision augmentation
vision_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Language augmentation
def augment_instruction(instruction):
    """Paraphrase instructions for robustness"""
    templates = [
        instruction,  # Original
        f"Please {instruction.lower()}",
        f"Can you {instruction.lower()}",
        f"I need you to {instruction.lower()}"
    ]
    return random.choice(templates)

Training Configuration

Configuration File

# config/vla_training.yaml

model:
  vision_encoder: "clip-vit-base"
  language_encoder: "t5-base"
  hidden_dim: 512
  num_fusion_layers: 4
  action_dim: 7  # 6 DoF + gripper

training:
  batch_size: 32
  num_epochs: 100
  learning_rate: 1e-4
  weight_decay: 0.01
  warmup_steps: 1000
  gradient_clip: 1.0

  # Loss weights
  action_loss_weight: 1.0
  auxiliary_loss_weight: 0.1

data:
  train_split: 0.9
  val_split: 0.1
  num_workers: 8
  prefetch_factor: 2

optimizer:
  type: "adamw"
  betas: [0.9, 0.999]
  eps: 1e-8

scheduler:
  type: "cosine"
  min_lr: 1e-6

logging:
  wandb_project: "vla-training"
  log_interval: 100
  eval_interval: 1000
  save_interval: 5000

Training Script

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm

class VLATrainer:
    def __init__(self, model, config):
        self.model = model
        self.config = config

        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Setup scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            T_max=config.num_epochs
        )

        # Loss function
        self.criterion = nn.MSELoss()

        # Logging
        wandb.init(project=config.wandb_project, config=config)

    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0

        for batch in tqdm(dataloader, desc="Training"):
            # Move to GPU
            batch = {k: v.cuda() for k, v in batch.items()}

            # Forward pass
            predicted_actions = self.model(batch)
            loss = self.criterion(predicted_actions, batch['action'])

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config.gradient_clip
            )

            self.optimizer.step()

            total_loss += loss.item()

        return total_loss / len(dataloader)

    def validate(self, dataloader):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.cuda() for k, v in batch.items()}
                predicted_actions = self.model(batch)
                loss = self.criterion(predicted_actions, batch['action'])
                total_loss += loss.item()

        return total_loss / len(dataloader)

    def train(self, train_loader, val_loader):
        best_val_loss = float('inf')

        for epoch in range(self.config.num_epochs):
            # Train
            train_loss = self.train_epoch(train_loader)

            # Validate
            val_loss = self.validate(val_loader)

            # Log
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'learning_rate': self.scheduler.get_last_lr()[0]
            })

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(
                    self.model.state_dict(),
                    f'checkpoints/best_model.pt'
                )

            # Step scheduler
            self.scheduler.step()

            print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

# Usage
if __name__ == "__main__":
    # Load config
    config = load_config("config/vla_training.yaml")

    # Create datasets
    train_dataset = VLADataset(config.train_data_path)
    val_dataset = VLADataset(config.val_data_path)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers
    )

    # Create model
    model = VLAModel(config).cuda()

    # Train
    trainer = VLATrainer(model, config)
    trainer.train(train_loader, val_loader)

Advanced Training Techniques

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

class VLATrainerAMP(VLATrainer):
    def __init__(self, model, config):
        super().__init__(model, config)
        self.scaler = GradScaler()

    def train_epoch(self, dataloader):
        self.model.train()
        total_loss = 0

        for batch in tqdm(dataloader):
            batch = {k: v.cuda() for k, v in batch.items()}

            # Mixed precision forward pass
            with autocast():
                predicted_actions = self.model(batch)
                loss = self.criterion(predicted_actions, batch['action'])

            # Scaled backward pass
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
            self.scaler.step(self.optimizer)
            self.scaler.update()

            total_loss += loss.item()

        return total_loss / len(dataloader)

Distributed Training

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_distributed():
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    return local_rank

# Launch with: torchrun --nproc_per_node=4 train.py
local_rank = setup_distributed()
model = VLAModel(config).cuda()
model = DDP(model, device_ids=[local_rank])

Curriculum Learning

class CurriculumScheduler:
    def __init__(self, easy_tasks, hard_tasks):
        self.easy_tasks = easy_tasks
        self.hard_tasks = hard_tasks
        self.current_difficulty = 0.0

    def get_task_distribution(self, epoch):
        # Gradually increase difficulty
        self.current_difficulty = min(1.0, epoch / 50)

        # Mix of easy and hard tasks
        easy_weight = 1.0 - self.current_difficulty
        hard_weight = self.current_difficulty

        return {
            'easy': easy_weight,
            'hard': hard_weight
        }

Monitoring and Debugging

Key Metrics to Track

def compute_metrics(predicted_actions, target_actions):
    metrics = {
        # Action prediction error
        'mse': F.mse_loss(predicted_actions, target_actions),
        'mae': F.l1_loss(predicted_actions, target_actions),

        # Per-dimension errors
        'position_error': F.mse_loss(predicted_actions[:, :3], target_actions[:, :3]),
        'rotation_error': F.mse_loss(predicted_actions[:, 3:6], target_actions[:, 3:6]),
        'gripper_error': F.mse_loss(predicted_actions[:, 6:], target_actions[:, 6:]),

        # Success metrics (if available)
        'gripper_accuracy': (
            (predicted_actions[:, 6] > 0.5) == (target_actions[:, 6] > 0.5)
        ).float().mean()
    }
    return metrics

Visualization

import matplotlib.pyplot as plt

def visualize_predictions(model, val_dataset, num_samples=4):
    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))

    for i in range(num_samples):
        sample = val_dataset[i]

        # Show image
        axes[i, 0].imshow(sample['image'].permute(1, 2, 0))
        axes[i, 0].set_title(f"Instruction: {sample['instruction']}")

        # Predict action
        with torch.no_grad():
            pred_action = model({k: v.unsqueeze(0) for k, v in sample.items()})

        # Plot predicted vs actual actions
        axes[i, 1].plot(sample['action'].numpy(), label='Actual', marker='o')
        axes[i, 1].plot(pred_action[0].cpu().numpy(), label='Predicted', marker='x')
        axes[i, 1].legend()
        axes[i, 1].set_title("Actions")

        # Show error
        error = (pred_action[0].cpu() - sample['action']).abs()
        axes[i, 2].bar(range(len(error)), error.numpy())
        axes[i, 2].set_title("Absolute Error")

    plt.tight_layout()
    wandb.log({"predictions": wandb.Image(fig)})
    plt.close()

Hyperparameter Tuning

Learning Rate Finder

def find_learning_rate(model, dataloader, min_lr=1e-7, max_lr=10, num_iter=100):
    optimizer = torch.optim.AdamW(model.parameters(), lr=min_lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
        optimizer,
        gamma=(max_lr/min_lr)**(1/num_iter)
    )

    losses = []
    lrs = []

    for batch in islice(dataloader, num_iter):
        batch = {k: v.cuda() for k, v in batch.items()}

        # Forward
        loss = model(batch)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log
        losses.append(loss.item())
        lrs.append(optimizer.param_groups[0]['lr'])

        # Step LR
        lr_scheduler.step()

    # Plot
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.show()

Troubleshooting

Common Issues

Issue Symptoms Solution
Overfitting Train loss ↓, Val loss ↑ More data, regularization, augmentation
Underfitting Both losses high Larger model, more training
Gradient explosion Loss → NaN Lower LR, gradient clipping
Slow convergence Loss decreases slowly Higher LR, better initialization
Poor generalization Good sim, bad real More diverse data, domain randomization

Next Steps