import argparse
import math
import random
import time
import tkinter as tk
from enum import IntEnum
from functools import lru_cache
from typing import Set
import os
from collections import deque

# Try importing PyTorch for DQN
try:
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import numpy as np

    TORCH_AVAILABLE = True
except ImportError:
    TORCH_AVAILABLE = False
    print("WARNING: PyTorch not installed. DQN agent will not be available.")
    print("Install with: pip install torch numpy")


# ============================================================================
# BERKELEY LAYOUT PARSER
# ============================================================================

class Grid:
    """Simple 2D grid structure from Berkeley"""

    def __init__(self, width, height, initialValue=False):
        self.width = width
        self.height = height
        self.data = [[initialValue for y in range(height)] for x in range(width)]

    def __getitem__(self, i):
        return self.data[i]

    def __setitem__(self, key, item):
        self.data[key] = item


class Layout:
    """Berkeley Layout class - manages static game board information"""

    def __init__(self, layoutText):
        self.width = len(layoutText[0])
        self.height = len(layoutText)
        self.walls = Grid(self.width, self.height, False)
        self.food = Grid(self.width, self.height, False)
        self.capsules = []
        self.agentPositions = []
        self.numGhosts = 0
        self.processLayoutText(layoutText)
        self.layoutText = layoutText

    def processLayoutText(self, layoutText):
        """
        Coordinates are flipped from input format to (x,y) convention
        % - Wall, . - Food, o - Capsule, G - Ghost, P - Pacman
        """
        maxY = self.height - 1
        for y in range(self.height):
            for x in range(self.width):
                layoutChar = layoutText[maxY - y][x]
                self.processLayoutChar(x, y, layoutChar)
        self.agentPositions.sort()
        self.agentPositions = [(i == 0, pos) for i, pos in self.agentPositions]

    def processLayoutChar(self, x, y, layoutChar):
        if layoutChar == '%':
            self.walls[x][y] = True
        elif layoutChar == '.':
            self.food[x][y] = True
        elif layoutChar == 'o':
            self.capsules.append((x, y))
        elif layoutChar == 'P':
            self.agentPositions.append((0, (x, y)))
        elif layoutChar in ['G']:
            self.agentPositions.append((1, (x, y)))
            self.numGhosts += 1


def getLayout(name):
    """Load layout from file"""
    paths = [
        name,
        f'layouts/{name}',
        f'layouts/{name}.lay',
        f'{name}.lay'
    ]

    for path in paths:
        if os.path.exists(path):
            with open(path, 'r') as f:
                return Layout([line.rstrip() for line in f])
    return None


# ============================================================================
# STATE AND GAME LOGIC
# ============================================================================

