"""
Deep Q-Learning Agent for Pacman
Based on Stanford CS229 Project: "Reinforcement Learning in Pacman"
https://cs229.stanford.edu/proj2017/final-reports/5241109.pdf

Key improvements:
- Equivalent image representation (100x faster than pixel input)
- 3 Convolutional layers with optimized architecture
- Target network updated every 100 steps
- Deque-based replay memory for efficiency
- Gradient clipping to prevent exploding gradients
"""

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random

student_name = "Type your full name here."


class ConvQNetwork(nn.Module):
    """
    Convolutional Q-Network based on Stanford CS229 paper

    Architecture:
    - Input: Grid representation (channels x height x width)
    - Conv Layer 1: 3x3 kernel, 8 filters, ReLU
    - Conv Layer 2: 3x3 kernel, 16 filters, ReLU
    - Conv Layer 3: 3x3 kernel, 32 filters, ReLU
    - Flatten
    - Fully Connected: 256 neurons, ReLU
    - Output: Q-values for each action
    """

    def __init__(self, input_channels, height, width, num_actions):
        super(ConvQNetwork, self).__init__()

        # Convolutional layers (as per Stanford paper)
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)

        self.relu = nn.ReLU()

        # Calculate flattened size after convolutions
        # With padding=1 and kernel=3, size is preserved
        # After conv1: (H, W) -> (H, W)
        # After conv2: (H, W) -> (H, W)
        # After conv3: (H, W) -> (H, W)
        conv_output_size = 32 * height * width

        # Fully connected layers
        self.fc1 = nn.Linear(conv_output_size, 256)
        self.fc2 = nn.Linear(256, num_actions)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def forward(self, x):
        # Convolutional layers
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))

        # Flatten
        x = x.view(x.size(0), -1)

        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x


