#!/usr/bin/env python3
"""
Simple script to calculate win rate and rewards from your Pac-Man training
Run this AFTER your training to get presentation metrics
"""

import matplotlib.pyplot as plt
import json
import sys


def calculate_metrics_from_file(filename='training_results.txt'):
    """
    Calculate metrics from the training_results.txt file
    that your current code already creates
    """

    print("=" * 70)
    print("CALCULATING METRICS FROM TRAINING RESULTS")
    print("=" * 70 + "\n")

    try:
        with open(filename, 'r') as f:
            rewards = [float(line.strip()) for line in f if line.strip()]
    except FileNotFoundError:
        print(f"Error: {filename} not found!")
        print("Run your training first to generate this file.")
        return

    if not rewards:
        print("No data found in file!")
        return

    total_episodes = len(rewards)

    # Calculate wins/losses (positive reward = win, negative = loss)
    wins = [1 if r > 0 else 0 for r in rewards]
    total_wins = sum(wins)
    total_losses = total_episodes - total_wins

    # Overall statistics
    overall_win_rate = (total_wins / total_episodes) * 100
    overall_avg_reward = sum(rewards) / total_episodes

    # Last 100 episodes statistics
    last_100_rewards = rewards[-100:] if len(rewards) >= 100 else rewards
    last_100_wins = sum(1 if r > 0 else 0 for r in last_100_rewards)
    last_100_win_rate = (last_100_wins / len(last_100_rewards)) * 100
    last_100_avg_reward = sum(last_100_rewards) / len(last_100_rewards)

    # Print results
    print("TRAINING RESULTS")
    print("-" * 70)
    print(f"Total Episodes:              {total_episodes:,}")
    print(f"Total Wins:                  {total_wins:,}")
    print(f"Total Losses:                {total_losses:,}")
    print(f"")
    print(f"Overall Win Rate:            {overall_win_rate:.2f}%")
    print(f"Overall Average Reward:      {overall_avg_reward:.2f}")
    print(f"")
    print(f"Last 100 Win Rate:           {last_100_win_rate:.2f}%")
    print(f"Last 100 Average Reward:     {last_100_avg_reward:.2f}")
    print(f"")
    print(f"Best Episode Reward:         {max(rewards):.2f}")
    print(f"Worst Episode Reward:        {min(rewards):.2f}")
    print("=" * 70 + "\n")

    # Create visualizations
    create_plots(rewards, wins)

    # Save summary
    save_summary(total_episodes, total_wins, total_losses, overall_win_rate,
                 overall_avg_reward, last_100_win_rate, last_100_avg_reward,
                 max(rewards), min(rewards))

    return {
        'total_episodes': total_episodes,
        'wins': total_wins,
        'losses': total_losses,
        'overall_win_rate': overall_win_rate,
        'overall_avg_reward': overall_avg_reward,
        'last_100_win_rate': last_100_win_rate,
        'last_100_avg_reward': last_100_avg_reward,
        'best_reward': max(rewards),
        'worst_reward': min(rewards)
    }