class State:
    def __init__(self, layout=None, **grid):
        if layout:
            # Initialize from Berkeley layout
            self._m = layout.width
            self._n = layout.height

            # Get pacman position
            pacman_pos = None
            ghost_pos = None
            for is_pacman, pos in layout.agentPositions:
                if is_pacman:
                    pacman_pos = (pos[1], pos[0])  # Flip to (row, col)
                elif ghost_pos is None:
                    ghost_pos = (pos[1], pos[0])

            self._pacman = pacman_pos if pacman_pos else (1, 1)
            self._ghost = ghost_pos if ghost_pos else (self._n - 2, self._m - 2)

            # Get dots from food
            dots = []
            for x in range(layout.width):
                for y in range(layout.height):
                    if layout.food[x][y]:
                        dots.append((y, x))  # Flip to (row, col)
            self._dots = tuple(dots)

            # Convert walls to 2D array
            self.__walls = [[layout.walls[x][y] for x in range(layout.width)]
                            for y in range(layout.height)]
        else:
            # Initialize from grid format (original format)
            self._n, self._m = grid['size']
            self._pacman = grid['pacman']
            self._ghost = grid['ghost']
            self._dots = grid['dots']
            self.__walls = grid['conv_walls']

        if 'ghost_dir' in grid:
            self._ghost_dir = grid['ghost_dir']
        else:
            possible = list(self._possible(self._ghost))
            self._ghost_dir = random.choice(possible) if possible else Pacman.Action.Up

    def __eq__(self, other: 'State'):
        return (self._pacman == other._pacman and
                self._ghost == other._ghost and
                self._ghost_dir == other._ghost_dir and
                self._dots == other._dots)

    def __hash__(self):
        return hash((self._pacman, self._ghost, self._ghost_dir, self._dots))

    @property
    def _won(self):
        return len(self._dots) <= 0

    @property
    def _lost(self):
        return self._pacman == self._ghost

    def __oob(self, x, y):
        return x < 0 or x >= self._n or y < 0 or y >= self._m or self.__walls[x][y]

    def _possible(self, loc):
        x, y = loc
        possible = set()
        if not self.__oob(x - 1, y):
            possible.add(Pacman.Action.Up)
        if not self.__oob(x + 1, y):
            possible.add(Pacman.Action.Down)
        if not self.__oob(x, y - 1):
            possible.add(Pacman.Action.Left)
        if not self.__oob(x, y + 1):
            possible.add(Pacman.Action.Right)
        return possible

    @staticmethod
    def _do_action(loc, action):
        x, y = loc
        if action == Pacman.Action.Up:
            return x - 1, y
        if action == Pacman.Action.Down:
            return x + 1, y
        if action == Pacman.Action.Left:
            return x, y - 1
        return x, y + 1

    def _move(self, action):
        if action not in self._get_actions():
            raise ValueError('not a valid action')

        target_x, target_y = State._do_action(self._pacman, action)

        if self.__oob(target_x, target_y):
            pacman = self._pacman
        else:
            pacman = (target_x, target_y)

        dots = set(self._dots)
        if pacman in dots:
            dots.remove(pacman)

        ghost_poss = self._possible(self._ghost)
        if self._ghost_dir in ghost_poss:
            ghost_poss.discard(0b10 ^ self._ghost_dir)

        def calc_dist(d):
            p = self._do_action(self._ghost, d)
            return math.hypot(p[0] - pacman[0], p[1] - pacman[1])

        ghost_poss_list = sorted(list(ghost_poss), key=calc_dist)
        if ghost_poss_list:
            ghost_dir = random.choices(ghost_poss_list,
                                       weights=list(range(len(ghost_poss_list), 0, -1)))[0]
        else:
            ghost_dir = self._ghost_dir

        new_ghost = self._do_action(self._ghost, ghost_dir)
        if self._pacman == new_ghost and pacman == self._ghost:
            x1, y1 = self._pacman
            x2, y2 = self._ghost
            pacman = new_ghost = ((x1 + x2) / 2, (y1 + y2) / 2)

        return State(
            size=(self._n, self._m),
            pacman=pacman,
            ghost=new_ghost,
            dots=tuple(sorted(dots)),
            conv_walls=self.__walls,
            ghost_dir=ghost_dir,
        )

    @lru_cache(maxsize=None)
    def _get_actions(self):
        x, y = self._pacman
        if x < 0 or x >= self._n or y < 0 or y >= self._m:
            raise ValueError('not a valid state')
        if self._lost or self._won:
            return set()

        poss = set()
        all_poss = self._possible(self._pacman)
        for i in all_poss:
            poss.add(i)
            rev = 0b10 ^ i
            if rev in all_poss:
                poss.add(rev)
        return poss


class Pacman:
    class Action(IntEnum):
        Up = 3
        Down = 1
        Left = 0
        Right = 2

    @staticmethod
    def get_actions(state: State) -> Set['Pacman.Action']:
        return state._get_actions()


# ============================================================================
# DQN AGENT (from your implementation)
# ============================================================================

