Stable-Baselines3: Easy-to-Use RL¶
Stable-Baselines3 (SB3) is the most user-friendly and well-documented RL library, perfect for research and rapid prototyping.
Overview¶
Stable-Baselines3 is the successor to Stable-Baselines (based on OpenAI Baselines), providing:
- Production-ready implementations of major RL algorithms
- Extensive documentation and tutorials
- Easy customization and extension
- Active maintenance and community support
Key Features: - ✓ Complete implementations (PPO, A2C, SAC, TD3, DQN, DDPG) - ✓ Extensive callbacks and logging - ✓ Pre-trained models (RL Zoo) - ✓ Hyperparameter optimization - ✓ TensorBoard, WandB integration - ✓ Excellent documentation
Official Repository: https://github.com/DLR-RM/stable-baselines3
Installation¶
Basic Installation¶
# Install from PyPI
pip install stable-baselines3[extra]
# Core only (minimal dependencies)
pip install stable-baselines3
# Latest from GitHub
pip install git+https://github.com/DLR-RM/stable-baselines3
With Extra Dependencies¶
# Full installation with all features
pip install stable-baselines3[extra]
# Includes:
# - TensorBoard
# - rich (progress bars)
# - matplotlib (plotting)
# - pandas (logging)
RL Zoo (Pre-trained Models)¶
# Install RL Zoo
pip install rl-zoo3
# Download pre-trained models
python -m rl_zoo3.load_from_hub --algo ppo --env CartPole-v1
Quick Start¶
Basic Training¶
import gymnasium as gym
from stable_baselines3 import PPO
# Create environment
env = gym.make("CartPole-v1")
# Create model
model = PPO("MlpPolicy", env, verbose=1)
# Train
model.learn(total_timesteps=10_000)
# Save
model.save("ppo_cartpole")
# Load
model = PPO.load("ppo_cartpole")
# Evaluate
obs, info = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, info = env.reset()
Available Algorithms¶
from stable_baselines3 import PPO, A2C, SAC, TD3, DQN, DDPG
# On-policy
ppo = PPO("MlpPolicy", env)
a2c = A2C("MlpPolicy", env)
# Off-policy (continuous)
sac = SAC("MlpPolicy", env)
td3 = TD3("MlpPolicy", env)
ddpg = DDPG("MlpPolicy", env)
# Off-policy (discrete)
dqn = DQN("MlpPolicy", env)
Core Components¶
1. Policies¶
SB3 provides several pre-built policy networks:
from stable_baselines3 import PPO
# Multi-Layer Perceptron (MLP) policy
model = PPO("MlpPolicy", env)
# Convolutional Neural Network (CNN) policy (for images)
model = PPO("CnnPolicy", env)
# Multi-input policy (for dict observations)
model = PPO("MultiInputPolicy", env)
# Custom policy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import torch.nn as nn
class CustomCNN(BaseFeaturesExtractor):
"""Custom CNN feature extractor"""
def __init__(self, observation_space, features_dim=256):
super().__init__(observation_space, features_dim)
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.ReLU(),
nn.Flatten(),
)
# Compute shape by doing one forward pass
with torch.no_grad():
n_flatten = self.cnn(
torch.as_tensor(observation_space.sample()[None]).float()
).shape[1]
self.linear = nn.Sequential(
nn.Linear(n_flatten, features_dim),
nn.ReLU()
)
def forward(self, observations):
return self.linear(self.cnn(observations))
# Use custom policy
policy_kwargs = dict(
features_extractor_class=CustomCNN,
features_extractor_kwargs=dict(features_dim=256),
)
model = PPO("CnnPolicy", env, policy_kwargs=policy_kwargs)
2. Training Configuration¶
Configure algorithm hyperparameters:
from stable_baselines3 import PPO
model = PPO(
"MlpPolicy",
env,
# Training hyperparameters
learning_rate=3e-4,
n_steps=2048, # Rollout buffer size
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
clip_range_vf=None, # No clipping for value function
normalize_advantage=True,
ent_coef=0.0, # Entropy coefficient
vf_coef=0.5, # Value function coefficient
max_grad_norm=0.5,
# Network architecture
policy_kwargs=dict(
net_arch=[dict(pi=[256, 256], vf=[256, 256])], # Separate actor-critic
activation_fn=nn.ReLU,
ortho_init=True
),
# Logging
verbose=1,
tensorboard_log="./ppo_tensorboard/",
# Device
device="cuda"
)
3. Callbacks¶
SB3 provides powerful callbacks for monitoring and control:
from stable_baselines3.common.callbacks import (
EvalCallback,
CheckpointCallback,
CallbackList,
StopTrainingOnRewardThreshold
)
# Evaluation callback
eval_callback = EvalCallback(
eval_env,
best_model_save_path="./logs/",
log_path="./logs/",
eval_freq=10000,
deterministic=True,
render=False
)
# Checkpoint callback
checkpoint_callback = CheckpointCallback(
save_freq=10000,
save_path="./checkpoints/",
name_prefix="ppo_model"
)
# Stop training when reward threshold reached
callback_on_best = StopTrainingOnRewardThreshold(
reward_threshold=200,
verbose=1
)
# Combine callbacks
callbacks = CallbackList([
eval_callback,
checkpoint_callback,
callback_on_best
])
# Train with callbacks
model.learn(total_timesteps=1_000_000, callback=callbacks)
4. Custom Callbacks¶
Create custom callbacks for specific needs:
from stable_baselines3.common.callbacks import BaseCallback
class CustomCallback(BaseCallback):
"""
Custom callback for saving model based on custom metric
Example: Save model when episode length > threshold
"""
def __init__(self, check_freq: int, save_path: str, verbose=1):
super().__init__(verbose)
self.check_freq = check_freq
self.save_path = save_path
self.best_mean_length = 0
def _on_step(self) -> bool:
if self.n_calls % self.check_freq == 0:
# Retrieve episode lengths
if len(self.model.ep_info_buffer) > 0:
mean_length = np.mean([ep_info["l"] for ep_info in self.model.ep_info_buffer])
if mean_length > self.best_mean_length:
self.best_mean_length = mean_length
if self.verbose > 0:
print(f"Saving new best model with mean length: {mean_length:.2f}")
self.model.save(os.path.join(self.save_path, "best_model"))
return True
# Use custom callback
custom_callback = CustomCallback(check_freq=1000, save_path="./logs/")
model.learn(total_timesteps=100_000, callback=custom_callback)
Advanced Features¶
Vectorized Environments¶
Train on multiple environments in parallel:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
# Method 1: Simple wrapper (single process)
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=DummyVecEnv)
# Method 2: Multi-process (faster)
env = make_vec_env("CartPole-v1", n_envs=4, vec_env_cls=SubprocVecEnv)
# Train on vectorized environment
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=100_000)
Wrappers¶
SB3 includes many useful wrappers:
from stable_baselines3.common.vec_env import VecNormalize, VecFrameStack
from stable_baselines3.common.atari_wrappers import AtariWrapper
import gymnasium as gym
# Normalize observations and rewards
env = make_vec_env("Ant-v4", n_envs=4)
env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0)
# Frame stacking (for Atari/vision)
env = make_vec_env("PongNoFrameskip-v4", n_envs=4)
env = VecFrameStack(env, n_stack=4)
# Atari preprocessing
env = gym.make("PongNoFrameskip-v4")
env = AtariWrapper(env)
# Train
model = PPO("CnnPolicy", env)
model.learn(total_timesteps=1_000_000)
# Save normalization statistics
env.save("vec_normalize.pkl")
# Load for evaluation
env = make_vec_env("Ant-v4", n_envs=1)
env = VecNormalize.load("vec_normalize.pkl", env)
env.training = False # Don't update statistics
env.norm_reward = False # Don't normalize rewards
Replay Buffers¶
For off-policy algorithms:
from stable_baselines3 import SAC
from stable_baselines3.common.buffers import ReplayBuffer
model = SAC(
"MlpPolicy",
env,
buffer_size=1_000_000, # Replay buffer size
learning_starts=10_000, # Start training after N steps
batch_size=256,
tau=0.005,
gamma=0.99,
train_freq=1, # Update every step
gradient_steps=1, # Number of gradient steps per update
)
# Access replay buffer
print(f"Buffer size: {model.replay_buffer.size()}")
print(f"Buffer capacity: {model.replay_buffer.buffer_size}")
# Sample from buffer
if model.replay_buffer.size() > 0:
replay_data = model.replay_buffer.sample(batch_size=256)
print(f"Observations shape: {replay_data.observations.shape}")
Hyperparameter Optimization¶
Use Optuna for automated hyperparameter tuning:
import optuna
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
import gymnasium as gym
def optimize_ppo(trial):
"""Objective function for Optuna"""
# Sample hyperparameters
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
n_steps = trial.suggest_categorical("n_steps", [256, 512, 1024, 2048])
batch_size = trial.suggest_categorical("batch_size", [32, 64, 128, 256])
n_epochs = trial.suggest_int("n_epochs", 5, 20)
gamma = trial.suggest_float("gamma", 0.9, 0.9999, log=True)
gae_lambda = trial.suggest_float("gae_lambda", 0.8, 0.99)
clip_range = trial.suggest_float("clip_range", 0.1, 0.4)
ent_coef = trial.suggest_float("ent_coef", 1e-8, 1e-1, log=True)
# Create environment
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
# Create model with sampled hyperparameters
model = PPO(
"MlpPolicy",
env,
learning_rate=learning_rate,
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
gamma=gamma,
gae_lambda=gae_lambda,
clip_range=clip_range,
ent_coef=ent_coef,
verbose=0
)
# Train
model.learn(total_timesteps=100_000)
# Evaluate
mean_reward, std_reward = evaluate_policy(
model, eval_env, n_eval_episodes=10
)
return mean_reward
# Run optimization
study = optuna.create_study(direction="maximize")
study.optimize(optimize_ppo, n_trials=100, timeout=3600)
print("Best hyperparameters:")
print(study.best_params)
print(f"Best reward: {study.best_value}")
Pre-trained Models (RL Zoo)¶
Use pre-trained models from RL Zoo:
# List available models
python -m rl_zoo3.cli list
# Download pre-trained model
python -m rl_zoo3.load_from_hub --algo ppo --env HalfCheetah-v4 -f logs/ -orga sb3
# Evaluate pre-trained model
python -m rl_zoo3.enjoy --algo ppo --env HalfCheetah-v4 -f logs/ -n 5000
Python API:
from rl_zoo3 import ALGOS
from rl_zoo3.load_from_hub import download_from_hub
# Download model
model_path = download_from_hub(
algo="ppo",
env_name="HalfCheetah-v4",
org="sb3"
)
# Load and use
model = PPO.load(model_path)
env = gym.make("HalfCheetah-v4")
obs, info = env.reset()
for _ in range(1000):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
Algorithm-Specific Examples¶
PPO (On-Policy)¶
from stable_baselines3 import PPO
model = PPO(
"MlpPolicy",
env,
learning_rate=3e-4,
n_steps=2048,
batch_size=64,
n_epochs=10,
gamma=0.99,
gae_lambda=0.95,
clip_range=0.2,
ent_coef=0.0,
verbose=1
)
model.learn(total_timesteps=1_000_000)
SAC (Off-Policy, Continuous)¶
from stable_baselines3 import SAC
model = SAC(
"MlpPolicy",
env,
learning_rate=3e-4,
buffer_size=1_000_000,
learning_starts=10_000,
batch_size=256,
tau=0.005,
gamma=0.99,
train_freq=1,
gradient_steps=1,
ent_coef='auto', # Automatic entropy tuning
verbose=1
)
model.learn(total_timesteps=1_000_000)
TD3 (Off-Policy, Continuous)¶
from stable_baselines3 import TD3
from stable_baselines3.common.noise import NormalActionNoise
# Action noise for exploration
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(
mean=np.zeros(n_actions),
sigma=0.1 * np.ones(n_actions)
)
model = TD3(
"MlpPolicy",
env,
learning_rate=1e-3,
buffer_size=1_000_000,
learning_starts=10_000,
batch_size=100,
tau=0.005,
gamma=0.99,
train_freq=(1, "episode"),
gradient_steps=-1, # Update at end of each episode
action_noise=action_noise,
policy_delay=2, # Delayed policy updates
target_policy_noise=0.2,
target_noise_clip=0.5,
verbose=1
)
model.learn(total_timesteps=1_000_000)
DQN (Off-Policy, Discrete)¶
from stable_baselines3 import DQN
model = DQN(
"MlpPolicy",
env,
learning_rate=1e-4,
buffer_size=100_000,
learning_starts=10_000,
batch_size=32,
tau=1.0,
gamma=0.99,
train_freq=4,
gradient_steps=1,
target_update_interval=10_000,
exploration_fraction=0.1,
exploration_initial_eps=1.0,
exploration_final_eps=0.05,
verbose=1
)
model.learn(total_timesteps=1_000_000)
Robotics Integration¶
Custom Gym Environment¶
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class RobotEnv(gym.Env):
"""Custom robot environment"""
def __init__(self):
super().__init__()
# Define action and observation space
self.action_space = spaces.Box(
low=-1.0, high=1.0, shape=(4,), dtype=np.float32
)
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32
)
# Initialize robot (pseudocode)
# self.robot = Robot()
def reset(self, seed=None, options=None):
super().reset(seed=seed)
# Reset robot to initial state
# observation = self.robot.get_observation()
observation = np.zeros(10, dtype=np.float32)
info = {}
return observation, info
def step(self, action):
# Apply action to robot
# self.robot.apply_action(action)
# Get new observation
# observation = self.robot.get_observation()
observation = np.zeros(10, dtype=np.float32)
# Calculate reward
reward = self._calculate_reward(observation, action)
# Check if done
terminated = False # Task completed
truncated = False # Time limit reached
info = {}
return observation, reward, terminated, truncated, info
def _calculate_reward(self, observation, action):
# Example reward function
reward = -np.linalg.norm(observation[:3]) # Distance to goal
reward -= 0.01 * np.linalg.norm(action) # Action penalty
return reward
# Register environment
gym.register(
id='Robot-v0',
entry_point='__main__:RobotEnv',
max_episode_steps=1000
)
# Use with SB3
env = gym.make('Robot-v0')
model = PPO("MlpPolicy", env)
model.learn(total_timesteps=100_000)
Real Robot Example¶
from stable_baselines3 import SAC
import gymnasium as gym
# Connect to real robot
env = gym.make('RealRobot-v0') # Your robot environment
# Load pre-trained model (from sim)
model = SAC.load("sim_policy")
# Fine-tune on real robot (with small LR)
model.learning_rate = 1e-5
model.set_env(env)
# Fine-tune with safety limits
from stable_baselines3.common.callbacks import BaseCallback
class SafetyCallback(BaseCallback):
"""Stop training if unsafe behavior detected"""
def __init__(self, force_limit=100.0):
super().__init__()
self.force_limit = force_limit
def _on_step(self) -> bool:
# Check safety constraints
if "force" in self.locals["infos"][0]:
force = self.locals["infos"][0]["force"]
if force > self.force_limit:
print(f"Safety violation! Force: {force}")
return False # Stop training
return True
safety_callback = SafetyCallback(force_limit=100.0)
model.learn(total_timesteps=10_000, callback=safety_callback)
Tips & Best Practices¶
Choosing an Algorithm¶
# For continuous control (robotics)
# - Sample efficient: SAC
# - Stable: TD3
# - Fast convergence: PPO
# For discrete control (games)
# - Best overall: DQN
# - On-policy: A2C/PPO
# Quick guide
if env.action_space.__class__.__name__ == 'Box':
# Continuous actions
if you_need_sample_efficiency:
algorithm = SAC
elif you_need_stability:
algorithm = TD3
else:
algorithm = PPO
else:
# Discrete actions
algorithm = DQN
Debugging Tips¶
# 1. Start simple
env = gym.make("CartPole-v1")
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000)
# 2. Check environment
from stable_baselines3.common.env_checker import check_env
check_env(env)
# 3. Monitor training
from stable_baselines3.common.monitor import Monitor
env = Monitor(env, "./logs/")
# 4. Visualize with TensorBoard
tensorboard --logdir ./ppo_tensorboard/
# 5. Evaluate frequently
from stable_baselines3.common.evaluation import evaluate_policy
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
print(f"Mean reward: {mean_reward} +/- {std_reward}")
Common Issues¶
Problem: Slow training
# Solution: Use vectorized environments
env = make_vec_env("Ant-v4", n_envs=8, vec_env_cls=SubprocVecEnv)
Problem: Unstable training
Problem: Poor exploration
# Solution: Increase entropy coefficient (PPO)
model = PPO("MlpPolicy", env, ent_coef=0.01)
# Or increase exploration noise (TD3/DDPG)
action_noise = NormalActionNoise(mean=0, sigma=0.3)
model = TD3("MlpPolicy", env, action_noise=action_noise)
References¶
Official Resources¶
- Documentation: https://stable-baselines3.readthedocs.io/
- GitHub: https://github.com/DLR-RM/stable-baselines3
- RL Zoo: https://github.com/DLR-RM/rl-baselines3-zoo
Tutorials¶
- Getting Started: https://stable-baselines3.readthedocs.io/en/master/guide/examples.html
- Custom Policies: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html
- Callbacks: https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html
Papers¶
SB3 implements algorithms from:
- PPO: Schulman et al., "Proximal Policy Optimization", 2017
- SAC: Haarnoja et al., "Soft Actor-Critic", 2018
- TD3: Fujimoto et al., "Addressing Function Approximation Error", 2018
- DQN: Mnih et al., "Human-level control through deep RL", Nature 2015
Community¶
- GitHub Discussions: https://github.com/DLR-RM/stable-baselines3/discussions
- Discord: https://discord.gg/nnWPWFbcCK
- RL Discord: https://discord.gg/xhfNqQv
Next Steps¶
- RSL-RL - Isaac Lab specialized library
- RL Games - High-performance alternative
- SB3 Contrib - Additional algorithms (TQC, QR-DQN, etc.)