import matplotlib.pyplot as plt
import numpy as np


def plot_learning_curve(filename='training_results.txt', window=100):
    # Read rewards from file
    with open(filename, 'r') as f:
        rewards = [float(line.strip()) for line in f]

    # Calculate moving average
    moving_avg = []
    for i in range(len(rewards)):
        start = max(0, i - window + 1)
        moving_avg.append(np.mean(rewards[start:i + 1]))

    # Create plot
    plt.figure(figsize=(12, 6))

    # Plot raw rewards (lighter)
    plt.plot(rewards, alpha=0.3, label='Episode Reward')

    # Plot moving average (darker)
    plt.plot(moving_avg, linewidth=2, label=f'{window}-Episode Moving Average')

    plt.xlabel('Episode')
    plt.ylabel('Total Reward')
    plt.title('DQN Training Curve')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # Save plot
    plt.savefig('learning_curve.png', dpi=300)
    print('Plot saved as learning_curve.png')
    plt.show()


if __name__ == '__main__':
    plot_learning_curve()