if TORCH_AVAILABLE:
    class ConvQNetwork(nn.Module):
        """Convolutional Q-Network from Stanford CS229 paper"""

        def __init__(self, input_channels, height, width, num_actions):
            super(ConvQNetwork, self).__init__()

            self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)
            self.conv3 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
            self.relu = nn.ReLU()

            conv_output_size = 32 * height * width
            self.fc1 = nn.Linear(conv_output_size, 256)
            self.fc2 = nn.Linear(256, num_actions)

            self.apply(self._init_weights)

        def _init_weights(self, module):
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        def forward(self, x):
            x = self.relu(self.conv1(x))
            x = self.relu(self.conv2(x))
            x = self.relu(self.conv3(x))
            x = x.view(x.size(0), -1)
            x = self.relu(self.fc1(x))
            x = self.fc2(x)
            return x


    class DQNAgent:
        """Deep Q-Network Agent"""

        def __init__(self, game, discount=0.9, learning_rate=0.00025,
                     explore_prob=1.0, grid_size=(7, 10)):
            self.game = game
            self.discount = discount
            self.learning_rate = learning_rate
            self.explore_prob = explore_prob
            self.initial_epsilon = explore_prob

            self.batch_size = 64
            self.memory_size = 100000
            self.target_update_freq = 50
            self.min_replay_size = 1000

            self.grid_height, self.grid_width = grid_size
            self.num_channels = 5

            self.q_network = None
            self.target_network = None
            self.optimizer = None
            self.action_list = None
            self.num_actions = 0

            self.memory = deque(maxlen=self.memory_size)
            self.steps = 0
            self.episodes = 0

            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            print(f"\nDQN Agent Initialized")
            print(f"Device: {self.device}")
            print(f"Grid Size: {self.grid_height}x{self.grid_width}\n")

        def _action_to_int(self, action):
            if isinstance(action, int):
                return action
            return int(action)

        def _create_equivalent_image(self, state):
            grid = np.zeros((self.num_channels, self.grid_height, self.grid_width),
                            dtype=np.float32)

            px, py = state._pacman
            if 0 <= px < self.grid_height and 0 <= py < self.grid_width:
                grid[0, int(px), int(py)] = 1.0

            gx, gy = state._ghost
            if 0 <= gx < self.grid_height and 0 <= gy < self.grid_width:
                grid[1, int(gx), int(gy)] = 1.0

            for dot in state._dots:
                dx, dy = dot
                if 0 <= dx < self.grid_height and 0 <= dy < self.grid_width:
                    grid[2, int(dx), int(dy)] = 1.0

            if hasattr(state, '_State__walls'):
                walls = state._State__walls
                for i in range(self.grid_height):
                    for j in range(self.grid_width):
                        if i < len(walls) and j < len(walls[i]) and walls[i][j]:
                            grid[4, i, j] = 1.0

            return torch.FloatTensor(grid).to(self.device)

        def _initialize_networks(self, state):
            if self.q_network is not None:
                return

            self.action_list = [0, 1, 2, 3]
            self.num_actions = 4

            print(f"Initializing Networks...")
            print(f"  Actions: {self.num_actions}")

            self.q_network = ConvQNetwork(
                self.num_channels,
                self.grid_height,
                self.grid_width,
                self.num_actions
            ).to(self.device)

            self.target_network = ConvQNetwork(
                self.num_channels,
                self.grid_height,
                self.grid_width,
                self.num_actions
            ).to(self.device)

            self.target_network.load_state_dict(self.q_network.state_dict())
            self.target_network.eval()

            self.optimizer = optim.Adam(
                self.q_network.parameters(),
                lr=self.learning_rate
            )

            total_params = sum(p.numel() for p in self.q_network.parameters())
            print(f"  Total Parameters: {total_params:,}\n")

        def get_best_policy(self, state):
            actions = self.game.get_actions(state)
            if not actions:
                return None

            if self.q_network is None:
                return random.choice(list(actions))

            self.q_network.eval()
            with torch.no_grad():
                grid = self._create_equivalent_image(state).unsqueeze(0)
                q_values = self.q_network(grid)

                valid_actions = [(self._action_to_int(a), a) for a in actions]
                valid_indices = [self.action_list.index(action_int)
                                 for action_int, _ in valid_actions
                                 if action_int in self.action_list]

                if not valid_indices:
                    return random.choice(list(actions))

                best_idx = max(valid_indices, key=lambda i: q_values[0, i].item())
                best_action_int = self.action_list[best_idx]

                for action_int, action_enum in valid_actions:
                    if action_int == best_action_int:
                        return action_enum

                return random.choice(list(actions))

        def get_action(self, state):
            actions = self.game.get_actions(state)
            if not actions:
                return None

            if self.q_network is None:
                self._initialize_networks(state)

            if random.random() < self.explore_prob:
                return random.choice(list(actions))
            else:
                return self.get_best_policy(state)

        def update(self, state, action, next_state, reward):
            if self.q_network is None:
                self._initialize_networks(state)

            self.memory.append((state, action, next_state, reward))
            self.steps += 1

            if len(self.memory) < self.min_replay_size:
                return

            self._train_step()

            if self.steps % self.target_update_freq == 0:
                self.update_target_network()

        def _train_step(self):
            batch = random.sample(self.memory, self.batch_size)

            states_batch = torch.stack([
                self._create_equivalent_image(s) for s, _, _, _ in batch
            ])

            actions_batch = torch.LongTensor([
                self.action_list.index(self._action_to_int(a)) for _, a, _, _ in batch
            ]).to(self.device)

            next_states_batch = torch.stack([
                self._create_equivalent_image(ns) for _, _, ns, _ in batch
            ])

            rewards_batch = torch.FloatTensor([
                r for _, _, _, r in batch
            ]).to(self.device)

            dones_batch = torch.FloatTensor([
                1.0 if (ns._won or ns._lost) else 0.0
                for _, _, ns, _ in batch
            ]).to(self.device)

            legal_actions_mask = []
            for _, _, ns, _ in batch:
                legal_actions = self.game.get_actions(ns)
                mask = torch.zeros(self.num_actions, device=self.device)
                for action in legal_actions:
                    action_int = self._action_to_int(action)
                    if action_int in self.action_list:
                        mask[self.action_list.index(action_int)] = 1.0
                legal_actions_mask.append(mask)
            legal_actions_mask = torch.stack(legal_actions_mask)

            self.q_network.train()
            current_q_values = self.q_network(states_batch).gather(
                1, actions_batch.unsqueeze(1)
            ).squeeze()

            with torch.no_grad():
                next_q_values = self.target_network(next_states_batch)
                next_q_values = next_q_values.masked_fill(legal_actions_mask == 0, -1e9)
                max_next_q = next_q_values.max(1)[0]
                target_q_values = rewards_batch + (1 - dones_batch) * self.discount * max_next_q

            loss = nn.MSELoss()(current_q_values, target_q_values)

            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), max_norm=10.0)
            self.optimizer.step()

        def update_target_network(self):
            self.target_network.load_state_dict(self.q_network.state_dict())

        def decay_epsilon(self):
            self.episodes += 1
            decay_episodes = 1000
            min_epsilon = 0.01

            if self.episodes < decay_episodes:
                self.explore_prob = self.initial_epsilon - \
                                    (self.initial_epsilon - min_epsilon) * \
                                    (self.episodes / decay_episodes)
            else:
                self.explore_prob = min_epsilon

        def save(self, filepath):
            torch.save({
                'q_network': self.q_network.state_dict(),
                'target_network': self.target_network.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'episodes': self.episodes,
                'steps': self.steps,
                'epsilon': self.explore_prob,
                'action_list': self.action_list,
            }, filepath)
            print(f"✓ Model saved to {filepath}")

        def load(self, filepath):
            checkpoint = torch.load(filepath, map_location=self.device)
            self.action_list = checkpoint['action_list']
            self.num_actions = len(self.action_list)

            self.q_network = ConvQNetwork(
                self.num_channels, self.grid_height,
                self.grid_width, self.num_actions
            ).to(self.device)

            self.target_network = ConvQNetwork(
                self.num_channels, self.grid_height,
                self.grid_width, self.num_actions
            ).to(self.device)

            self.q_network.load_state_dict(checkpoint['q_network'])
            self.target_network.load_state_dict(checkpoint['target_network'])

            self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)
            self.optimizer.load_state_dict(checkpoint['optimizer'])

            self.episodes = checkpoint['episodes']
            self.steps = checkpoint['steps']
            self.explore_prob = checkpoint['epsilon']

            print(f"✓ Model loaded from {filepath}")


