Skip to content

Offline Reinforcement Learning

Offline RL (also called Batch RL) learns policies from fixed datasets without environment interaction - critical for robotics where online exploration is expensive or unsafe.

Why Offline RL?

Problem with Online RL: - Requires millions of environment interactions - Unsafe during exploration (real robots can break) - Expensive (time, wear-and-tear) - May not be allowed (medical, industrial settings)

Offline RL Solution: - Learn from pre-collected datasets - No environment interaction during training - Safe (no exploration) - Can leverage existing demonstration data

# Online RL
for episode in range(1_000_000):  # Millions of episodes!
    obs = env.reset()
    while not done:
        action = policy.explore(obs)  # Can be unsafe
        next_obs, reward, done, _ = env.step(action)
        buffer.add(obs, action, reward, next_obs)

# Offline RL
dataset = load_fixed_dataset()  # Pre-collected
for epoch in range(100):
    batch = dataset.sample()
    policy.update(batch)  # No environment interaction!

The Challenge: Distributional Shift

Problem: Policy may select out-of-distribution (OOD) actions at test time

# Training: policy only sees actions from dataset distribution
dataset_actions = [a1, a2, a3, ...]  # From demonstrations

# Test: policy might choose unseen action
test_action = policy(obs)  # Could be OOD!

# If OOD: Q-values unreliable → poor performance

Solution: Constrain policy to stay close to data distribution

Offline RL Algorithms

1. Conservative Q-Learning (CQL)

Idea: Penalize Q-values for out-of-distribution actions

import torch
import torch.nn as nn
import torch.nn.functional as F

class CQL:
    """
    Conservative Q-Learning

    Paper: Kumar et al., "Conservative Q-Learning for Offline RL", NeurIPS 2020
    """
    def __init__(self, state_dim, action_dim, alpha=1.0):
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_q = QNetwork(state_dim, action_dim)

        self.alpha = alpha  # CQL penalty weight

        self.optimizer = torch.optim.Adam(
            self.q_network.parameters(),
            lr=3e-4
        )

    def compute_cql_loss(self, batch):
        """
        CQL loss = TD error + penalty for high OOD Q-values

        L = L_TD + alpha * (Q_OOD - Q_data)
        """
        states = batch['states']
        actions = batch['actions']
        rewards = batch['rewards']
        next_states = batch['next_states']
        dones = batch['dones']

        # 1. Standard TD loss
        with torch.no_grad():
            # Target Q-value
            next_actions = self.select_action(next_states)
            target_q = self.target_q(next_states, next_actions)
            target = rewards + (1 - dones) * 0.99 * target_q

        current_q = self.q_network(states, actions)
        td_loss = F.mse_loss(current_q, target)

        # 2. CQL penalty
        # Sample random actions (OOD)
        random_actions = torch.rand_like(actions) * 2 - 1  # [-1, 1]
        ood_q_values = self.q_network(states, random_actions)

        # Q-values for dataset actions
        data_q_values = self.q_network(states, actions)

        # Penalty: push OOD Q-values down
        cql_penalty = (ood_q_values.mean() - data_q_values.mean())

        # Total loss
        total_loss = td_loss + self.alpha * cql_penalty

        return total_loss, {
            'td_loss': td_loss.item(),
            'cql_penalty': cql_penalty.item()
        }

    def train_step(self, batch):
        """Single training step"""
        loss, metrics = self.compute_cql_loss(batch)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network periodically
        self.soft_update_target()

        return metrics

    def soft_update_target(self, tau=0.005):
        """Polyak averaging for target network"""
        for target_param, param in zip(
            self.target_q.parameters(),
            self.q_network.parameters()
        ):
            target_param.data.copy_(
                tau * param.data + (1 - tau) * target_param.data
            )

    def select_action(self, state):
        """Select action using learned Q-function"""
        # Discrete: argmax over actions
        # Continuous: optimize action to maximize Q
        # (simplified - usually need actor network)

        with torch.no_grad():
            # Sample candidate actions
            num_samples = 100
            sampled_actions = torch.rand(num_samples, self.action_dim) * 2 - 1

            # Evaluate Q-values
            states_expanded = state.unsqueeze(0).repeat(num_samples, 1)
            q_values = self.q_network(states_expanded, sampled_actions)

            # Select best
            best_idx = q_values.argmax()
            best_action = sampled_actions[best_idx]

        return best_action


