"""
DQN Pacman GUI
Based on Stanford CS229 implementation with optimizations
"""

import argparse
import time
import tkinter as tk
from tkinter import ttk
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import sys
import os

# Import from existing files
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from pacman import State, Pacman, convert_walls, PRESET_LAYOUTS
import dqn_agent


class DQNTrainingGUI(tk.Tk):
    """Enhanced GUI for training and visualizing DQN agent"""

    SQUARE_SIZE = 50
    ANIMATION_SPEED = 0.05

    def __init__(self, grid, agent, init_state):
        super().__init__()

        self.state = self.init_state = init_state
        self.agent = agent
        self.last_action = Pacman.Action.Right

        # Training statistics
        self.episode_rewards = []
        self.episode_wins = []
        self.current_episode_reward = 0

        # Create main layout
        self.title('Deep Q-Learning Pacman - Stanford CS229')

        # Create frames
        main_frame = tk.Frame(self)
        main_frame.pack(side=tk.LEFT, padx=10, pady=10)

        stats_frame = tk.Frame(self)
        stats_frame.pack(side=tk.RIGHT, padx=10, pady=10, fill=tk.BOTH, expand=True)

        # Game canvas
        rows, cols = grid['size']
        width = self.SQUARE_SIZE * (cols + 2)
        height = self.SQUARE_SIZE * (rows + 2)

        self.canvas = tk.Canvas(main_frame, width=width + 1, height=height + 1,
                                highlightthickness=0, bg='black')

        # Border
        self.canvas.create_line(
            self.SQUARE_SIZE * .7, self.SQUARE_SIZE * .7,
            self.SQUARE_SIZE * (cols + 1.3), self.SQUARE_SIZE * .7,
            self.SQUARE_SIZE * (cols + 1.3), self.SQUARE_SIZE * (rows + 1.3),
            self.SQUARE_SIZE * .7, self.SQUARE_SIZE * (rows + 1.3),
            self.SQUARE_SIZE * .7, self.SQUARE_SIZE * .7,
            fill='blue', width=self.SQUARE_SIZE * .2
        )

        # Walls
        self.canvas.create_line(
            *((_ + 1.5) * self.SQUARE_SIZE for _ in grid['walls']),
            fill='blue', width=self.SQUARE_SIZE * .2
        )
        self.canvas.pack()

        # Stats labels
        self.create_stats_labels(stats_frame)

        # Create plot for learning curve
        self.create_learning_curve_plot(stats_frame)

    def create_stats_labels(self, parent):
        """Create statistics display"""
        stats_container = tk.Frame(parent, relief=tk.RIDGE, borderwidth=2)
        stats_container.pack(fill=tk.BOTH, pady=5)

        tk.Label(stats_container, text="Training Statistics",
                 font=('Arial', 14, 'bold')).pack(pady=5)

        # Current episode
        self.episode_label = tk.Label(stats_container, text="Episode: 0",
                                      font=('Arial', 12))
        self.episode_label.pack(anchor='w', padx=10)

        # Current reward
        self.reward_label = tk.Label(stats_container, text="Current Reward: 0",
                                     font=('Arial', 12))
        self.reward_label.pack(anchor='w', padx=10)

        # Average reward
        self.avg_reward_label = tk.Label(stats_container,
                                         text="Avg Reward (100): N/A",
                                         font=('Arial', 12))
        self.avg_reward_label.pack(anchor='w', padx=10)

        # Win rate
        self.winrate_label = tk.Label(stats_container, text="Win Rate: N/A",
                                      font=('Arial', 12))
        self.winrate_label.pack(anchor='w', padx=10)

        # Epsilon
        self.epsilon_label = tk.Label(stats_container, text="Epsilon: 1.000",
                                      font=('Arial', 12))
        self.epsilon_label.pack(anchor='w', padx=10)

        # Steps
        self.steps_label = tk.Label(stats_container, text="Total Steps: 0",
                                    font=('Arial', 12))
        self.steps_label.pack(anchor='w', padx=10)

        # Memory size
        self.memory_label = tk.Label(stats_container, text="Replay Memory: 0",
                                     font=('Arial', 12))
        self.memory_label.pack(anchor='w', padx=10)

    def create_learning_curve_plot(self, parent):
        """Create matplotlib plot for learning curve"""
        plot_frame = tk.Frame(parent, relief=tk.RIDGE, borderwidth=2)
        plot_frame.pack(fill=tk.BOTH, expand=True, pady=5)

        tk.Label(plot_frame, text="Learning Curve",
                 font=('Arial', 14, 'bold')).pack(pady=5)

        # Create matplotlib figure
        self.fig, (self.ax1, self.ax2) = plt.subplots(2, 1, figsize=(6, 5))
        self.fig.tight_layout(pad=3.0)

        # Reward plot
        self.ax1.set_xlabel('Episode')
        self.ax1.set_ylabel('Reward')
        self.ax1.set_title('Episode Rewards')
        self.ax1.grid(True, alpha=0.3)

        # Win rate plot
        self.ax2.set_xlabel('Episode')
        self.ax2.set_ylabel('Win Rate')
        self.ax2.set_title('Win Rate (Moving Average)')
        self.ax2.grid(True, alpha=0.3)
        self.ax2.set_ylim([0, 1])

        # Embed in tkinter
        self.canvas_plot = FigureCanvasTkAgg(self.fig, master=plot_frame)
        self.canvas_plot.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def update_stats(self):
        """Update statistics display"""
        episode = len(self.episode_rewards)
        self.episode_label.config(text=f"Episode: {episode}")
        self.reward_label.config(text=f"Current Reward: {self.current_episode_reward:.0f}")

        if len(self.episode_rewards) > 0:
            # Average reward (last 100 episodes)
            recent = self.episode_rewards[-100:]
            avg = sum(recent) / len(recent)
            self.avg_reward_label.config(text=f"Avg Reward (100): {avg:.1f}")

            # Win rate (last 100 episodes)
            recent_wins = self.episode_wins[-100:]
            winrate = sum(recent_wins) / len(recent_wins) if recent_wins else 0
            self.winrate_label.config(text=f"Win Rate: {winrate * 100:.1f}%")

        # Epsilon
        if hasattr(self.agent, 'explore_prob'):
            self.epsilon_label.config(text=f"Epsilon: {self.agent.explore_prob:.3f}")

        # Steps
        if hasattr(self.agent, 'steps'):
            self.steps_label.config(text=f"Total Steps: {self.agent.steps}")

        # Memory
        if hasattr(self.agent, 'memory'):
            self.memory_label.config(text=f"Replay Memory: {len(self.agent.memory)}")

    def update_learning_curve(self):
        """Update learning curve plots"""
        if len(self.episode_rewards) < 2:
            return

        episodes = list(range(1, len(self.episode_rewards) + 1))

        # Clear previous plots
        self.ax1.clear()
        self.ax2.clear()

        # Plot rewards
        self.ax1.plot(episodes, self.episode_rewards, alpha=0.3, color='blue')

        # Moving average
        if len(self.episode_rewards) >= 10:
            window = 50
            moving_avg = []
            for i in range(len(self.episode_rewards)):
                start = max(0, i - window + 1)
                moving_avg.append(sum(self.episode_rewards[start:i + 1]) / (i - start + 1))
            self.ax1.plot(episodes, moving_avg, color='red', linewidth=2,
                          label=f'{window}-Episode MA')

        self.ax1.set_xlabel('Episode')
        self.ax1.set_ylabel('Reward')
        self.ax1.set_title('Episode Rewards')
        self.ax1.grid(True, alpha=0.3)
        self.ax1.legend()

        # Plot win rate
        window = 100
        win_rates = []
        for i in range(len(self.episode_wins)):
            start = max(0, i - window + 1)
            win_rate = sum(self.episode_wins[start:i + 1]) / (i - start + 1)
            win_rates.append(win_rate)

        self.ax2.plot(episodes, win_rates, color='green', linewidth=2)
        self.ax2.set_xlabel('Episode')
        self.ax2.set_ylabel('Win Rate')
        self.ax2.set_title(f'Win Rate ({window}-Episode MA)')
        self.ax2.set_ylim([0, 1])
        self.ax2.grid(True, alpha=0.3)

        # Redraw
        self.canvas_plot.draw()

    def update_board(self):
        """Update game board visualization"""
        self.canvas.delete('state')

        # Draw dots
        for x, y in self.state._dots:
            self.canvas.create_rectangle(
                (y + 1.4) * self.SQUARE_SIZE, (x + 1.4) * self.SQUARE_SIZE,
                (y + 1.6) * self.SQUARE_SIZE, (x + 1.6) * self.SQUARE_SIZE,
                tags='state', fill='white')

        # Draw Pacman
        x, y = self.state._pacman
        self.canvas.create_oval(
            (y + 1) * self.SQUARE_SIZE, (x + 1) * self.SQUARE_SIZE,
            (y + 2) * self.SQUARE_SIZE, (x + 2) * self.SQUARE_SIZE,
            tags='state', fill='yellow')

        # Pacman mouth
        pts = ((y + 1, x + 1), (y + 1, x + 2), (y + 2, x + 2), (y + 2, x + 1))
        self.canvas.create_polygon(
            pts[self.last_action][0] * self.SQUARE_SIZE,
            pts[self.last_action][1] * self.SQUARE_SIZE,
            (y + 1.5) * self.SQUARE_SIZE, (x + 1.5) * self.SQUARE_SIZE,
            pts[(self.last_action + 1) % 4][0] * self.SQUARE_SIZE,
            pts[(self.last_action + 1) % 4][1] * self.SQUARE_SIZE,
            tags='state')

        # Draw Ghost
        x, y = self.state._ghost
        self.canvas.create_polygon(
            (y + 1) * self.SQUARE_SIZE, (x + 1) * self.SQUARE_SIZE,
            (y + 1) * self.SQUARE_SIZE, (x + 2) * self.SQUARE_SIZE,
            (y + 1 + 1 / 6) * self.SQUARE_SIZE, (x + 1.7) * self.SQUARE_SIZE,
            (y + 1 + 2 / 6) * self.SQUARE_SIZE, (x + 2) * self.SQUARE_SIZE,
            (y + 1 + 3 / 6) * self.SQUARE_SIZE, (x + 1.7) * self.SQUARE_SIZE,
            (y + 1 + 4 / 6) * self.SQUARE_SIZE, (x + 2) * self.SQUARE_SIZE,
            (y + 1 + 5 / 6) * self.SQUARE_SIZE, (x + 1.7) * self.SQUARE_SIZE,
            (y + 2) * self.SQUARE_SIZE, (x + 2) * self.SQUARE_SIZE,
            (y + 2) * self.SQUARE_SIZE, (x + 1) * self.SQUARE_SIZE,
            tags='state', fill='red', smooth=True)

        # Ghost eyes
        self.canvas.create_oval(
            (y + 1.1) * self.SQUARE_SIZE, (x + 1.1) * self.SQUARE_SIZE,
            (y + 1.4) * self.SQUARE_SIZE, (x + 1.4) * self.SQUARE_SIZE,
            tags='state', fill='white', outline='white')
        self.canvas.create_oval(
            (y + 1.6) * self.SQUARE_SIZE, (x + 1.1) * self.SQUARE_SIZE,
            (y + 1.9) * self.SQUARE_SIZE, (x + 1.4) * self.SQUARE_SIZE,
            tags='state', fill='white', outline='white')

    def train_episode(self):
        """Train for one episode"""
        self.state = self.init_state
        self.current_episode_reward = 0

        while True:
            # Get action
            self.last_action = self.agent.get_action(self.state)

            # Take action
            new_state = self.state._move(self.last_action)

            # Calculate reward
            reward = (len(new_state._dots) - len(self.state._dots)) * 10 - 1
            if new_state._lost:
                reward -= 500
            elif new_state._won:
                reward += 500

            # Update agent
            self.agent.update(self.state, self.last_action, new_state, reward)

            self.current_episode_reward += reward
            self.state = new_state

            # Check terminal
            if self.state._lost or self.state._won:
                # Record statistics
                self.episode_rewards.append(self.current_episode_reward)
                self.episode_wins.append(1 if self.state._won else 0)

                # Decay epsilon
                if hasattr(self.agent, 'decay_epsilon'):
                    self.agent.decay_epsilon()

                # Adjust learning rate
                if hasattr(self.agent, 'adjust_learning_rate'):
                    self.agent.adjust_learning_rate()

                return

    def train(self, episodes, update_freq=100):
        """Train for multiple episodes"""
        print(f"\nStarting training for {episodes} episodes...")
        print(f"{'=' * 60}\n")

        for episode in range(episodes):
            self.train_episode()

            # Update display periodically
            if (episode + 1) % update_freq == 0:
                print(f"Episode {episode + 1}/{episodes}")
                print(
                    f"  Avg Reward (last 100): {sum(self.episode_rewards[-100:]) / min(100, len(self.episode_rewards)):.1f}")
                print(
                    f"  Win Rate (last 100): {sum(self.episode_wins[-100:]) / min(100, len(self.episode_wins)) * 100:.1f}%")
                print(f"  Epsilon: {self.agent.explore_prob:.3f}")
                print(f"  Total Steps: {self.agent.steps}")
                print()

                # Update GUI
                self.update_stats()
                self.update_learning_curve()
                self.update()

        print(f"\n{'=' * 60}")
        print(f"Training completed!")
        print(f"{'=' * 60}\n")

        # Save final plot
        self.fig.savefig('dqn_learning_curve.png', dpi=300, bbox_inches='tight')
        print("Learning curve saved to 'dqn_learning_curve.png'")

        # Save training data
        with open('dqn_training_rewards.txt', 'w') as f:
            for reward in self.episode_rewards:
                f.write(f"{reward}\n")
        print("Training rewards saved to 'dqn_training_rewards.txt'\n")

    def play(self):
        """Play one episode with visualization"""
        self.state = self.init_state
        total_reward = 0

        while True:
            self.update_board()
            self.update()
            time.sleep(self.ANIMATION_SPEED)

            # Get action (greedy)
            self.last_action = self.agent.get_best_policy(self.state)

            # Take action
            new_state = self.state._move(self.last_action)

            # Calculate reward
            reward = (len(new_state._dots) - len(self.state._dots)) * 10 - 1
            if new_state._lost:
                reward -= 500
            elif new_state._won:
                reward += 500

            total_reward += reward
            self.state = new_state

            # Check terminal
            if self.state._lost or self.state._won:
                if self.state._won:
                    print(f"✓ Pacman WON! Total Reward: {total_reward}")
                else:
                    print(f"✗ Pacman LOST! Total Reward: {total_reward}")
                return


