Skip to content

Edge Deployment for Robot Learning

Deploying learned policies on edge devices (robots, embedded systems) requires optimization for resource-constrained environments.

Why Edge Deployment?

Cloud deployment challenges: - ✗Latency: Network round-trip delays (50-200ms) - ✗Reliability: Requires stable internet connection - ✗Privacy: Sending camera feeds to cloud - ✗Cost: Continuous cloud compute expensive

Edge deployment benefits: - ✓ Low latency: <10ms inference - ✓ Reliability: Works offline - ✓ Privacy: Data stays on device - ✓ Cost: One-time hardware cost

Target Hardware

Common Edge Devices for Robotics

Device Compute Memory Power Price Use Case
NVIDIA Jetson Orin Nano 40 TOPS 8GB 15W $500 Mobile robots, drones
NVIDIA Jetson AGX Orin 275 TOPS 64GB 60W $2000 Humanoids, complex manipulation
Raspberry Pi 4/5 CPU only 8GB 5W $75 Simple control, sensors
Google Coral 4 TOPS (TPU) - 2W $150 Vision-only tasks
Intel NUC CPU+GPU 32GB 65W $800 Desktop replacement

Best balance of cost, performance, and power:

# Setup Jetson Orin Nano
# Flash JetPack 5.1+ (includes CUDA, cuDNN, TensorRT)

# Install PyTorch for Jetson
wget https://nvidia.box.com/shared/static/...pytorch-2.0.0-cp38-cp38m-linux_aarch64.whl
pip3 install pytorch-2.0.0-cp38-cp38m-linux_aarch64.whl

# Install torchvision
pip3 install torchvision

# Verify
python3 -c "import torch; print(torch.cuda.is_available())"

Model Optimization Pipeline

1. Quantization

Reduce precision from FP32 → INT8:

import torch
from torch.quantization import quantize_dynamic, quantize_static

class ModelQuantizer:
    """Quantize model for edge deployment"""

    def __init__(self, model):
        self.model = model

    def dynamic_quantization(self):
        """
        Dynamic quantization (easiest)

        - Weights: INT8
        - Activations: computed dynamically
        - No calibration needed
        - ~2x speedup, ~4x memory reduction
        """
        quantized_model = quantize_dynamic(
            self.model,
            {torch.nn.Linear, torch.nn.Conv2d},  # Layers to quantize
            dtype=torch.qint8
        )

        return quantized_model

    def static_quantization(self, calibration_dataloader):
        """
        Static quantization (best performance)

        - Weights: INT8
        - Activations: INT8 (pre-computed scale/zero-point)
        - Requires calibration data
        - ~4x speedup, ~4x memory reduction
        """
        # Prepare model for quantization
        self.model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
        torch.quantization.prepare(self.model, inplace=True)

        # Calibrate with representative data
        self.model.eval()
        with torch.no_grad():
            for batch in calibration_dataloader:
                self.model(batch)

        # Convert to quantized model
        quantized_model = torch.quantization.convert(self.model, inplace=False)

        return quantized_model

    def quantization_aware_training(self, train_dataloader, num_epochs=5):
        """
        Quantization-Aware Training (QAT)

        - Train with fake quantization
        - Model learns to compensate for quantization errors
        - Best accuracy, but requires retraining
        """
        # Insert fake quantization nodes
        self.model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        torch.quantization.prepare_qat(self.model, inplace=True)

        # Fine-tune with quantization
        self.model.train()
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-5)

        for epoch in range(num_epochs):
            for batch in train_dataloader:
                loss = self.model.compute_loss(batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # Convert to quantized model
        quantized_model = torch.quantization.convert(self.model.eval(), inplace=False)

        return quantized_model

    def evaluate_quantization(self, model_fp32, model_int8, test_dataloader):
        """Compare FP32 vs INT8 performance"""
        import time

        # Accuracy
        acc_fp32 = evaluate_accuracy(model_fp32, test_dataloader)
        acc_int8 = evaluate_accuracy(model_int8, test_dataloader)

        # Speed
        start = time.time()
        for batch in test_dataloader:
            model_fp32(batch)
        time_fp32 = time.time() - start

        start = time.time()
        for batch in test_dataloader:
            model_int8(batch)
        time_int8 = time.time() - start

        # Size
        import os
        torch.save(model_fp32.state_dict(), 'fp32.pt')
        torch.save(model_int8.state_dict(), 'int8.pt')

        size_fp32 = os.path.getsize('fp32.pt') / 1e6  # MB
        size_int8 = os.path.getsize('int8.pt') / 1e6

        print("="*60)
        print("QUANTIZATION COMPARISON")
        print("="*60)
        print(f"Accuracy: FP32={acc_fp32:.3f}, INT8={acc_int8:.3f} (Δ={acc_fp32-acc_int8:.3f})")
        print(f"Speed: FP32={time_fp32:.2f}s, INT8={time_int8:.2f}s ({time_fp32/time_int8:.1f}x speedup)")
        print(f"Size: FP32={size_fp32:.1f}MB, INT8={size_int8:.1f}MB ({size_fp32/size_int8:.1f}x reduction)")
        print("="*60)

2. Pruning

Remove unnecessary weights:

import torch.nn.utils.prune as prune

def prune_model(model, amount=0.3):
    """
    Magnitude-based pruning

    Remove smallest magnitude weights (least important)
    """
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv2d):
            # Prune 30% of weights
            prune.l1_unstructured(module, name='weight', amount=amount)

            # Make pruning permanent
            prune.remove(module, 'weight')

    return model