class QNetwork(nn.Module):
    """Q-network"""
    def __init__(self, state_dim, action_dim):
        super().__init__()

        self.network = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)  # Q-value
        )

    def forward(self, states, actions):
        x = torch.cat([states, actions], dim=-1)
        return self.network(x).squeeze(-1)

2. Implicit Q-Learning (IQL)

Idea: Avoid explicitly computing policy - use implicit policy from Q-function

class IQL:
    """
    Implicit Q-Learning

    Paper: Kostrikov et al., "Offline RL with Implicit Q-Learning", ICLR 2022

    Key insight: Decouple value learning from policy extraction
    """
    def __init__(self, state_dim, action_dim, tau=0.7):
        # Two Q-networks (twin Q)
        self.q1 = QNetwork(state_dim, action_dim)
        self.q2 = QNetwork(state_dim, action_dim)

        # Value network V(s) - expectile regression
        self.v_network = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

        # Actor network (for deployment)
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()
        )

        self.tau = tau  # Expectile parameter

        self.q_optimizer = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()),
            lr=3e-4
        )
        self.v_optimizer = torch.optim.Adam(self.v_network.parameters(), lr=3e-4)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

    def train_step(self, batch):
        """IQL training step"""
        states = batch['states']
        actions = batch['actions']
        rewards = batch['rewards']
        next_states = batch['next_states']
        dones = batch['dones']

        # 1. Update V-network with expectile regression
        with torch.no_grad():
            q1_val = self.q1(states, actions)
            q2_val = self.q2(states, actions)
            q_val = torch.min(q1_val, q2_val)

        v_val = self.v_network(states).squeeze(-1)

        # Expectile loss (asymmetric L2)
        v_loss = self.expectile_loss(q_val - v_val, self.tau)

        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()

        # 2. Update Q-networks
        with torch.no_grad():
            next_v = self.v_network(next_states).squeeze(-1)
            target_q = rewards + (1 - dones) * 0.99 * next_v

        q1_pred = self.q1(states, actions)
        q2_pred = self.q2(states, actions)

        q_loss = F.mse_loss(q1_pred, target_q) + F.mse_loss(q2_pred, target_q)

        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        # 3. Update actor with advantage-weighted regression
        with torch.no_grad():
            q_val = torch.min(
                self.q1(states, actions),
                self.q2(states, actions)
            )
            v_val = self.v_network(states).squeeze(-1)
            advantage = q_val - v_val

            # Advantage weights
            weights = torch.exp(advantage / 0.1)  # Temperature = 0.1
            weights = torch.clamp(weights, max=100.0)  # Clip for stability

        # Actor loss: weighted behavior cloning
        predicted_actions = self.actor(states)
        actor_loss = (weights * F.mse_loss(
            predicted_actions, actions, reduction='none'
        ).mean(dim=-1)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        return {
            'v_loss': v_loss.item(),
            'q_loss': q_loss.item(),
            'actor_loss': actor_loss.item()
        }

    def expectile_loss(self, diff, expectile):
        """
        Expectile regression loss

        Asymmetric L2 loss that learns different quantiles
        tau=0.5: mean, tau=0.7: ~upper 30%, tau=0.9: ~upper 10%
        """
        weight = torch.where(diff > 0, expectile, 1 - expectile)
        return (weight * (diff ** 2)).mean()

    def select_action(self, state):
        """Select action using learned actor"""
        with torch.no_grad():
            action = self.actor(state)
        return action.cpu().numpy()

3. Decision Transformer

Idea: Frame RL as sequence modeling - predict actions conditioned on desired returns

class DecisionTransformer(nn.Module):
    """
    Decision Transformer

    Paper: Chen et al., "Decision Transformer: Reinforcement Learning via Sequence Modeling", NeurIPS 2021

    Treats RL as conditional sequence modeling:
    (R, s1, a1, R, s2, a2, ...) → predict actions
    """
    def __init__(self, state_dim, action_dim, max_ep_len=1000, hidden_dim=128):
        super().__init__()

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim

        # Embedding layers
        self.embed_state = nn.Linear(state_dim, hidden_dim)
        self.embed_action = nn.Linear(action_dim, hidden_dim)
        self.embed_return = nn.Linear(1, hidden_dim)

        # Positional embedding
        self.embed_timestep = nn.Embedding(max_ep_len, hidden_dim)

        # Transformer
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=4,
                dim_feedforward=hidden_dim * 4,
                batch_first=True
            ),
            num_layers=6
        )

        # Prediction heads
        self.predict_action = nn.Linear(hidden_dim, action_dim)

    def forward(self, states, actions, returns_to_go, timesteps):
        """
        Args:
            states: (B, T, state_dim)
            actions: (B, T, action_dim)
            returns_to_go: (B, T, 1) - target cumulative return
            timesteps: (B, T) - timestep indices
        Returns:
            predicted_actions: (B, T, action_dim)
        """
        B, T, _ = states.shape

        # Embed each modality
        state_embeddings = self.embed_state(states)  # (B, T, hidden_dim)
        action_embeddings = self.embed_action(actions)
        return_embeddings = self.embed_return(returns_to_go)

        # Add positional encoding
        time_embeddings = self.embed_timestep(timesteps)

        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        return_embeddings = return_embeddings + time_embeddings

        # Interleave: (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # Stack: (B, 3*T, hidden_dim)
        sequence = torch.stack([
            return_embeddings,
            state_embeddings,
            action_embeddings
        ], dim=2).reshape(B, 3 * T, self.hidden_dim)

        # Transformer forward pass
        transformer_out = self.transformer(sequence)

        # Extract state embeddings (every 3rd token, offset by 1)
        state_hidden = transformer_out[:, 1::3, :]

        # Predict actions
        predicted_actions = self.predict_action(state_hidden)

        return predicted_actions

    def get_action(self, states, actions, returns_to_go, timesteps):
        """
        Inference: predict next action

        Condition on desired return to guide behavior
        """
        # Ensure correct shapes
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.action_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        # Forward pass
        predicted_actions = self.forward(states, actions, returns_to_go, timesteps)

        # Return last predicted action
        return predicted_actions[0, -1]