class DQNAgent:
    """
    Deep Q-Network Agent based on Stanford CS229 implementation

    Key Features:
    - Equivalent image representation for 100x speedup
    - Experience replay with deque (efficient storage)
    - Target network for stability
    - Epsilon-greedy exploration with decay
    - Gradient clipping
    """

    def __init__(self, game, discount=0.9, learning_rate=0.00025,
                 explore_prob=1.0, grid_size=(7, 10)):
        self.game = game
        self.discount = discount
        self.learning_rate = learning_rate
        self.explore_prob = explore_prob
        self.initial_epsilon = explore_prob

        # Hyperparameters (from Stanford paper)
        self.batch_size = 32
        self.memory_size = 100000  # Large replay buffer
        self.target_update_freq = 100  # Update target network every 100 steps
        self.min_replay_size = 1000
        self.learning_rate_decay = 0.00005  # Final learning rate

        # Grid configuration
        self.grid_height, self.grid_width = grid_size
        self.num_channels = 5  # Pacman, Ghost, Food, Capsules, Walls

        # Networks
        self.q_network = None
        self.target_network = None
        self.optimizer = None

        # Action mapping - ALWAYS use all 4 actions
        self.action_list = None
        self.num_actions = 0

        # Replay memory (using deque for efficiency as per paper)
        self.memory = deque(maxlen=self.memory_size)

        # Counters
        self.steps = 0
        self.episodes = 0

        # Device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"\n{'=' * 60}")
        print(f"Stanford CS229 DQN Agent Initialized")
        print(f"{'=' * 60}")
        print(f"Device: {self.device}")
        print(f"Grid Size: {self.grid_height}x{self.grid_width}")
        print(f"Replay Buffer: {self.memory_size}")
        print(f"Target Update Frequency: {self.target_update_freq}")
        print(f"{'=' * 60}\n")

    def _action_to_int(self, action):
        """Convert action (enum or int) to integer"""
        if isinstance(action, int):
            return action
        # Action enum values: Left=0, Down=1, Right=2, Up=3
        return int(action)

    def _create_equivalent_image(self, state):
        """
        Create equivalent image representation (Key optimization from paper!)

        Instead of capturing actual game pixels, we create a compact
        grid representation where each cell represents game objects.
        This gives 100x speedup without losing information!

        Channels:
        0: Pacman position (1 at pacman location, 0 elsewhere)
        1: Ghost position (1 at ghost location, 0 elsewhere)
        2: Food dots (1 at food locations, 0 elsewhere)
        3: Capsules (1 at capsule locations, 0 elsewhere)
        4: Walls (1 at wall locations, 0 elsewhere)
        """
        # Initialize empty grid
        grid = np.zeros((self.num_channels, self.grid_height, self.grid_width),
                        dtype=np.float32)

        # Channel 0: Pacman
        px, py = state._pacman
        if 0 <= px < self.grid_height and 0 <= py < self.grid_width:
            grid[0, int(px), int(py)] = 1.0

        # Channel 1: Ghost
        gx, gy = state._ghost
        if 0 <= gx < self.grid_height and 0 <= gy < self.grid_width:
            grid[1, int(gx), int(gy)] = 1.0

        # Channel 2: Food dots
        for dot in state._dots:
            dx, dy = dot
            if 0 <= dx < self.grid_height and 0 <= dy < self.grid_width:
                grid[2, int(dx), int(dy)] = 1.0

        # Channel 3: Capsules (if your game has them)
        # Assuming capsules are special dots - modify based on your game
        # For now, leaving this channel mostly zeros

        # Channel 4: Walls
        if hasattr(state, '_State__walls'):
            walls = state._State__walls
            for i in range(self.grid_height):
                for j in range(self.grid_width):
                    if i < len(walls) and j < len(walls[i]) and walls[i][j]:
                        grid[4, i, j] = 1.0

        return torch.FloatTensor(grid).to(self.device)

    def _initialize_networks(self, state):
        """Initialize Q-network and target network"""
        if self.q_network is not None:
            return

        # CRITICAL FIX: Always use ALL 4 Pacman actions as integers
        # Don't get actions from the first state - it might not have all 4 available
        self.action_list = [0, 1, 2, 3]  # Left, Down, Right, Up
        self.num_actions = 4

        print(f"Initializing Networks...")
        print(f"  Input: {self.num_channels} channels × {self.grid_height}×{self.grid_width}")
        print(f"  Output: {self.num_actions} actions")
        print(f"  Action list: {self.action_list}")

        # Create networks
        self.q_network = ConvQNetwork(
            self.num_channels,
            self.grid_height,
            self.grid_width,
            self.num_actions
        ).to(self.device)

        self.target_network = ConvQNetwork(
            self.num_channels,
            self.grid_height,
            self.grid_width,
            self.num_actions
        ).to(self.device)

        # Copy weights to target network
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()  # Target network always in eval mode

        # Optimizer with learning rate from paper
        self.optimizer = optim.Adam(
            self.q_network.parameters(),
            lr=self.learning_rate
        )

        # Count parameters
        total_params = sum(p.numel() for p in self.q_network.parameters())
        print(f"  Total Parameters: {total_params:,}")
        print(f"  Learning Rate: {self.learning_rate}")
        print()

    def get_q_value(self, state, action):
        """Get Q-value for a specific state-action pair"""
        if self.q_network is None:
            return 0.0

        action_int = self._action_to_int(action)
        if action_int not in self.action_list:
            return 0.0

        self.q_network.eval()
        with torch.no_grad():
            grid = self._create_equivalent_image(state).unsqueeze(0)
            q_values = self.q_network(grid)
            action_idx = self.action_list.index(action_int)
            return q_values[0, action_idx].item()

    def get_value(self, state):
        """Get state value V(s) = max_a Q(s,a)"""
        actions = self.game.get_actions(state)
        if not actions:
            return 0.0

        if self.q_network is None:
            return 0.0

        self.q_network.eval()
        with torch.no_grad():
            grid = self._create_equivalent_image(state).unsqueeze(0)
            q_values = self.q_network(grid)
            return q_values.max().item()

    def get_best_policy(self, state):
        """Get best action (greedy policy)"""
        actions = self.game.get_actions(state)
        if not actions:
            return None

        if self.q_network is None:
            return random.choice(list(actions))

        self.q_network.eval()
        with torch.no_grad():
            grid = self._create_equivalent_image(state).unsqueeze(0)
            q_values = self.q_network(grid)

            # Convert actions to integers and find valid indices
            valid_actions = [(self._action_to_int(a), a) for a in actions]
            valid_indices = [self.action_list.index(action_int)
                             for action_int, _ in valid_actions
                             if action_int in self.action_list]

            if not valid_indices:
                return random.choice(list(actions))

            # Get best valid action index
            best_idx = max(valid_indices, key=lambda i: q_values[0, i].item())
            best_action_int = self.action_list[best_idx]

            # Return the corresponding action enum
            for action_int, action_enum in valid_actions:
                if action_int == best_action_int:
                    return action_enum

            return random.choice(list(actions))

    def get_action(self, state):
        """Epsilon-greedy action selection"""
        actions = self.game.get_actions(state)
        if not actions:
            return None

        # Initialize networks on first state
        if self.q_network is None:
            self._initialize_networks(state)

        # Epsilon-greedy
        if random.random() < self.explore_prob:
            return random.choice(list(actions))
        else:
            return self.get_best_policy(state)

    def update(self, state, action, next_state, reward):
        """
        Store transition and perform training step

        Training happens at EVERY step (as per paper)
        """
        # Initialize if needed
        if self.q_network is None:
            self._initialize_networks(state)

        # Store transition in replay memory
        self.memory.append((state, action, next_state, reward))
        self.steps += 1

        # Wait for enough experiences
        if len(self.memory) < self.min_replay_size:
            return

        # Train on mini-batch
        self._train_step()

        # Update target network periodically
        if self.steps % self.target_update_freq == 0:
            self.update_target_network()
            print(f"  [Step {self.steps}] Target network updated")

    def _train_step(self):
        """
        Perform one training step using experience replay

        Key aspects from Stanford paper:
        - Random sampling from replay buffer (breaks correlation)
        - Use target network for stable targets
        - Vectorized max over legal actions only
        - Gradient clipping
        """
        # Sample random mini-batch
        batch = random.sample(self.memory, self.batch_size)

        # Prepare batch tensors
        states_batch = torch.stack([
            self._create_equivalent_image(s) for s, _, _, _ in batch
        ])

        actions_batch = torch.LongTensor([
            self.action_list.index(self._action_to_int(a)) for _, a, _, _ in batch
        ]).to(self.device)

        next_states_batch = torch.stack([
            self._create_equivalent_image(ns) for _, _, ns, _ in batch
        ])

        rewards_batch = torch.FloatTensor([
            r for _, _, _, r in batch
        ]).to(self.device)

        # Check for terminal states
        dones_batch = torch.FloatTensor([
            1.0 if (ns._won or ns._lost) else 0.0
            for _, _, ns, _ in batch
        ]).to(self.device)

        # Get legal actions for next states (vectorized)
        legal_actions_mask = []
        for _, _, ns, _ in batch:
            legal_actions = self.game.get_actions(ns)
            mask = torch.zeros(self.num_actions, device=self.device)
            for action in legal_actions:
                action_int = self._action_to_int(action)
                if action_int in self.action_list:
                    mask[self.action_list.index(action_int)] = 1.0
            legal_actions_mask.append(mask)
        legal_actions_mask = torch.stack(legal_actions_mask)

        # Compute current Q-values
        self.q_network.train()
        current_q_values = self.q_network(states_batch).gather(
            1, actions_batch.unsqueeze(1)
        ).squeeze()

        # Compute target Q-values using target network
        with torch.no_grad():
            next_q_values = self.target_network(next_states_batch)

            # Mask out illegal actions (set to very negative value)
            next_q_values = next_q_values.masked_fill(legal_actions_mask == 0, -1e9)

            # Max over legal actions only
            max_next_q = next_q_values.max(1)[0]

            # Target: r + γ * max Q(s', a')
            target_q_values = rewards_batch + (1 - dones_batch) * self.discount * max_next_q

        # Compute loss (Mean Squared Error)
        loss = nn.MSELoss()(current_q_values, target_q_values)

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping (prevent exploding gradients)
        torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=10.0)

        self.optimizer.step()

    def update_target_network(self):
        """Copy Q-network weights to target network"""
        self.target_network.load_state_dict(self.q_network.state_dict())

    def decay_epsilon(self):
        """
        Decay epsilon from 1.0 to 0.1 over training
        (As per Stanford paper)
        """
        self.episodes += 1

        # Linear decay over first 2000 episodes
        decay_episodes = 2000
        min_epsilon = 0.1

        if self.episodes < decay_episodes:
            self.explore_prob = self.initial_epsilon - \
                                (self.initial_epsilon - min_epsilon) * \
                                (self.episodes / decay_episodes)
        else:
            self.explore_prob = min_epsilon

    def adjust_learning_rate(self):
        """
        Gradually reduce learning rate
        (Stanford paper: 0.00025 → 0.00005)
        """
        if self.episodes > 1000:
            new_lr = max(self.learning_rate_decay,
                         self.learning_rate * 0.99)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = new_lr

    def save(self, filepath):
        """Save model checkpoint"""
        torch.save({
            'q_network': self.q_network.state_dict(),
            'target_network': self.target_network.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'episodes': self.episodes,
            'steps': self.steps,
            'epsilon': self.explore_prob,
            'action_list': self.action_list,
        }, filepath)
        print(f"✓ Model saved to {filepath}")

    def load(self, filepath):
        """Load model checkpoint"""
        checkpoint = torch.load(filepath, map_location=self.device)

        self.action_list = checkpoint['action_list']
        self.num_actions = len(self.action_list)

        # Reinitialize networks
        self.q_network = ConvQNetwork(
            self.num_channels, self.grid_height,
            self.grid_width, self.num_actions
        ).to(self.device)

        self.target_network = ConvQNetwork(
            self.num_channels, self.grid_height,
            self.grid_width, self.num_actions
        ).to(self.device)

        self.q_network.load_state_dict(checkpoint['q_network'])
        self.target_network.load_state_dict(checkpoint['target_network'])

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.episodes = checkpoint['episodes']
        self.steps = checkpoint['steps']
        self.explore_prob = checkpoint['epsilon']

        print(f"✓ Model loaded from {filepath}")
        print(f"  Episodes: {self.episodes}")
        print(f"  Steps: {self.steps}")
        print(f"  Epsilon: {self.explore_prob:.4f}")


# Feedback
feedback_question_1 = 10

feedback_question_2 = """
The most challenging aspect was understanding the equivalent image representation
and how to efficiently implement the convolutional neural network to process
grid-based game states instead of raw pixels.
"""

feedback_question_3 = """
I enjoyed implementing the Stanford CS229 approach and seeing how the equivalent
image trick provides 100x speedup. The progression from Q-learning to Approximate
Q-learning to Deep Q-learning shows the power of modern deep reinforcement learning.
"""