Training VLA Models
This guide covers the complete training pipeline for Vision-Language-Action models.
Training Pipeline Overview
graph TD
A[Collect Data] --> B[Prepare Dataset]
B --> C[Configure Model]
C --> D[Train Model]
D --> E{Validation}
E -->|Poor| F[Adjust Hyperparameters]
F --> D
E -->|Good| G[Simulation Testing]
G --> H{Sim Performance OK?}
H -->|No| F
H -->|Yes| I[Real Robot Testing]
I --> J{Real Performance OK?}
J -->|No| K[Collect More Data]
K --> A
J -->|Yes| L[Deploy]
Prerequisites
Environment Setup
# Create conda environment
conda create -n vla python=3.10
conda activate vla
# Install dependencies
pip install torch torchvision
pip install transformers accelerate
pip install lerobot # LeRobot dataset tools
pip install wandb tensorboard # Logging
# For simulation
pip install isaacgym # or IsaacLab
Hardware Requirements
| Component |
Minimum |
Recommended |
| GPU |
1x RTX 3090 (24GB) |
4x A100 (80GB) |
| RAM |
32GB |
128GB+ |
| Storage |
500GB SSD |
2TB+ NVMe |
| CPU |
8 cores |
32+ cores |
Data Preparation
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# Load dataset
dataset = LeRobotDataset(
repo_id="lerobot/pusht",
root="/path/to/data"
)
# Dataset structure
"""
data/
├── meta/
│ ├── info.json
│ └── tasks.json
├── episodes/
│ ├── episode_000000.parquet
│ ├── episode_000001.parquet
│ └── ...
└── videos/
├── observation.image.mp4
└── ...
"""
Custom Dataset Creation
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
class VLADataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
# Load episode data
self.episodes = []
for episode_file in sorted(Path(data_dir / "episodes").glob("*.parquet")):
df = pd.read_parquet(episode_file)
self.episodes.append(df)
def __len__(self):
return sum(len(ep) for ep in self.episodes)
def __getitem__(self, idx):
# Find episode and step
episode_idx, step_idx = self._get_episode_step(idx)
episode_data = self.episodes[episode_idx]
# Load observation
image_path = episode_data.iloc[step_idx]['observation.image']
image = Image.open(image_path)
if self.transform:
image = self.transform(image)
# Load instruction
instruction = episode_data.iloc[step_idx]['instruction']
# Load robot state
state = torch.tensor(episode_data.iloc[step_idx]['observation.state'])
# Load action
action = torch.tensor(episode_data.iloc[step_idx]['action'])
return {
'image': image,
'instruction': instruction,
'robot_state': state,
'action': action
}
Data Augmentation
from torchvision import transforms
# Vision augmentation
vision_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Language augmentation
def augment_instruction(instruction):
"""Paraphrase instructions for robustness"""
templates = [
instruction, # Original
f"Please {instruction.lower()}",
f"Can you {instruction.lower()}",
f"I need you to {instruction.lower()}"
]
return random.choice(templates)
Training Configuration
Configuration File
# config/vla_training.yaml
model:
vision_encoder: "clip-vit-base"
language_encoder: "t5-base"
hidden_dim: 512
num_fusion_layers: 4
action_dim: 7 # 6 DoF + gripper
training:
batch_size: 32
num_epochs: 100
learning_rate: 1e-4
weight_decay: 0.01
warmup_steps: 1000
gradient_clip: 1.0
# Loss weights
action_loss_weight: 1.0
auxiliary_loss_weight: 0.1
data:
train_split: 0.9
val_split: 0.1
num_workers: 8
prefetch_factor: 2
optimizer:
type: "adamw"
betas: [0.9, 0.999]
eps: 1e-8
scheduler:
type: "cosine"
min_lr: 1e-6
logging:
wandb_project: "vla-training"
log_interval: 100
eval_interval: 1000
save_interval: 5000
Training Script
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm
class VLATrainer:
def __init__(self, model, config):
self.model = model
self.config = config
# Setup optimizer
self.optimizer = torch.optim.AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
# Setup scheduler
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer,
T_max=config.num_epochs
)
# Loss function
self.criterion = nn.MSELoss()
# Logging
wandb.init(project=config.wandb_project, config=config)
def train_epoch(self, dataloader):
self.model.train()
total_loss = 0
for batch in tqdm(dataloader, desc="Training"):
# Move to GPU
batch = {k: v.cuda() for k, v in batch.items()}
# Forward pass
predicted_actions = self.model(batch)
loss = self.criterion(predicted_actions, batch['action'])
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.gradient_clip
)
self.optimizer.step()
total_loss += loss.item()
return total_loss / len(dataloader)
def validate(self, dataloader):
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.cuda() for k, v in batch.items()}
predicted_actions = self.model(batch)
loss = self.criterion(predicted_actions, batch['action'])
total_loss += loss.item()
return total_loss / len(dataloader)
def train(self, train_loader, val_loader):
best_val_loss = float('inf')
for epoch in range(self.config.num_epochs):
# Train
train_loss = self.train_epoch(train_loader)
# Validate
val_loss = self.validate(val_loader)
# Log
wandb.log({
'epoch': epoch,
'train_loss': train_loss,
'val_loss': val_loss,
'learning_rate': self.scheduler.get_last_lr()[0]
})
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(
self.model.state_dict(),
f'checkpoints/best_model.pt'
)
# Step scheduler
self.scheduler.step()
print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
# Usage
if __name__ == "__main__":
# Load config
config = load_config("config/vla_training.yaml")
# Create datasets
train_dataset = VLADataset(config.train_data_path)
val_dataset = VLADataset(config.val_data_path)
train_loader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers
)
# Create model
model = VLAModel(config).cuda()
# Train
trainer = VLATrainer(model, config)
trainer.train(train_loader, val_loader)
Advanced Training Techniques
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
class VLATrainerAMP(VLATrainer):
def __init__(self, model, config):
super().__init__(model, config)
self.scaler = GradScaler()
def train_epoch(self, dataloader):
self.model.train()
total_loss = 0
for batch in tqdm(dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
# Mixed precision forward pass
with autocast():
predicted_actions = self.model(batch)
loss = self.criterion(predicted_actions, batch['action'])
# Scaled backward pass
self.optimizer.zero_grad()
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)
self.scaler.step(self.optimizer)
self.scaler.update()
total_loss += loss.item()
return total_loss / len(dataloader)
Distributed Training
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_distributed():
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank
# Launch with: torchrun --nproc_per_node=4 train.py
local_rank = setup_distributed()
model = VLAModel(config).cuda()
model = DDP(model, device_ids=[local_rank])
Curriculum Learning
class CurriculumScheduler:
def __init__(self, easy_tasks, hard_tasks):
self.easy_tasks = easy_tasks
self.hard_tasks = hard_tasks
self.current_difficulty = 0.0
def get_task_distribution(self, epoch):
# Gradually increase difficulty
self.current_difficulty = min(1.0, epoch / 50)
# Mix of easy and hard tasks
easy_weight = 1.0 - self.current_difficulty
hard_weight = self.current_difficulty
return {
'easy': easy_weight,
'hard': hard_weight
}
Monitoring and Debugging
Key Metrics to Track
def compute_metrics(predicted_actions, target_actions):
metrics = {
# Action prediction error
'mse': F.mse_loss(predicted_actions, target_actions),
'mae': F.l1_loss(predicted_actions, target_actions),
# Per-dimension errors
'position_error': F.mse_loss(predicted_actions[:, :3], target_actions[:, :3]),
'rotation_error': F.mse_loss(predicted_actions[:, 3:6], target_actions[:, 3:6]),
'gripper_error': F.mse_loss(predicted_actions[:, 6:], target_actions[:, 6:]),
# Success metrics (if available)
'gripper_accuracy': (
(predicted_actions[:, 6] > 0.5) == (target_actions[:, 6] > 0.5)
).float().mean()
}
return metrics
Visualization
import matplotlib.pyplot as plt
def visualize_predictions(model, val_dataset, num_samples=4):
fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
for i in range(num_samples):
sample = val_dataset[i]
# Show image
axes[i, 0].imshow(sample['image'].permute(1, 2, 0))
axes[i, 0].set_title(f"Instruction: {sample['instruction']}")
# Predict action
with torch.no_grad():
pred_action = model({k: v.unsqueeze(0) for k, v in sample.items()})
# Plot predicted vs actual actions
axes[i, 1].plot(sample['action'].numpy(), label='Actual', marker='o')
axes[i, 1].plot(pred_action[0].cpu().numpy(), label='Predicted', marker='x')
axes[i, 1].legend()
axes[i, 1].set_title("Actions")
# Show error
error = (pred_action[0].cpu() - sample['action']).abs()
axes[i, 2].bar(range(len(error)), error.numpy())
axes[i, 2].set_title("Absolute Error")
plt.tight_layout()
wandb.log({"predictions": wandb.Image(fig)})
plt.close()
Hyperparameter Tuning
Learning Rate Finder
def find_learning_rate(model, dataloader, min_lr=1e-7, max_lr=10, num_iter=100):
optimizer = torch.optim.AdamW(model.parameters(), lr=min_lr)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer,
gamma=(max_lr/min_lr)**(1/num_iter)
)
losses = []
lrs = []
for batch in islice(dataloader, num_iter):
batch = {k: v.cuda() for k, v in batch.items()}
# Forward
loss = model(batch)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Log
losses.append(loss.item())
lrs.append(optimizer.param_groups[0]['lr'])
# Step LR
lr_scheduler.step()
# Plot
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.show()
Troubleshooting
Common Issues
| Issue |
Symptoms |
Solution |
| Overfitting |
Train loss ↓, Val loss ↑ |
More data, regularization, augmentation |
| Underfitting |
Both losses high |
Larger model, more training |
| Gradient explosion |
Loss → NaN |
Lower LR, gradient clipping |
| Slow convergence |
Loss decreases slowly |
Higher LR, better initialization |
| Poor generalization |
Good sim, bad real |
More diverse data, domain randomization |
Next Steps