def train_decision_transformer(model, dataset, config):
    """Train Decision Transformer"""

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    for epoch in range(config.num_epochs):
        for batch in dataset:
            # Batch contains full trajectories
            states = batch['states']  # (B, T, state_dim)
            actions = batch['actions']  # (B, T, action_dim)
            rewards = batch['rewards']  # (B, T)
            timesteps = batch['timesteps']  # (B, T)

            # Compute returns-to-go
            returns_to_go = torch.zeros_like(rewards)
            for t in reversed(range(rewards.shape[1])):
                if t == rewards.shape[1] - 1:
                    returns_to_go[:, t] = rewards[:, t]
                else:
                    returns_to_go[:, t] = rewards[:, t] + 0.99 * returns_to_go[:, t+1]

            returns_to_go = returns_to_go.unsqueeze(-1)  # (B, T, 1)

            # Predict actions
            predicted_actions = model(states, actions, returns_to_go, timesteps)

            # Loss: MSE between predicted and true actions
            loss = F.mse_loss(predicted_actions, actions)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

    return model


# Inference: condition on desired return
def inference_with_desired_return(model, env, target_return=500):
    """Deploy Decision Transformer conditioned on target return"""

    obs = env.reset()
    done = False

    # Trajectory history
    states = [obs]
    actions = [np.zeros(model.action_dim)]  # Dummy first action
    returns_to_go = [target_return]
    timesteps = [0]

    t = 0
    while not done:
        # Get action from model
        action = model.get_action(
            states=torch.FloatTensor(states),
            actions=torch.FloatTensor(actions),
            returns_to_go=torch.FloatTensor(returns_to_go),
            timesteps=torch.LongTensor(timesteps)
        )

        # Execute
        next_obs, reward, done, _ = env.step(action.numpy())

        # Update history
        states.append(next_obs)
        actions.append(action.numpy())
        returns_to_go.append(returns_to_go[-1] - reward)  # Decrement return
        timesteps.append(t + 1)

        t += 1

    print(f"Episode return: {target_return - returns_to_go[-1]}")