def structured_pruning(model, amount=0.5):
    """
    Structured pruning - remove entire channels/neurons

    Better for hardware acceleration (maintains regular structure)
    """
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            # Prune entire channels
            prune.ln_structured(
                module,
                name='weight',
                amount=amount,
                n=2,  # L2 norm
                dim=0  # Output channels
            )
            prune.remove(module, 'weight')

    return model

3. TensorRT Optimization

NVIDIA's high-performance inference engine:

import torch
import tensorrt as trt
from torch2trt import torch2trt

class TensorRTOptimizer:
    """Optimize model with TensorRT"""

    def __init__(self, model, input_shape):
        self.model = model
        self.input_shape = input_shape

    def convert_to_tensorrt(self, fp16_mode=True):
        """
        Convert PyTorch model to TensorRT

        Args:
            fp16_mode: Use FP16 precision (2x faster on Jetson)
        """
        # Create example input
        x = torch.ones(self.input_shape).cuda()

        # Convert
        model_trt = torch2trt(
            self.model,
            [x],
            fp16_mode=fp16_mode,
            max_workspace_size=1 << 30  # 1GB
        )

        return model_trt

    def benchmark(self, model_pytorch, model_trt, num_runs=1000):
        """Benchmark PyTorch vs TensorRT"""
        import time

        x = torch.ones(self.input_shape).cuda()

        # Warm up
        for _ in range(10):
            model_pytorch(x)
            model_trt(x)

        # Benchmark PyTorch
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(num_runs):
            model_pytorch(x)
        torch.cuda.synchronize()
        time_pytorch = (time.time() - start) / num_runs * 1000  # ms

        # Benchmark TensorRT
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(num_runs):
            model_trt(x)
        torch.cuda.synchronize()
        time_trt = (time.time() - start) / num_runs * 1000  # ms

        print(f"PyTorch: {time_pytorch:.2f}ms")
        print(f"TensorRT: {time_trt:.2f}ms")
        print(f"Speedup: {time_pytorch/time_trt:.1f}x")

# Example usage
model = YourRobotPolicy().cuda().eval()

optimizer = TensorRTOptimizer(model, input_shape=(1, 3, 224, 224))
model_trt = optimizer.convert_to_tensorrt(fp16_mode=True)

# Save TensorRT model
torch.save(model_trt.state_dict(), 'model_trt.pth')

4. ONNX Export

For deployment on non-NVIDIA hardware:

def export_to_onnx(model, input_shape, filename='model.onnx'):
    """Export PyTorch model to ONNX"""

    model.eval()

    # Dummy input
    dummy_input = torch.randn(input_shape)

    # Export
    torch.onnx.export(
        model,
        dummy_input,
        filename,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={
            'input': {0: 'batch_size'},
            'output': {0: 'batch_size'}
        }
    )

    print(f"Exported to {filename}")

    # Verify
    import onnx
    onnx_model = onnx.load(filename)
    onnx.checker.check_model(onnx_model)
    print("✓ ONNX model valid")

# Use ONNX Runtime for inference
import onnxruntime as ort

