Multi-Task Learning for VLA Models¶
Multi-task learning enables a single VLA model to perform diverse robotic tasks, improving generalization and sample efficiency.
Why Multi-Task Learning?¶
Single-task models learn one skill at a time:
Model_1: "Pick red block" → Trained on 1000 demos
Model_2: "Open drawer" → Trained on 1000 demos
Model_3: "Wipe table" → Trained on 1000 demos
Total: 3 models, 3000 demos
Multi-task models learn all skills simultaneously:
Benefits: - ✓ Shared representations: Common skills transfer across tasks - ✓ Better generalization: Exposure to diverse scenarios - ✓ Sample efficiency: Each task benefits from others - ✓ Single deployment: One model for all tasks - ✓ Compositional skills: Can combine learned primitives
Architecture for Multi-Task VLA¶
Task-Conditioned Policy¶
Language naturally provides task conditioning:
class MultiTaskVLA(nn.Module):
"""VLA model for multiple tasks"""
def __init__(self, vision_encoder, language_encoder, action_head):
super().__init__()
self.vision_encoder = vision_encoder
self.language_encoder = language_encoder
self.action_head = action_head
# Shared representations
self.fusion = nn.Linear(512 + 768, 512)
def forward(self, image, task_instruction, state):
"""
Single forward pass for any task
Args:
image: (B, 3, H, W)
task_instruction: "pick up red block" / "open drawer" / etc.
state: (B, state_dim)
Returns:
action: (B, action_dim)
"""
# Extract features
vision_features = self.vision_encoder(image)
task_embedding = self.language_encoder(task_instruction)
# Fuse (task conditioning happens here)
combined = torch.cat([vision_features, task_embedding], dim=-1)
features = self.fusion(combined)
# Predict action
action = self.action_head(features)
return action
Key insight: Language instruction specifies the task, so no architectural changes needed!
Task Embeddings¶
Alternative: Learn task embeddings:
class TaskEmbeddingVLA(nn.Module):
"""VLA with learned task embeddings"""
def __init__(self, num_tasks=10, embedding_dim=128):
super().__init__()
# Discrete task embeddings
self.task_embeddings = nn.Embedding(num_tasks, embedding_dim)
# Task ID to embedding
self.task_names = [
"pick_red_block",
"open_drawer",
"close_drawer",
"wipe_table",
# ... more tasks
]
def forward(self, image, task_id, state):
"""
Args:
task_id: integer in [0, num_tasks-1]
"""
# Get task embedding
task_emb = self.task_embeddings(task_id)
# Rest same as before
features = self.encode(image, task_emb, state)
action = self.action_head(features)
return action
Pros: Efficient, no language encoder needed Cons: Not compositional, can't generalize to new task descriptions
Training Strategies¶
1. Uniform Sampling¶
Sample tasks uniformly during training:
def train_multitask_uniform(model, task_datasets, config):
"""Train with uniform task sampling"""
# Combine all datasets
all_data = []
for task_name, dataset in task_datasets.items():
for demo in dataset:
demo['task'] = task_name
all_data.append(demo)
# Shuffle
random.shuffle(all_data)
dataloader = DataLoader(all_data, batch_size=config.batch_size, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)
for epoch in range(config.num_epochs):
for batch in dataloader:
# Forward pass (task specified by language instruction)
loss = model.compute_loss(batch)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model
Issue: Tasks with more data dominate training
2. Balanced Sampling¶
Balance samples across tasks:
class BalancedTaskSampler:
"""Sample equally from each task"""
def __init__(self, task_datasets):
self.task_datasets = task_datasets
self.task_names = list(task_datasets.keys())
# Create iterators
self.iterators = {
task: iter(DataLoader(dataset, batch_size=1, shuffle=True))
for task, dataset in task_datasets.items()
}
def sample_batch(self, batch_size):
"""Sample batch with balanced tasks"""
batch = []
samples_per_task = batch_size // len(self.task_names)
for task_name in self.task_names:
for _ in range(samples_per_task):
try:
sample = next(self.iterators[task_name])
batch.append(sample)
except StopIteration:
# Restart iterator
self.iterators[task_name] = iter(DataLoader(
self.task_datasets[task_name],
batch_size=1,
shuffle=True
))
sample = next(self.iterators[task_name])
batch.append(sample)
return batch
# Training loop
sampler = BalancedTaskSampler(task_datasets)
for step in range(config.max_steps):
batch = sampler.sample_batch(config.batch_size)
loss = model.compute_loss(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Benefit: Each task gets equal training signal
3. Curriculum Learning¶
Start with easy tasks, gradually increase difficulty:
class CurriculumScheduler:
"""Schedule task difficulty over training"""
def __init__(self, task_difficulties):
"""
Args:
task_difficulties: dict mapping task -> difficulty score [0-1]
"""
self.task_difficulties = task_difficulties
self.current_step = 0
def get_task_weights(self, total_steps):
"""Get sampling weights for each task"""
# Progress through curriculum
progress = self.current_step / total_steps
weights = {}
for task, difficulty in self.task_difficulties.items():
# Easy tasks (low difficulty) early, hard tasks later
if difficulty <= progress:
weights[task] = 1.0 # Active
else:
weights[task] = 0.1 # Reduced probability
return weights
# Example
task_difficulties = {
'reach_target': 0.1, # Easy
'pick_object': 0.3,
'place_object': 0.5,
'stack_blocks': 0.7,
'open_drawer': 0.8,
'complex_assembly': 0.9 # Hard
}
scheduler = CurriculumScheduler(task_difficulties)
for step in range(total_steps):
# Get current task distribution
task_weights = scheduler.get_task_weights(total_steps)
# Sample task
task = np.random.choice(
list(task_weights.keys()),
p=np.array(list(task_weights.values())) / sum(task_weights.values())
)
# Train on this task
batch = task_datasets[task].sample()
loss = model.compute_loss(batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.current_step += 1
4. Task-Specific Losses¶
Different losses for different tasks:
class MultiTaskLoss(nn.Module):
"""Combined loss for multi-task learning"""
def __init__(self, task_weights=None):
super().__init__()
self.task_weights = task_weights or {}
def forward(self, model, batch):
"""Compute weighted multi-task loss"""
total_loss = 0
task_losses = {}
for task_name in batch['task_names']:
# Get samples for this task
task_indices = batch['task_names'] == task_name
task_batch = {k: v[task_indices] for k, v in batch.items()}
# Task-specific loss
if task_name == 'pick_place':
# Binary cross-entropy for grasp success
loss = self.pick_place_loss(model, task_batch)
elif task_name == 'navigation':
# Distance-based loss
loss = self.navigation_loss(model, task_batch)
else:
# Default: action prediction
loss = F.mse_loss(
model.predict(task_batch),
task_batch['actions']
)
# Weight and accumulate
weight = self.task_weights.get(task_name, 1.0)
total_loss += weight * loss
task_losses[task_name] = loss.item()
return total_loss, task_losses
def pick_place_loss(self, model, batch):
"""Loss for pick-and-place"""
# Action prediction
action_loss = F.mse_loss(
model.predict(batch),
batch['actions']
)
# Auxiliary: grasp success prediction
grasp_logits = model.predict_grasp_success(batch)
grasp_loss = F.binary_cross_entropy_with_logits(
grasp_logits,
batch['grasp_success']
)
return action_loss + 0.1 * grasp_loss
def navigation_loss(self, model, batch):
"""Loss for navigation"""
# Predict waypoints
waypoints_pred = model.predict_waypoints(batch)
waypoints_true = batch['waypoints']
# Distance-weighted loss (later waypoints less important)
weights = torch.exp(-torch.arange(len(waypoints_true)) * 0.1)
loss = (weights * F.mse_loss(
waypoints_pred,
waypoints_true,
reduction='none'
).mean(dim=-1)).mean()
return loss
Task Composition¶
Learn primitives that can be combined:
class HierarchicalMultiTaskVLA(nn.Module):
"""Hierarchical model with task primitives"""
def __init__(self):
super().__init__()
# Low-level controllers (primitives)
self.primitives = nn.ModuleDict({
'reach': PrimitiveController(),
'grasp': PrimitiveController(),
'move': PrimitiveController(),
'release': PrimitiveController(),
})
# High-level policy (selects primitives)
self.high_level_policy = HighLevelPolicy()
def forward(self, observation, task_instruction):
"""
Two-level policy
1. High-level: which primitive to execute
2. Low-level: how to execute primitive
"""
# Parse task into sequence of primitives
# E.g., "pick and place" → [reach, grasp, move, release]
primitive_sequence = self.parse_task(task_instruction)
# Execute primitives sequentially
actions = []
for primitive_name in primitive_sequence:
# Get primitive controller
primitive = self.primitives[primitive_name]
# Execute primitive
action = primitive(observation)
actions.append(action)
# Update state (in simulation or real robot)
observation = self.step(action)
return actions
def parse_task(self, task_instruction):
"""
Parse language into primitive sequence
Uses high-level policy (can be learned or rule-based)
"""
# Option 1: Rule-based
if "pick and place" in task_instruction:
return ['reach', 'grasp', 'move', 'release']
elif "open drawer" in task_instruction:
return ['reach', 'grasp', 'pull']
# Option 2: Learned (seq2seq model)
return self.high_level_policy.predict_primitives(task_instruction)
Multi-Task Evaluation¶
Evaluate on all tasks simultaneously:
class MultiTaskEvaluator:
"""Evaluate multi-task model"""
def __init__(self, model, task_envs):
self.model = model
self.task_envs = task_envs # Dict: task_name -> env
def evaluate(self, num_episodes_per_task=10):
"""Evaluate on all tasks"""
results = {}
for task_name, env in self.task_envs.items():
task_results = self.evaluate_task(task_name, env, num_episodes_per_task)
results[task_name] = task_results
# Aggregate metrics
avg_success_rate = np.mean([r['success_rate'] for r in results.values()])
results['average_success_rate'] = avg_success_rate
# Print report
print("="*60)
print("MULTI-TASK EVALUATION RESULTS")
print("="*60)
for task_name, task_results in results.items():
if task_name != 'average_success_rate':
print(f"{task_name}:")
print(f" Success Rate: {task_results['success_rate']*100:.1f}%")
print(f" Avg Steps: {task_results['avg_steps']:.1f}")
print(f"\nOverall Success Rate: {avg_success_rate*100:.1f}%")
print("="*60)
return results
def evaluate_task(self, task_name, env, num_episodes):
"""Evaluate on single task"""
successes = 0
episode_lengths = []
for _ in range(num_episodes):
obs = env.reset()
done = False
steps = 0
# Task instruction
instruction = env.get_instruction()
while not done and steps < 500:
# Predict action
with torch.no_grad():
action = self.model.predict(obs, instruction)
# Execute
obs, reward, done, info = env.step(action)
steps += 1
success = info.get('success', False)
successes += int(success)
episode_lengths.append(steps)
return {
'success_rate': successes / num_episodes,
'avg_steps': np.mean(episode_lengths)
}
Handling Task Interference¶
Tasks can interfere with each other during training:
1. Gradient Surgery¶
Prevent conflicting gradients:
def gradient_surgery(losses, shared_params):
"""
PCGrad: Project conflicting gradients
Paper: "Gradient Surgery for Multi-Task Learning"
"""
# Compute gradients for each task
grads = {}
for task_name, loss in losses.items():
task_grads = torch.autograd.grad(
loss, shared_params,
retain_graph=True,
create_graph=False
)
grads[task_name] = task_grads
# Project conflicting gradients
for task1, grad1 in grads.items():
for task2, grad2 in grads.items():
if task1 != task2:
# Check if gradients conflict (negative dot product)
dot_product = sum(
(g1 * g2).sum()
for g1, g2 in zip(grad1, grad2)
)
if dot_product < 0:
# Project grad1 to be orthogonal to grad2
grads[task1] = [
g1 - (dot_product / (torch.norm(g2)**2 + 1e-8)) * g2
for g1, g2 in zip(grad1, grad2)
]
# Average projected gradients
avg_grads = [
sum(grads[task][i] for task in grads.keys()) / len(grads)
for i in range(len(shared_params))
]
# Apply averaged gradients
for param, grad in zip(shared_params, avg_grads):
param.grad = grad
2. Task-Specific Parameters¶
Give each task some dedicated parameters:
class MultiHeadVLA(nn.Module):
"""VLA with task-specific heads"""
def __init__(self, num_tasks):
super().__init__()
# Shared encoder
self.shared_encoder = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 256)
)
# Task-specific heads
self.task_heads = nn.ModuleList([
nn.Linear(256, 7) # 7-DoF action
for _ in range(num_tasks)
])
def forward(self, observation, task_id):
# Shared features
features = self.shared_encoder(observation)
# Task-specific prediction
action = self.task_heads[task_id](features)
return action
Advanced: Meta-Learning for Multi-Task¶
Learn to quickly adapt to new tasks:
class MAMLMultiTask:
"""Model-Agnostic Meta-Learning for multi-task VLA"""
def __init__(self, model, inner_lr=0.01, outer_lr=0.001):
self.model = model
self.inner_lr = inner_lr
self.outer_lr = outer_lr
self.meta_optimizer = torch.optim.Adam(model.parameters(), lr=outer_lr)
def meta_train(self, task_datasets, num_iterations=1000):
"""Meta-training loop"""
for iteration in range(num_iterations):
# Sample batch of tasks
task_batch = random.sample(list(task_datasets.keys()), k=4)
meta_loss = 0
for task_name in task_batch:
# Clone model
task_model = copy.deepcopy(self.model)
# Inner loop: adapt to task with few examples
support_data = task_datasets[task_name].sample(k=5) # 5-shot
for _ in range(5): # 5 gradient steps
loss = task_model.compute_loss(support_data)
# Inner update
grads = torch.autograd.grad(loss, task_model.parameters())
for param, grad in zip(task_model.parameters(), grads):
param.data -= self.inner_lr * grad
# Query loss after adaptation
query_data = task_datasets[task_name].sample(k=10)
query_loss = task_model.compute_loss(query_data)
meta_loss += query_loss
# Outer update: improve initial parameters
self.meta_optimizer.zero_grad()
meta_loss.backward()
self.meta_optimizer.step()
if iteration % 100 == 0:
print(f"Meta-iteration {iteration}, Meta-loss: {meta_loss.item():.4f}")
return self.model
Best Practices¶
DO:¶
✓ Use language for task conditioning (most flexible) ✓ Balance task sampling during training ✓ Monitor per-task performance separately ✓ Start with curriculum learning for complex task sets ✓ Share as much as possible (encoder, fusion) ✓ Use task-specific losses when appropriate
DON'T:¶
✗Train on imbalanced task distribution ✗Ignore task interference (use gradient surgery if needed) ✗Use separate models when multi-task would work ✗Forget to evaluate on task composition ✗Over-parameterize task-specific components
Example: Complete Multi-Task Training¶
# Define tasks
task_datasets = {
'pick_red_block': load_dataset('pick_red'),
'pick_blue_block': load_dataset('pick_blue'),
'open_drawer': load_dataset('drawer'),
'close_drawer': load_dataset('drawer_close'),
'wipe_table': load_dataset('wipe'),
}
# Create model
model = MultiTaskVLA(
vision_encoder=CLIPVisionEncoder(),
language_encoder=CLIPTextEncoder(),
action_head=ActionHead(action_dim=7)
)
# Balanced sampler
sampler = BalancedTaskSampler(task_datasets)
# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for step in range(10000):
# Sample balanced batch
batch = sampler.sample_batch(batch_size=32)
# Compute loss
loss = model.compute_loss(batch)
# Update
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Evaluate periodically
if step % 1000 == 0:
evaluator = MultiTaskEvaluator(model, task_envs)
results = evaluator.evaluate()
# Save model
torch.save(model.state_dict(), 'multitask_vla.pt')
Next Steps¶
- Training Guide - General VLA training
- Fine-tuning - Adapt to new tasks
- Curriculum Learning - Task sequencing
- Hierarchical RL - Task decomposition