Dataset Quality Matters

Offline RL performance heavily depends on dataset quality:

class OfflineDatasetAnalyzer:
    """Analyze offline RL dataset quality"""
    def __init__(self, dataset):
        self.dataset = dataset

    def analyze(self):
        """Comprehensive dataset analysis"""
        print("="*60)
        print("OFFLINE RL DATASET ANALYSIS")
        print("="*60)

        # 1. Size
        print(f"\nDataset Size: {len(self.dataset)} transitions")

        # 2. Return distribution
        returns = self.compute_episode_returns()
        print(f"\nReturn Statistics:")
        print(f"  Mean: {np.mean(returns):.2f}")
        print(f"  Std: {np.std(returns):.2f}")
        print(f"  Min: {np.min(returns):.2f}")
        print(f"  Max: {np.max(returns):.2f}")

        # 3. Coverage
        state_coverage = self.estimate_state_coverage()
        print(f"\nState Space Coverage: {state_coverage:.2%}")

        # 4. Quality score
        quality = self.compute_quality_score()
        print(f"\nDataset Quality Score: {quality:.2f} / 10")

        if quality >= 7:
            print("✓ High-quality dataset - offline RL should work well")
        elif quality >= 5:
            print("⚠️  Medium-quality - consider data augmentation")
        else:
            print("✗Low-quality - offline RL may struggle")

        print("="*60)

    def compute_episode_returns(self):
        """Compute return for each episode"""
        returns = []
        current_return = 0

        for transition in self.dataset:
            current_return += transition['reward']

            if transition['done']:
                returns.append(current_return)
                current_return = 0

        return np.array(returns)

    def estimate_state_coverage(self):
        """Estimate what fraction of state space is covered"""
        states = np.array([t['state'] for t in self.dataset])

        # Discretize state space
        bins_per_dim = 10
        covered_bins = set()

        for state in states:
            bin_coords = tuple(
                int((s + 1) / 2 * bins_per_dim)  # Assume states in [-1, 1]
                for s in state
            )
            covered_bins.add(bin_coords)

        total_bins = bins_per_dim ** states.shape[1]
        coverage = len(covered_bins) / total_bins

        return coverage

    def compute_quality_score(self):
        """Overall dataset quality score [0-10]"""
        scores = []

        # 1. Return score (higher is better)
        returns = self.compute_episode_returns()
        return_score = min(10, np.mean(returns) / 100)  # Normalize
        scores.append(return_score)

        # 2. Diversity score
        coverage = self.estimate_state_coverage()
        diversity_score = min(10, coverage * 100)
        scores.append(diversity_score)

        # 3. Size score
        size_score = min(10, len(self.dataset) / 100000)
        scores.append(size_score)

        return np.mean(scores)

Best Practices

DO:

✓ Use high-quality datasets (expert or near-expert) ✓ Check dataset coverage before training ✓ Start with CQL or IQL (most robust) ✓ Use conservative hyperparameters (low learning rate) ✓ Evaluate on same distribution as training data ✓ Monitor Q-value overestimation

DON'T:

✗Expect to outperform dataset performance significantly ✗Use tiny datasets (<1000 transitions) ✗Skip dataset quality analysis ✗Use high exploration during evaluation ✗Ignore distributional shift warnings

When to Use Offline RL

Use when: - Have existing dataset of demonstrations - Online interaction expensive/unsafe - Need to leverage sub-optimal data - Exploring is dangerous

Don't use when: - Can cheaply collect online data - Need to significantly exceed dataset performance - Dataset quality is very poor - Have very little data (<1000 transitions)

Resources

Key Papers: - Kumar et al., "Conservative Q-Learning for Offline RL", NeurIPS 2020 - Kostrikov et al., "Offline RL with Implicit Q-Learning", ICLR 2022 - Chen et al., "Decision Transformer", NeurIPS 2021 - Fujimoto & Gu, "A Minimalist Approach to Offline RL", NeurIPS 2021

Next Steps