def create_plots(rewards, wins, window=100):
    """Create visualization plots"""

    print("Creating plots...")

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Pac-Man Training Results', fontsize=16, fontweight='bold')

    episodes = list(range(1, len(rewards) + 1))

    # Plot 1: Rewards over time
    ax1 = axes[0, 0]
    ax1.plot(episodes, rewards, alpha=0.3, color='blue', label='Episode Reward')

    # Moving average
    if len(rewards) >= window:
        moving_avg = []
        for i in range(len(rewards)):
            start = max(0, i - window + 1)
            moving_avg.append(sum(rewards[start:i + 1]) / (i - start + 1))
        ax1.plot(episodes, moving_avg, color='red', linewidth=2,
                 label=f'{window}-Episode Moving Avg')

    ax1.axhline(y=0, color='black', linestyle='--', linewidth=1, alpha=0.5)
    ax1.set_xlabel('Episode')
    ax1.set_ylabel('Reward')
    ax1.set_title('Training Rewards Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Plot 2: Win rate over time
    ax2 = axes[0, 1]
    if len(wins) >= window:
        win_rates = []
        for i in range(len(wins)):
            start = max(0, i - window + 1)
            win_rates.append(sum(wins[start:i + 1]) / (i - start + 1) * 100)
        ax2.plot(episodes, win_rates, color='green', linewidth=2)
        ax2.fill_between(episodes, 0, win_rates, alpha=0.3, color='green')

    ax2.set_xlabel('Episode')
    ax2.set_ylabel('Win Rate (%)')
    ax2.set_title(f'Win Rate Over Time ({window}-Episode Window)')
    ax2.grid(True, alpha=0.3)
    ax2.set_ylim(0, 105)

    # Plot 3: Reward distribution
    ax3 = axes[1, 0]
    ax3.hist(rewards, bins=50, color='purple', alpha=0.7, edgecolor='black')
    ax3.axvline(sum(rewards) / len(rewards), color='red', linestyle='--',
                linewidth=2, label=f'Mean: {sum(rewards) / len(rewards):.1f}')
    ax3.set_xlabel('Reward')
    ax3.set_ylabel('Frequency')
    ax3.set_title('Reward Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3, axis='y')

    # Plot 4: Cumulative wins
    ax4 = axes[1, 1]
    cumulative_wins = []
    total = 0
    for w in wins:
        total += w
        cumulative_wins.append(total)
    ax4.plot(episodes, cumulative_wins, color='orange', linewidth=2)
    ax4.fill_between(episodes, 0, cumulative_wins, alpha=0.3, color='orange')
    ax4.set_xlabel('Episode')
    ax4.set_ylabel('Cumulative Wins')
    ax4.set_title('Cumulative Wins Over Time')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_analysis.png', dpi=300, bbox_inches='tight')
    print("✓ Plots saved to: training_analysis.png\n")
    plt.close()


def save_summary(total_episodes, wins, losses, overall_wr, overall_ar,
                 last_100_wr, last_100_ar, best, worst):
    """Save text summary for presentation"""

    with open('presentation_summary.txt', 'w') as f:
        f.write("=" * 70 + "\n")
        f.write("PACMAN Q-LEARNING - PRESENTATION SUMMARY\n")
        f.write("=" * 70 + "\n\n")

        f.write("TRAINING RESULTS\n")
        f.write("-" * 70 + "\n")
        f.write(f"Total Training Episodes:     {total_episodes:,}\n")
        f.write(f"Wins:                        {wins:,}\n")
        f.write(f"Losses:                      {losses:,}\n")
        f.write(f"\n")
        f.write(f"Overall Win Rate:            {overall_wr:.2f}%\n")
        f.write(f"Final Win Rate (last 100):   {last_100_wr:.2f}%\n")
        f.write(f"\n")
        f.write(f"Overall Avg Reward:          {overall_ar:.2f}\n")
        f.write(f"Final Avg Reward (last 100): {last_100_ar:.2f}\n")
        f.write(f"\n")
        f.write(f"Best Episode:                {best:.2f}\n")
        f.write(f"Worst Episode:               {worst:.2f}\n")
        f.write(f"\n")
        f.write("=" * 70 + "\n")
        f.write("KEY TAKEAWAYS FOR PRESENTATION\n")
        f.write("=" * 70 + "\n")
        f.write(f"• Agent trained for {total_episodes:,} episodes\n")
        f.write(f"• Achieved {last_100_wr:.1f}% win rate in final 100 episodes\n")
        f.write(f"• Average reward improved from {overall_ar:.1f} to {last_100_ar:.1f}\n")

        # Learning progress
        improvement = last_100_wr - overall_wr
        if improvement > 0:
            f.write(f"• Win rate improved by {improvement:.1f}% during training\n")

        f.write("\n")

    print("✓ Summary saved to: presentation_summary.txt\n")


def main():
    """Main function"""

    if len(sys.argv) > 1:
        filename = sys.argv[1]
    else:
        filename = 'training_results.txt'

    calculate_metrics_from_file(filename)

    print("1. training_analysis.png       - Visual plots")
    print("2. presentation_summary.txt    - Text summary")
    print("\nUse these files in your presentation slides!")
    print("=" * 70)


if __name__ == '__main__':
    main()