Skip to content

VLA Model Architectures

This page explores the architectural components and design patterns for Vision-Language-Action models.

Overview

A typical VLA architecture consists of four main components:

graph TD
    A[Visual Inputs] --> B[Vision Encoder]
    C[Language Instructions] --> D[Language Encoder]
    E[Robot State] --> F[State Encoder]
    B --> G[Multi-Modal Fusion]
    D --> G
    F --> G
    G --> H[Action Decoder]
    H --> I[Robot Actions]

Vision Encoder

The vision encoder processes visual observations from cameras and sensors.

from transformers import ViTModel

class ViTVisionEncoder:
    def __init__(self, pretrained='google/vit-base-patch16-224'):
        self.encoder = ViTModel.from_pretrained(pretrained)
        self.projection = nn.Linear(768, hidden_dim)

    def forward(self, images):
        # images: (batch, channels, height, width)
        outputs = self.encoder(images)
        features = outputs.last_hidden_state  # (batch, num_patches, 768)
        return self.projection(features)

Advantages: - Strong pre-trained representations - Captures global context - Scalable to large datasets

import torchvision.models as models

class ResNetVisionEncoder:
    def __init__(self, pretrained=True):
        resnet = models.resnet50(pretrained=pretrained)
        # Remove final classification layer
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.projection = nn.Linear(2048, hidden_dim)

    def forward(self, images):
        features = self.encoder(images)  # (batch, 2048, H, W)
        pooled = self.pool(features).squeeze(-1).squeeze(-1)
        return self.projection(pooled)

Advantages: - Proven architecture - Efficient inference - Strong spatial features

import open_clip

class CLIPVisionEncoder:
    def __init__(self):
        model, _, preprocess = open_clip.create_model_and_transforms(
            'ViT-B-32',
            pretrained='laion2b_s34b_b79k'
        )
        self.encoder = model.visual

    def forward(self, images):
        # Leverages vision-language pre-training
        return self.encoder(images)

Advantages: - Vision-language alignment - Strong zero-shot capabilities - Internet-scale pre-training

Multi-View Processing

For robots with multiple cameras:

class MultiViewVisionEncoder:
    def __init__(self, single_view_encoder):
        self.view_encoders = nn.ModuleList([
            single_view_encoder() for _ in range(num_views)
        ])
        self.fusion = nn.MultiheadAttention(hidden_dim, num_heads=8)

    def forward(self, multi_view_images):
        # multi_view_images: dict with keys like 'front', 'wrist', 'top'
        view_features = []
        for view_name, encoder in zip(multi_view_images.keys(), self.view_encoders):
            features = encoder(multi_view_images[view_name])
            view_features.append(features)

        # Stack and fuse with attention
        stacked_features = torch.stack(view_features, dim=1)
        fused, _ = self.fusion(stacked_features, stacked_features, stacked_features)
        return fused

Language Encoder

Processes natural language instructions and task descriptions.

Architecture Options

from transformers import T5EncoderModel, T5Tokenizer

class T5LanguageEncoder:
    def __init__(self):
        self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
        self.encoder = T5EncoderModel.from_pretrained('t5-base')

    def forward(self, instructions):
        # instructions: List[str]
        inputs = self.tokenizer(
            instructions,
            padding=True,
            return_tensors='pt'
        )
        outputs = self.encoder(**inputs)
        return outputs.last_hidden_state  # (batch, seq_len, 768)
from transformers import BertModel, BertTokenizer

class BERTLanguageEncoder:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.encoder = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, instructions):
        inputs = self.tokenizer(
            instructions,
            padding=True,
            truncation=True,
            return_tensors='pt'
        )
        outputs = self.encoder(**inputs)
        # Use [CLS] token representation
        return outputs.last_hidden_state[:, 0, :]  # (batch, 768)
import open_clip

class CLIPTextEncoder:
    def __init__(self):
        model, _, _ = open_clip.create_model_and_transforms(
            'ViT-B-32',
            pretrained='laion2b_s34b_b79k'
        )
        self.encoder = model.encode_text
        self.tokenizer = open_clip.get_tokenizer('ViT-B-32')

    def forward(self, instructions):
        tokens = self.tokenizer(instructions)
        return self.encoder(tokens)

Multi-Modal Fusion

Combining visual and language representations is critical for VLA performance.

Cross-Attention Fusion

class CrossAttentionFusion:
    def __init__(self, hidden_dim, num_heads=8, num_layers=4):
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim * 4
            )
            for _ in range(num_layers)
        ])

    def forward(self, visual_features, language_features, state_features):
        # visual_features: (batch, num_patches, hidden_dim)
        # language_features: (batch, seq_len, hidden_dim)
        # state_features: (batch, state_dim)

        # Combine all modalities
        query = torch.cat([
            visual_features,
            language_features,
            state_features.unsqueeze(1)
        ], dim=1)

        # Self-attention and cross-attention
        for layer in self.layers:
            query = layer(query, query)

        return query

Gated Fusion