def main():
    parser = argparse.ArgumentParser(
        description='Deep Q-Learning for Pacman (Stanford CS229 Implementation)',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument('layout', choices=PRESET_LAYOUTS.keys(),
                        default='small', nargs='?',
                        help='Layout preset')
    parser.add_argument('-d', '--discount', type=float, default=0.9,
                        help='Discount factor gamma')
    parser.add_argument('-r', '--learning-rate', type=float, default=0.00025,
                        help='Learning rate (Stanford paper: 0.00025)')
    parser.add_argument('-e', '--epsilon', type=float, default=1.0,
                        help='Initial exploration probability')
    parser.add_argument('-t', '--train', type=int, required=True,
                        help='Number of training episodes')
    parser.add_argument('-p', '--play', type=int, required=True,
                        help='Number of episodes to play after training')
    parser.add_argument('-s', '--save', type=str, default='dqn_model.pth',
                        help='Path to save trained model')
    parser.add_argument('-l', '--load', type=str, default=None,
                        help='Path to load pretrained model')

    args = parser.parse_args()

    # Get grid configuration
    grid = PRESET_LAYOUTS[args.layout]
    grid_size = grid['size']

    # Create DQN agent
    agent = dqn_agent.DQNAgent(
        Pacman,
        discount=args.discount,
        learning_rate=args.learning_rate,
        explore_prob=args.epsilon,
        grid_size=grid_size
    )

    # Load pretrained model if specified
    if args.load:
        agent.load(args.load)

    # Create initial state
    init_state = State(**grid, conv_walls=convert_walls(grid['size'], grid['walls']))

    # Create GUI
    gui = DQNTrainingGUI(grid, agent, init_state)

    # Training
    if args.train > 0:
        gui.train(args.train)
        agent.save(args.save)

    # Playing
    if args.play > 0:
        print(f"\nPlaying {args.play} episodes...")
        for i in range(args.play):
            print(f"\nEpisode {i + 1}/{args.play}")
            gui.play()
            time.sleep(0.5)

    gui.mainloop()


if __name__ == '__main__':
    main()