# ============================================================================
# Q-LEARNING AGENTS
# ============================================================================

class QLearningAgent:
    """Q-Learning Agent using Q-table"""

    def __init__(self, game, discount, learning_rate, explore_prob):
        self.game = game
        self.discount = discount
        self.learning_rate = learning_rate
        self.explore_prob = explore_prob
        self.q_table = {}

    def get_q_value(self, state, action):
        return self.q_table.get((state, action), 0.0)

    def get_value(self, state):
        actions = self.game.get_actions(state)
        if not actions:
            return 0.0
        return max(self.get_q_value(state, action) for action in actions)

    def get_best_policy(self, state):
        actions = self.game.get_actions(state)
        if not actions:
            return None
        max_q = max(self.get_q_value(state, action) for action in actions)
        best_actions = [action for action in actions
                        if self.get_q_value(state, action) == max_q]
        return random.choice(best_actions)

    def update(self, state, action, next_state, reward):
        current_q = self.get_q_value(state, action)
        next_value = self.get_value(next_state)
        new_q = ((1 - self.learning_rate) * current_q +
                 self.learning_rate * (reward + self.discount * next_value))
        self.q_table[(state, action)] = new_q

    def get_action(self, state):
        actions = self.game.get_actions(state)
        if not actions:
            return None
        if random.random() < self.explore_prob:
            return random.choice(list(actions))
        else:
            return self.get_best_policy(state)