class GatedFusion:
    def __init__(self, visual_dim, language_dim, output_dim):
        self.visual_proj = nn.Linear(visual_dim, output_dim)
        self.language_proj = nn.Linear(language_dim, output_dim)
        self.gate = nn.Sequential(
            nn.Linear(visual_dim + language_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, visual_features, language_features):
        v = self.visual_proj(visual_features)
        l = self.language_proj(language_features)

        # Compute gate
        gate = self.gate(torch.cat([visual_features, language_features], dim=-1))

        # Gated fusion
        fused = gate * v + (1 - gate) * l
        return fused

FiLM (Feature-wise Linear Modulation)

class FiLMFusion:
    def __init__(self, visual_dim, language_dim):
        # Language generates scaling and shifting parameters
        self.gamma_net = nn.Linear(language_dim, visual_dim)
        self.beta_net = nn.Linear(language_dim, visual_dim)

    def forward(self, visual_features, language_features):
        # Generate modulation parameters from language
        gamma = self.gamma_net(language_features)  # Scaling
        beta = self.beta_net(language_features)    # Shifting

        # Modulate visual features
        modulated = gamma * visual_features + beta
        return modulated

Action Decoder

Generates robot actions from fused multi-modal representations.

Action Prediction Heads

class MLPActionDecoder:
    def __init__(self, input_dim, action_dim):
        self.network = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, fused_features):
        return self.network(fused_features)
class DiffusionActionDecoder:
    def __init__(self, input_dim, action_dim, num_steps=100):
        self.noise_predictor = UNet1D(
            input_dim=input_dim + action_dim,
            output_dim=action_dim
        )
        self.num_steps = num_steps

    def forward(self, fused_features, actions=None):
        if self.training:
            # Training: add noise and predict
            timesteps = torch.randint(0, self.num_steps, (batch_size,))
            noise = torch.randn_like(actions)
            noisy_actions = self.add_noise(actions, noise, timesteps)

            # Predict noise
            predicted_noise = self.noise_predictor(
                torch.cat([fused_features, noisy_actions], dim=-1),
                timesteps
            )
            return predicted_noise
        else:
            # Inference: denoise from random noise
            actions = torch.randn(batch_size, action_dim)
            for t in reversed(range(self.num_steps)):
                predicted_noise = self.noise_predictor(
                    torch.cat([fused_features, actions], dim=-1),
                    t
                )
                actions = self.denoise_step(actions, predicted_noise, t)
            return actions
class AutoregressiveActionDecoder:
    def __init__(self, input_dim, action_dim, chunk_size=10):
        self.chunk_size = chunk_size
        self.transformer = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=256, nhead=8),
            num_layers=6
        )
        self.action_embedding = nn.Linear(action_dim, 256)
        self.output_projection = nn.Linear(256, action_dim)

    def forward(self, fused_features, past_actions=None):
        # Generate action sequence autoregressively
        if past_actions is None:
            past_actions = torch.zeros(batch_size, 1, action_dim)

        action_embeds = self.action_embedding(past_actions)
        decoder_output = self.transformer(
            action_embeds,
            fused_features.unsqueeze(1)
        )
        predicted_actions = self.output_projection(decoder_output)
        return predicted_actions

Complete VLA Architecture

Putting it all together:

class VLAModel(nn.Module):
    def __init__(self, config):
        super().__init__()

        # Encoders
        self.vision_encoder = CLIPVisionEncoder()
        self.language_encoder = T5LanguageEncoder()
        self.state_encoder = nn.Linear(config.state_dim, config.hidden_dim)

        # Fusion
        self.fusion = CrossAttentionFusion(
            hidden_dim=config.hidden_dim,
            num_heads=8,
            num_layers=4
        )

        # Action decoder
        self.action_decoder = MLPActionDecoder(
            input_dim=config.hidden_dim,
            action_dim=config.action_dim
        )

    def forward(self, observations):
        # Encode modalities
        visual_features = self.vision_encoder(observations['image'])
        language_features = self.language_encoder(observations['instruction'])
        state_features = self.state_encoder(observations['robot_state'])

        # Fuse
        fused = self.fusion(visual_features, language_features, state_features)

        # Decode actions
        actions = self.action_decoder(fused[:, 0, :])  # Use first token
        return actions

Architecture Variants

RT-1 Style

  • Vision: EfficientNet
  • Language: Universal Sentence Encoder
  • Fusion: Token-based Transformer
  • Action: Discretized action space

RT-2 Style

  • Vision: ViT (from pre-trained VLM)
  • Language: T5
  • Fusion: Vision-Language Model backbone
  • Action: Co-fine-tuned with VLM

OpenVLA Style

  • Vision: SigLIP
  • Language: Llama-based
  • Fusion: Integrated multi-modal transformer
  • Action: Continuous action prediction

Design Considerations

Model Size vs. Performance

Size Parameters Inference Speed Performance
Small < 100M Fast (>30 Hz) Good for simple tasks
Medium 100M-1B Medium (10-30 Hz) General purpose
Large > 1B Slow (<10 Hz) Best generalization

Action Representation

Choose based on your robot and task:

  • End-effector: More intuitive, easier sim-to-real
  • Joint space: More precise control
  • Delta actions: More stable, less prone to drift

Next Steps