session = ort.InferenceSession('model.onnx')

# Inference
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
outputs = session.run(None, {'input': input_data})

Complete Deployment Pipeline

class EdgeDeploymentPipeline:
    """End-to-end pipeline for edge deployment"""

    def __init__(self, model, target_device='jetson'):
        self.model = model
        self.target_device = target_device

    def optimize(self, calibration_data=None):
        """Apply all optimizations"""

        print("Starting optimization pipeline...")

        # Step 1: Pruning
        print("1. Pruning...")
        self.model = prune_model(self.model, amount=0.3)

        # Step 2: Quantization
        print("2. Quantization...")
        if calibration_data is not None:
            quantizer = ModelQuantizer(self.model)
            self.model = quantizer.static_quantization(calibration_data)
        else:
            self.model = quantize_dynamic(self.model, {torch.nn.Linear})

        # Step 3: TensorRT (if NVIDIA)
        if self.target_device == 'jetson':
            print("3. TensorRT conversion...")
            trt_optimizer = TensorRTOptimizer(self.model, input_shape=(1, 3, 224, 224))
            self.model = trt_optimizer.convert_to_tensorrt(fp16_mode=True)

        print("✓ Optimization complete!")

        return self.model

    def deploy(self, output_dir='./deployed_model'):
        """Package model for deployment"""
        import os
        os.makedirs(output_dir, exist_ok=True)

        # Save model
        torch.save(self.model.state_dict(), f'{output_dir}/model.pth')

        # Save config
        config = {
            'model_type': type(self.model).__name__,
            'input_shape': (1, 3, 224, 224),
            'target_device': self.target_device
        }

        import json
        with open(f'{output_dir}/config.json', 'w') as f:
            json.dump(config, f)

        # Create inference script
        inference_script = '''
import torch
import json

# Load config
with open('config.json') as f:
    config = json.load(f)

# Load model
model = torch.load('model.pth')
model.eval()

def predict(image):
    """Run inference"""
    with torch.no_grad():
        action = model(image)
    return action.cpu().numpy()
'''

        with open(f'{output_dir}/inference.py', 'w') as f:
            f.write(inference_script)

        print(f"✓ Model deployed to {output_dir}/")

Real-Time Inference

Ensure consistent low-latency inference:

class RealTimeController:
    """Real-time robot controller with guaranteed latency"""

    def __init__(self, model, target_fps=30):
        self.model = model
        self.target_fps = target_fps
        self.target_dt = 1.0 / target_fps

        # Statistics
        self.inference_times = []
        self.missed_deadlines = 0

    def run(self, env):
        """Control loop with timing guarantees"""
        import time

        obs = env.reset()
        done = False
        step_count = 0

        while not done:
            step_start = time.time()

            # 1. Inference
            inference_start = time.time()
            action = self.predict(obs)
            inference_time = time.time() - inference_start

            self.inference_times.append(inference_time)

            # 2. Execute action
            obs, reward, done, info = env.step(action)

            # 3. Maintain target FPS
            elapsed = time.time() - step_start

            if elapsed < self.target_dt:
                time.sleep(self.target_dt - elapsed)
            else:
                self.missed_deadlines += 1
                print(f"⚠ Missed deadline at step {step_count}: {elapsed*1000:.1f}ms > {self.target_dt*1000:.1f}ms")

            step_count += 1

        # Report
        self.print_stats()

    def predict(self, obs):
        """Run inference with timing"""
        with torch.no_grad():
            obs_tensor = torch.from_numpy(obs).float()
            if torch.cuda.is_available():
                obs_tensor = obs_tensor.cuda()

            action = self.model(obs_tensor)

        return action.cpu().numpy()

    def print_stats(self):
        """Print timing statistics"""
        import numpy as np

        print("="*60)
        print("REAL-TIME PERFORMANCE")
        print("="*60)
        print(f"Target FPS: {self.target_fps}")
        print(f"Missed deadlines: {self.missed_deadlines}")
        print(f"\nInference time (ms):")
        print(f"  Mean: {np.mean(self.inference_times)*1000:.2f}")
        print(f"  Std: {np.std(self.inference_times)*1000:.2f}")
        print(f"  Max: {np.max(self.inference_times)*1000:.2f}")
        print(f"  P95: {np.percentile(self.inference_times, 95)*1000:.2f}")
        print(f"  P99: {np.percentile(self.inference_times, 99)*1000:.2f}")
        print("="*60)