class ApproximateQAgent(QLearningAgent):
    """Approximate Q-Learning Agent using feature weights"""

    def __init__(self, *args, extractor):
        super().__init__(*args)
        self.extractor = extractor
        self.weights = {}

    def get_weight(self, feature):
        return self.weights.get(feature, 0.0)

    def get_q_value(self, state, action):
        features = self.extractor(state, action)
        q_value = sum(self.get_weight(feature) * value
                      for feature, value in features.items())
        return q_value

    def update(self, state, action, next_state, reward):
        next_value = self.get_value(next_state)
        current_q = self.get_q_value(state, action)
        td_error = reward + self.discount * next_value - current_q

        features = self.extractor(state, action)
        for feature, feature_value in features.items():
            current_weight = self.get_weight(feature)
            new_weight = current_weight + self.learning_rate * td_error * feature_value
            self.weights[feature] = new_weight


# ============================================================================
# FEATURE EXTRACTORS
# ============================================================================

def identity_extractor(state, action):
    return {(state, action): 1.}


def closest_food(start, state: State):
    queue = [start]
    dist = {start: 0}
    while queue:
        pos = queue.pop(0)
        dist_n = dist[pos]
        if pos in state._dots:
            return dist_n
        for act in state._possible(pos):
            neighbor = state._do_action(pos, act)
            if neighbor not in dist:
                dist[neighbor] = dist_n + 1
                queue.append(neighbor)
    return None


@lru_cache(maxsize=10)
def simple_extractor(state, action):
    features = {'bias': 1}
    next_loc = state._do_action(state._pacman, action)

    features['ghost-step-away'] = next_loc == state._ghost or any(
        next_loc == state._do_action(state._ghost, act)
        for act in state._possible(state._ghost))

    if not features['ghost-step-away'] and next_loc in state._dots:
        features['food'] = 1

    dist = closest_food(next_loc, state)
    if dist is not None:
        features['closest-food'] = dist / (state._n * state._m)

    for k in features:
        features[k] /= 10
    return features


