Step-by-Step Tutorial on Deep Q-Network (DQN)

Listen to this article

This tutorial provides a comprehensive guide to implementing a Deep Q-Network (DQN), a reinforcement learning algorithm that combines Q-learning with deep neural networks. We’ll cover the theory, key components, and provide a Python implementation using PyTorch.

Prerequisites

  • Basic understanding of Python, PyTorch, and reinforcement learning concepts.
  • Familiarity with neural networks and Q-learning.
  • Libraries: PyTorch, NumPy, Gymnasium (for environments like CartPole).

Step 1: Understand the Theory of DQN

DQN is a model-free, off-policy reinforcement learning algorithm that approximates the Q-value function using a deep neural network. The agent learns to make decisions by estimating the expected future rewards for each action in a given state.

Key Concepts

  • Q-Value: Represents the expected cumulative reward for taking action (a) in state (s) and following the policy thereafter.
  • Neural Network: Approximates the Q-function, mapping states to action values.
  • Experience Replay: Stores past experiences (state, action, reward, next state) in a replay buffer to break correlation in sequential data.
  • Target Network: A separate network used to stabilize training by providing consistent Q-value targets.
  • Epsilon-Greedy Policy: Balances exploration (random actions) and exploitation (best-known actions).
  • Loss Function: Typically mean squared error between predicted and target Q-values.

DQN Algorithm Overview

  1. Initialize the Q-network and target network with random weights.
  2. For each episode:
    • Reset the environment.
    • For each step:
      • Select an action using epsilon-greedy policy.
      • Execute the action, observe reward and next state.
      • Store the experience in the replay buffer.
      • Sample a batch from the replay buffer.
      • Compute the target Q-values using the target network.
      • Update the Q-network by minimizing the loss.
      • Periodically update the target network.
  3. Repeat until convergence or a set number of episodes.

Step 2: Set Up the Environment

We’ll use the CartPole-v1 environment from Gymnasium, where the goal is to balance a pole on a cart.

Install Dependencies

pip install gymnasium torch numpy

Step 3: Implement the DQN

Below is a complete Python implementation of DQN with detailed explanations.

Import Libraries

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

Define the Q-Network

Create a neural network to approximate the Q-function.

class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
    
    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

Implement Experience Replay

Use a deque to store experiences and sample random batches.

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        state, action, reward, next_state, done = zip(*random.sample(self.buffer, batch_size))
        return np.array(state), action, reward, np.array(next_state), done
    
    def __len__(self):
        return len(self.buffer)

DQN Agent

Define the agent that interacts with the environment and updates the Q-network.

class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Hyperparameters
        self.gamma = 0.99  # Discount factor
        self.epsilon = 1.0  # Exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.batch_size = 64
        self.memory = ReplayBuffer(10000)
        self.update_target_freq = 1000
        
        # Networks
        self.q_network = QNetwork(state_size, action_size).to(self.device)
        self.target_network = QNetwork(state_size, action_size).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
    
    def act(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.action_size)
        state = torch.FloatTensor(state).to(self.device)
        with torch.no_grad():
            q_values = self.q_network(state)
        return np.argmax(q_values.cpu().numpy())
    
    def train(self, step):
        if len(self.memory) < self.batch_size:
            return
        
        states, actions, rewards, next_states, dones = self.memory.sample(self.batch_size)
        
        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        # Compute Q-values
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            targets = rewards + self.gamma * next_q_values * (1 - dones)
        
        # Compute loss
        loss = nn.MSELoss()(q_values, targets)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        # Update epsilon
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
        
        # Update target network
        if step % self.update_target_freq == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())
    
    def store(self, state, action, reward, next_state, done):
        self.memory.push(state, action, reward, next_state, done)

Training Loop

Train the agent on the CartPole environment.

def train_dqn(episodes):
    env = gym.make("CartPole-v1")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    step = 0
    
    for episode in range(episodes):
        state, _ = env.reset()
        episode_reward = 0
        
        while True:
            action = agent.act(state)
            next_state, reward, done, truncated, _ = env.step(action)
            done = done or truncated
            
            agent.store(state, action, reward, next_state, done)
            agent.train(step)
            
            state = next_state
            episode_reward += reward
            step += 1
            
            if done:
                print(f"Episode: {episode+1}/{episodes}, Reward: {episode_reward}, Epsilon: {agent.epsilon:.3f}")
                break
    
    env.close()

# Run training
train_dqn(1000)

Step 4: Explanation of Key Components

Q-Network

  • A simple feedforward neural network with two hidden layers.
  • Input: State vector (e.g., 4D for CartPole).
  • Output: Q-values for each action (e.g., 2 for left/right).

Experience Replay

  • Stores experiences to decorrelate training data.
  • Random sampling improves stability and efficiency.

Target Network

  • A copy of the Q-network updated less frequently.
  • Provides stable Q-value targets to prevent oscillations.

Epsilon-Greedy Policy

  • Initially, high epsilon (e.g., 1.0) encourages exploration.
  • Gradually decays to favor exploitation (e.g., 0.01).

Loss Function

  • Mean squared error between predicted Q-values and target Q-values:
    [
    \text{Loss} = \mathbb{E}[(r + \gamma \max_{a’} Q(s’, a’; \theta^-) – Q(s, a; \theta))^2]
    ]
    where (\theta) are Q-network parameters, and (\theta^-) are target network parameters.

Step 5: Hyperparameter Tuning

  • Gamma (0.99): Balances immediate vs. future rewards.
  • Learning Rate (0.001): Controls step size in gradient descent.
  • Batch Size (64): Number of experiences sampled per update.
  • Epsilon Decay (0.995): Rate at which exploration decreases.
  • Target Update Frequency (1000): Steps between target network updates.

Experiment with these values to improve performance on different environments.

Step 6: Testing the Trained Agent

After training, test the agent with a low epsilon (e.g., 0.01) to evaluate its performance.

def test_dqn():
    env = gym.make("CartPole-v1", render_mode="human")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = DQNAgent(state_size, action_size)
    agent.epsilon = 0.01  # Minimal exploration
    
    state, _ = env.reset()
    episode_reward = 0
    
    while True:
        action = agent.act(state)
        state, reward, done, truncated, _ = env.step(action)
        episode_reward += reward
        env.render()
        if done or truncated:
            print(f"Test Episode Reward: {episode_reward}")
            break
    
    env.close()

Step 7: Enhancements and Variants

  • Double DQN: Reduces overestimation by decoupling action selection and evaluation.
  • Dueling DQN: Separates state value and action advantage estimation.
  • Prioritized Experience Replay: Samples important experiences more frequently.
  • Noisy Nets: Adds noise to network parameters for exploration.

Step 8: Debugging Tips

  • Reward Monitoring: Ensure episode rewards increase over time.
  • Q-Value Inspection: Check if Q-values are reasonable and stable.
  • Replay Buffer: Verify experiences are stored and sampled correctly.
  • Gradient Clipping: Add if gradients explode (e.g., torch.nn.utils.clip_grad_norm_).

Step 9: Running the Code

  1. Save the code in a Python file (e.g., dqn_cartpole.py).
  2. Ensure dependencies are installed.
  3. Run the script:python dqn_cartpole.py
  4. Expect the agent to achieve rewards close to 500 (CartPole’s max) after 500–1000 episodes.

Conclusion

This tutorial covered the theory and implementation of a DQN for the CartPole environment. You can extend this to other environments (e.g., Atari games) by adjusting the network architecture and hyperparameters. Experiment with enhancements like Double DQN or Dueling DQN for better performance.

For further reading:

Leave a Reply

Your email address will not be published. Required fields are marked *