Power Optimization

Maximize battery life on mobile robots:

def set_power_mode(mode='balanced'):
    """
    Set Jetson power mode

    Modes:
    - max_performance (60W): Highest speed
    - balanced (30W): Good speed, reasonable power
    - power_save (15W): Lower speed, longer battery life
    """
    import subprocess

    mode_map = {
        'max_performance': '0',
        'balanced': '1',
        'power_save': '2'
    }

    if mode in mode_map:
        subprocess.run(['nvpmodel', '-m', mode_map[mode]])
        print(f"✓ Set power mode to {mode}")


def optimize_for_battery_life(model):
    """Optimize model for battery-powered robots"""

    # 1. More aggressive quantization
    model = quantize_dynamic(model, {torch.nn.Linear, torch.nn.Conv2d})

    # 2. Lower FPS (if acceptable)
    # 30fps → 15fps = 2x battery life

    # 3. Use INT8 instead of FP16
    # (INT8 uses less power than FP16 on edge devices)

    # 4. Enable DVFS (Dynamic Voltage/Frequency Scaling)
    import subprocess
    subprocess.run(['jetson_clocks', '--restore'])

    return model

Monitoring & Diagnostics

Monitor deployed models on edge:

class EdgeModelMonitor:
    """Monitor model performance on edge device"""

    def __init__(self):
        self.metrics = {
            'inference_times': [],
            'cpu_usage': [],
            'gpu_usage': [],
            'memory_usage': [],
            'temperature': [],
            'power_consumption': []
        }

    def log_step(self, inference_time):
        """Log metrics for current step"""
        import psutil

        # Inference time
        self.metrics['inference_times'].append(inference_time)

        # System metrics
        self.metrics['cpu_usage'].append(psutil.cpu_percent())
        self.metrics['memory_usage'].append(psutil.virtual_memory().percent)

        # GPU (Jetson-specific)
        try:
            import jtop
            with jtop.jtop() as jetson:
                self.metrics['gpu_usage'].append(jetson.gpu['usage'])
                self.metrics['temperature'].append(jetson.temperature['thermal'])
                self.metrics['power_consumption'].append(jetson.power['total'])
        except:
            pass  # Not on Jetson

    def report(self):
        """Generate diagnostics report"""
        import numpy as np

        print("="*60)
        print("EDGE DEPLOYMENT DIAGNOSTICS")
        print("="*60)

        print(f"\nInference Time: {np.mean(self.metrics['inference_times'])*1000:.2f}ms ± {np.std(self.metrics['inference_times'])*1000:.2f}ms")
        print(f"CPU Usage: {np.mean(self.metrics['cpu_usage']):.1f}%")
        print(f"Memory Usage: {np.mean(self.metrics['memory_usage']):.1f}%")

        if self.metrics['gpu_usage']:
            print(f"GPU Usage: {np.mean(self.metrics['gpu_usage']):.1f}%")
            print(f"Temperature: {np.mean(self.metrics['temperature']):.1f}°C")
            print(f"Power Consumption: {np.mean(self.metrics['power_consumption']):.1f}W")

        print("="*60)

Best Practices

DO:

✓ Profile before optimization (find bottlenecks) ✓ Use FP16 on Jetson (2x speedup, minimal accuracy loss) ✓ Quantize to INT8 for maximum speed ✓ Use TensorRT for NVIDIA hardware ✓ Monitor temperature and power consumption ✓ Test extensively on target hardware

DON'T:

✗Assume cloud latencies acceptable for control ✗Over-optimize (diminishing returns) ✗Skip accuracy validation after optimization ✗Ignore thermal throttling ✗Deploy without real-time guarantees

Checklist for Deployment

  • Model fits in device memory
  • Inference time < target (e.g., 30ms for 30Hz control)
  • Accuracy drop < 5% after optimization
  • Tested at sustained load (30+ minutes)
  • Thermal throttling handled
  • Power consumption acceptable
  • Graceful degradation on errors

Resources

  • NVIDIA TensorRT: https://developer.nvidia.com/tensorrt
  • PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html
  • ONNX Runtime: https://onnxruntime.ai
  • Jetson Inference: https://github.com/dusty-nv/jetson-inference

Next Steps