@lru_cache(maxsize=10)
def better_extractor(state, action):
    """Improved feature extractor with more informative features"""
    features = {}

    next_loc = state._do_action(state._pacman, action)

    # Distance to ghost (Manhattan distance)
    ghost_dist = abs(next_loc[0] - state._ghost[0]) + abs(next_loc[1] - state._ghost[1])
    features['ghost-distance'] = ghost_dist / (state._n + state._m)

    # Is ghost very close? (immediate danger!)
    features['ghost-danger'] = 1.0 if ghost_dist <= 2 else 0.0

    # Will this action eat food?
    features['eats-food'] = 1.0 if next_loc in state._dots else 0.0

    # Distance to closest food
    dist_to_food = closest_food(next_loc, state)
    if dist_to_food is not None:
        features['food-distance'] = 1.0 / (1.0 + dist_to_food)
    else:
        features['food-distance'] = 0.0

    # Number of dots remaining (want to minimize)
    features['dots-remaining'] = len(state._dots) / 100.0

    # Bias term
    features['bias'] = 1.0

    return features


# ============================================================================
# GUI
# ============================================================================

class GUI(tk.Tk):
    SQUARE_SIZE = 30
    ANIMATION_SPEED = 0.1

    def __init__(self, init_state: State, agent, layout_obj=None):
        super().__init__()

        self.__state = self.__init_state = init_state
        self.__agent = agent
        self.__last_action = Pacman.Action.Right
        self.__layout = layout_obj

        rows, cols = init_state._n, init_state._m
        width = GUI.SQUARE_SIZE * (cols + 2)
        height = GUI.SQUARE_SIZE * (rows + 2)
        self.__canvas = tk.Canvas(self, width=width + 1, height=height + 1,
                                  highlightthickness=0, bg='black')

        # Draw border
        self.__canvas.create_line(
            GUI.SQUARE_SIZE * .7, GUI.SQUARE_SIZE * .7,
            GUI.SQUARE_SIZE * (cols + 1.3), GUI.SQUARE_SIZE * .7,
            GUI.SQUARE_SIZE * (cols + 1.3), GUI.SQUARE_SIZE * (rows + 1.3),
            GUI.SQUARE_SIZE * .7, GUI.SQUARE_SIZE * (rows + 1.3),
            GUI.SQUARE_SIZE * .7, GUI.SQUARE_SIZE * .7,
            fill='blue', width=GUI.SQUARE_SIZE * .2
        )

        # Draw walls
        if layout_obj:
            for y in range(rows):
                for x in range(cols):
                    if layout_obj.walls[x][y]:
                        self.__canvas.create_rectangle(
                            (x + 1) * GUI.SQUARE_SIZE, (y + 1) * GUI.SQUARE_SIZE,
                            (x + 2) * GUI.SQUARE_SIZE, (y + 2) * GUI.SQUARE_SIZE,
                            fill='blue', outline='blue')

        self.__canvas.pack()

        self.__reward = tk.Label(self)
        self.__update_score(0)
        self.__reward.pack(side=tk.BOTTOM)

        self.title('Pacman with Q-Learning - Berkeley Layouts')

    def __iterate(self, learning=True):
        if self.__state._lost or self.__state._won:
            result = self.__state._won
            self.__state = self.__init_state
            return result

        if learning:
            self.__last_action = self.__agent.get_action(self.__state)
        else:
            self.__last_action = self.__agent.get_best_policy(self.__state)

        new_state = self.__state._move(self.__last_action)
        reward = (len(new_state._dots) - len(self.__state._dots)) * 10 - 1

        if new_state._lost:
            reward -= 500
        elif new_state._won:
            reward += 500

        if learning:
            self.__agent.update(self.__state, self.__last_action, new_state, reward)

        self.__state = new_state
        return reward

    def __update_board(self):
        self.__canvas.delete('state')

        # Draw dots
        for x, y in self.__state._dots:
            self.__canvas.create_rectangle(
                (y + 1.4) * GUI.SQUARE_SIZE, (x + 1.4) * GUI.SQUARE_SIZE,
                (y + 1.6) * GUI.SQUARE_SIZE, (x + 1.6) * GUI.SQUARE_SIZE,
                tags='state', fill='white')

        # Draw Pacman
        x, y = self.__state._pacman
        self.__canvas.create_oval(
            (y + 1) * GUI.SQUARE_SIZE, (x + 1) * GUI.SQUARE_SIZE,
            (y + 2) * GUI.SQUARE_SIZE, (x + 2) * GUI.SQUARE_SIZE,
            tags='state', fill='yellow')

        pts = ((y + 1, x + 1), (y + 1, x + 2), (y + 2, x + 2), (y + 2, x + 1))
        self.__canvas.create_polygon(
            pts[self.__last_action][0] * GUI.SQUARE_SIZE,
            pts[self.__last_action][1] * GUI.SQUARE_SIZE,
            (y + 1.5) * GUI.SQUARE_SIZE, (x + 1.5) * GUI.SQUARE_SIZE,
            pts[(self.__last_action + 1) % 4][0] * GUI.SQUARE_SIZE,
            pts[(self.__last_action + 1) % 4][1] * GUI.SQUARE_SIZE,
            tags='state')

        # Draw Ghost
        x, y = self.__state._ghost
        self.__canvas.create_polygon(
            (y + 1) * GUI.SQUARE_SIZE, (x + 1) * GUI.SQUARE_SIZE,
            (y + 1) * GUI.SQUARE_SIZE, (x + 2) * GUI.SQUARE_SIZE,
            (y + 1 + 1 / 6) * GUI.SQUARE_SIZE, (x + 1.7) * GUI.SQUARE_SIZE,
            (y + 1 + 2 / 6) * GUI.SQUARE_SIZE, (x + 2) * GUI.SQUARE_SIZE,
            (y + 1 + 3 / 6) * GUI.SQUARE_SIZE, (x + 1.7) * GUI.SQUARE_SIZE,
            (y + 1 + 4 / 6) * GUI.SQUARE_SIZE, (x + 2) * GUI.SQUARE_SIZE,
            (y + 1 + 5 / 6) * GUI.SQUARE_SIZE, (x + 1.7) * GUI.SQUARE_SIZE,
            (y + 2) * GUI.SQUARE_SIZE, (x + 2) * GUI.SQUARE_SIZE,
            (y + 2) * GUI.SQUARE_SIZE, (x + 1) * GUI.SQUARE_SIZE,
            tags='state', fill='red', smooth=True)

        # Ghost eyes
        self.__canvas.create_oval(
            (y + 1.1) * GUI.SQUARE_SIZE, (x + 1.1) * GUI.SQUARE_SIZE,
            (y + 1.4) * GUI.SQUARE_SIZE, (x + 1.4) * GUI.SQUARE_SIZE,
            tags='state', fill='white', outline='white')
        self.__canvas.create_oval(
            (y + 1.6) * GUI.SQUARE_SIZE, (x + 1.1) * GUI.SQUARE_SIZE,
            (y + 1.9) * GUI.SQUARE_SIZE, (x + 1.4) * GUI.SQUARE_SIZE,
            tags='state', fill='white', outline='white')

    def __update_score(self, score):
        self.__reward.configure(text=f'Score: {score}')

    def train(self, episodes):
        rewards = []
        while len(rewards) < episodes:
            total = 0
            while True:
                reward = self.__iterate()
                if isinstance(reward, bool):
                    break
                total += reward
                self.__update_score(total)
            rewards.append(total)

            # Decay epsilon for DQN after each episode
            if hasattr(self.__agent, 'decay_epsilon'):
                self.__agent.decay_epsilon()

            if len(rewards) % 100 == 0:
                print(f'{len(rewards)}/{episodes} completed')
                print(f'\tAverage rewards over all episodes: {sum(rewards) / len(rewards):.2f}')
                print(f'\tAverage rewards for last 100 episodes: {sum(rewards[-100:]) / 100:.2f}')
                if hasattr(self.__agent, 'explore_prob'):
                    print(f'\tCurrent epsilon: {self.__agent.explore_prob:.4f}')
        return rewards

    def play(self):
        total = 0
        while True:
            self.__update_board()
            self.update()
            time.sleep(self.ANIMATION_SPEED)

            reward = self.__iterate(learning=False)
            if isinstance(reward, bool):
                if reward:
                    print(f'Pacman won! Episode reward {total}')
                else:
                    print(f'Pacman lost! Episode reward {total}')
                return
            total += reward
            self.__update_score(total)


