Skip to content

Imitation Learning Methods

Detailed guide to specific imitation learning algorithms.

Behavioral Cloning (BC)

Standard BC

Simplest IL approach: supervised learning from demonstrations.

import torch
import torch.nn as nn

class BCPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, obs):
        return self.net(obs)

def train_bc(policy, demonstrations, epochs=100):
    optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        for batch in demonstrations:
            obs, actions = batch

            predicted_actions = policy(obs)
            loss = criterion(predicted_actions, actions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return policy

Pros: - Simple to implement - Fast training - No environment interaction needed

Cons: - Distribution shift problem - Cannot recover from errors - Averages multi-modal actions

BC with Data Augmentation

Improve robustness with augmentation:

def augment_observation(obs):
    # Add noise
    obs_noisy = obs + torch.randn_like(obs) * 0.01

    # Random crop/shift for images
    if obs.dim() == 4:  # Images
        obs_noisy = random_crop(obs_noisy)

    # Random rotation for proprioception
    obs_noisy = random_rotate_coords(obs_noisy)

    return obs_noisy

# Training with augmentation
for obs, action in demonstrations:
    obs_aug = augment_observation(obs)
    loss = criterion(policy(obs_aug), action)
    loss.backward()

Goal-Conditioned BC

Learn policy conditioned on goals:

class GoalConditionedPolicy(nn.Module):
    def __init__(self, obs_dim, goal_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim + goal_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, obs, goal):
        return self.net(torch.cat([obs, goal], dim=-1))

# Can achieve different goals with one policy
action = policy(obs, desired_goal)

DAgger (Dataset Aggregation)

Interactive IL that addresses distribution shift.

Algorithm

def dagger(expert, env, initial_dataset, num_iterations=10):
    policy = train_bc(initial_dataset)
    dataset = initial_dataset

    for iteration in range(num_iterations):
        # Collect rollouts with current policy
        rollout_states = []
        obs = env.reset()

        for step in range(episode_length):
            # Execute policy
            action = policy(obs)
            rollout_states.append(obs)

            obs, _, done, _ = env.step(action)
            if done:
                obs = env.reset()

        # Query expert for actions on visited states
        expert_actions = [expert.get_action(s) for s in rollout_states]

        # Add to dataset
        dataset.add(rollout_states, expert_actions)

        # Retrain policy
        policy = train_bc(policy, dataset)

    return policy

Hyperparameters:

dagger:
  num_iterations: 10
  rollouts_per_iteration: 20
  expert_queries_per_rollout: 100
  mixing_parameter: 0.5  # Mix expert and policy actions

Practical Considerations

# Beta-DAgger: Gradually reduce expert influence
beta_schedule = [1.0, 0.8, 0.6, 0.4, 0.2, 0.1, 0.0]

for iteration, beta in enumerate(beta_schedule):
    # Collect data with β-mixture of expert and policy
    for step in rollout:
        if random.random() < beta:
            action = expert(obs)
        else:
            action = policy(obs)

Generative Behavioral Cloning

Handle multi-modal action distributions.

Diffusion Policy

class DiffusionPolicy:
    def __init__(self, obs_dim, action_dim, num_steps=100):
        self.noise_predictor = UNet(obs_dim + action_dim, action_dim)
        self.num_steps = num_steps

    def train_step(self, obs, expert_actions):
        # Sample random timestep
        t = torch.randint(0, self.num_steps, (batch_size,))

        # Add noise to expert actions
        noise = torch.randn_like(expert_actions)
        noisy_actions = self.add_noise(expert_actions, noise, t)

        # Predict noise
        predicted_noise = self.noise_predictor(
            torch.cat([obs, noisy_actions], dim=-1),
            t
        )

        # Loss
        loss = F.mse_loss(predicted_noise, noise)
        return loss

    def predict(self, obs):
        # Start from random noise
        actions = torch.randn(batch_size, action_dim)

        # Iteratively denoise
        for t in reversed(range(self.num_steps)):
            predicted_noise = self.noise_predictor(
                torch.cat([obs, actions], dim=-1),
                t
            )
            actions = self.denoise_step(actions, predicted_noise, t)

        return actions

VAE Policy

class VAEPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim, latent_dim=32):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(obs_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)  # mean and logvar
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(obs_dim + latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def encode(self, obs, action):
        h = self.encoder(torch.cat([obs, action], dim=-1))
        mean, logvar = torch.chunk(h, 2, dim=-1)
        return mean, logvar

    def decode(self, obs, z):
        return self.decoder(torch.cat([obs, z], dim=-1))

    def forward(self, obs, action=None):
        if action is not None:  # Training
            mean, logvar = self.encode(obs, action)
            z = self.reparameterize(mean, logvar)
            reconstructed_action = self.decode(obs, z)

            # VAE loss
            recon_loss = F.mse_loss(reconstructed_action, action)
            kl_loss = -0.5 * torch.sum(1 + logvar - mean**2 - logvar.exp())

            return reconstructed_action, recon_loss + kl_loss
        else:  # Inference
            z = torch.randn(obs.shape[0], latent_dim)
            return self.decode(obs, z)

Inverse Reinforcement Learning (IRL)

Learn reward function from demonstrations.

MaxEnt IRL

def max_ent_irl(expert_demos, env, num_iterations=100):
    # Initialize reward function
    reward_function = nn.Sequential(
        nn.Linear(state_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 1)
    )

    for iteration in range(num_iterations):
        # 1. Compute feature expectations under expert
        expert_features = compute_features(expert_demos)

        # 2. Train policy with current reward using RL
        policy = train_rl_policy(reward_function, env)

        # 3. Compute feature expectations under learned policy
        policy_features = compute_features(collect_rollouts(policy, env))

        # 4. Update reward to match expert feature expectations
        loss = F.mse_loss(policy_features, expert_features)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return reward_function, policy

Adversarial IRL (GAIL/AIRL)

class Discriminator(nn.Module):
    """Distinguish expert from policy trajectories"""
    def __init__(self, state_dim, action_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, state, action):
        return self.net(torch.cat([state, action], dim=-1))

def train_gail(expert_demos, env):
    policy = PPO('MlpPolicy', env)
    discriminator = Discriminator(state_dim, action_dim)

    for iteration in range(num_iterations):
        # Collect policy rollouts
        policy_data = collect_rollouts(policy, env)

        # Train discriminator
        for epoch in range(5):
            # Label expert as 1, policy as 0
            disc_loss = (
                -torch.log(discriminator(expert_demos)).mean() +
                -torch.log(1 - discriminator(policy_data)).mean()
            )
            disc_loss.backward()
            disc_optimizer.step()

        # Use discriminator output as reward
        def learned_reward(state, action):
            return -torch.log(1 - discriminator(state, action))

        # Train policy with learned reward
        policy.learn(reward_fn=learned_reward)

    return policy

One-Shot Imitation Learning

Learn from single demonstration.

class OneeShotPolicy(nn.Module):
    def __init__(self, obs_dim, action_dim):
        super().__init__()

        # Context encoder (processes demonstration)
        self.context_encoder = nn.LSTM(obs_dim + action_dim, 128)

        # Policy network (conditioned on context)
        self.policy = nn.Sequential(
            nn.Linear(obs_dim + 128, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def encode_demonstration(self, demo_obs, demo_actions):
        # Process demonstration sequence
        demo_sequence = torch.cat([demo_obs, demo_actions], dim=-1)
        _, (context, _) = self.context_encoder(demo_sequence)
        return context.squeeze(0)

    def forward(self, obs, context):
        # Generate action conditioned on demonstration context
        return self.policy(torch.cat([obs, context], dim=-1))

# Usage
context = policy.encode_demonstration(demo_obs, demo_actions)
action = policy(current_obs, context)

Next Steps