# ============================================================================
# MAIN
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description='Pacman Q-Learning with Berkeley Layouts',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('layout', choices=('mediumClassic','smallClassic','tinyMaze'),help='Berkeley layout file (e.g., mediumClassic.lay)',
                        default='mediumClassic')
    parser.add_argument('-a', '--agent', choices=('q', 'approx', 'dqn'), default='q',
                        help='Q learning agent type')
    parser.add_argument('-f', '--feature', choices=('identity', 'simple', 'better'),
                        default='simple', help='Feature extraction type')
    parser.add_argument('-d', '--discount', type=float, default=0.8,
                        help='Discount factor gamma (0-1)')
    parser.add_argument('-r', '--learning-rate', type=float, default=0.2,
                        help='Learning rate')
    parser.add_argument('-e', '--epsilon', type=float, default=0.05,
                        help='Exploration probability epsilon')
    parser.add_argument('-t', '--train', type=int, required=True,
                        help='Number of training episodes')
    parser.add_argument('-p', '--play', type=int, required=True,
                        help='Number of play episodes')
    parser.add_argument('--save-model', type=str, default=None,
                        help='Path to save DQN model')
    parser.add_argument('--load-model', type=str, default=None,
                        help='Path to load DQN model')

    args = parser.parse_args()

    # Load Berkeley layout
    layout_obj = getLayout(args.layout)
    if layout_obj is None:
        print(f"Error: Could not find layout '{args.layout}'")
        print("Make sure the layout file exists in the 'layouts' directory")
        return

    print(f"Loaded layout: {args.layout}")
    print(f"Size: {layout_obj.width} x {layout_obj.height}")
    print(f"Number of ghosts: {layout_obj.numGhosts}")

    # Create agent
    if args.agent == 'q':
        agent = QLearningAgent(Pacman, args.discount, args.learning_rate, args.epsilon)
    elif args.agent == 'dqn':
        if not TORCH_AVAILABLE:
            print("ERROR: PyTorch not installed. Cannot use DQN agent.")
            print("Install with: pip install torch numpy")
            return
        grid_size = (layout_obj.height, layout_obj.width)
        agent = DQNAgent(Pacman, args.discount, args.learning_rate, args.epsilon, grid_size=grid_size)

        # Load model if specified
        if args.load_model:
            agent.load(args.load_model)
    else:
        extractor_map = {
            'identity': identity_extractor,
            'simple': simple_extractor,
            'better': better_extractor
        }
        extractor = extractor_map[args.feature]
        agent = ApproximateQAgent(Pacman, args.discount, args.learning_rate,
                                  args.epsilon, extractor=extractor)

    # Initialize state from Berkeley layout
    init_state = State(layout=layout_obj)

    # Create GUI
    gui = GUI(init_state, agent, layout_obj)

    # Train
    print(f'\nStarting training for {args.train} episodes...')
    rewards = gui.train(args.train)

    # Decay epsilon for DQN
    if args.agent == 'dqn':
        print(f'Final epsilon: {agent.explore_prob:.4f}')

    # Save DQN model if specified
    if args.agent == 'dqn' and args.save_model:
        agent.save(args.save_model)

    # Save results
    with open('training_results.txt', 'w') as f:
        for r in rewards:
            f.write(f'{r}\n')
    print(f'Training results saved to training_results.txt')

    # Play
    print(f'\nStarting {args.play} play episodes...')
    for i in range(args.play):
        print(f'Playing episode {i + 1}/{args.play}')
        gui.play()
        time.sleep(0.5)

    gui.mainloop()


if __name__ == '__main__':
    main()