import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Categorical
from collections import deque
import argparse
# pip install "gymnasium[classic-control]" pygame numpy matplotlib tqdm torch


class ReplayBuffer:
    """Experience replay buffer for SAC (discrete)"""

    def __init__(self, capacity=100000):
        self.buffer = deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[idx] for idx in indices])

        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(np.array(actions)).unsqueeze(1),   # [B,1] for gather
            torch.FloatTensor(np.array(rewards)).unsqueeze(1),  # [B,1]
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(np.array(dones)).unsqueeze(1)     # [B,1]
        )

    def size(self):
        return len(self.buffer)


class Actor(nn.Module):
    """Discrete actor: outputs logits over actions"""

    def __init__(self, num_states, num_actions, hidden_size=256):
        super(Actor, self).__init__()

        self.fc1 = nn.Linear(num_states, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.logits_layer = nn.Linear(hidden_size, num_actions)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        logits = self.logits_layer(x)
        return logits  # [B, num_actions]

    def sample(self, state):
        """
        Returns:
          actions: [B]
          log_probs: [B,1]
          probs: [B, num_actions]
        """
        logits = self.forward(state)
        probs = F.softmax(logits, dim=-1)
        dist = Categorical(probs=probs)
        actions = dist.sample()                      # [B]
        log_probs = dist.log_prob(actions).unsqueeze(1)  # [B,1]
        return actions, log_probs, probs


class Critic(nn.Module):
    """Critic network (Q-function for all actions)"""

    def __init__(self, num_states, num_actions, hidden_size=256):
        super(Critic, self).__init__()

        self.fc1 = nn.Linear(num_states, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_actions)

    def forward(self, state):
        """
        Returns Q(s, a) for all actions: [B, num_actions]
        """
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        q_values = self.fc3(x)
        return q_values


class SAC:
    def __init__(self, path=None, render_mode=None, device='cpu'):
        # Use discrete Acrobot-v1 for SAC (discrete variant)
        self.env = gym.make('Acrobot-v1', render_mode=render_mode)

        # Device
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        self.gamma = 0.99
        self.tau = 0.005
        self.max_steps_per_episode = 500  # default max steps for Acrobot-v1

        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n

        # Hyperparameters
        self.batch_size = 256
        self.buffer_capacity = 100000
        self.actor_lr = 0.0003
        self.critic_lr = 0.0003
        self.alpha_lr = 0.0003
        self.hidden_size = 256
        self.update_frequency = 1

        # Networks
        self.actor = Actor(self.num_states, self.num_actions, self.hidden_size).to(self.device)
        self.critic_1 = Critic(self.num_states, self.num_actions, self.hidden_size).to(self.device)
        self.critic_2 = Critic(self.num_states, self.num_actions, self.hidden_size).to(self.device)
        self.target_critic_1 = Critic(self.num_states, self.num_actions, self.hidden_size).to(self.device)
        self.target_critic_2 = Critic(self.num_states, self.num_actions, self.hidden_size).to(self.device)

        # Copy weights to target networks
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())

        # Optimizers
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=self.critic_lr)
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=self.critic_lr)

        # Temperature (alpha) with auto-tuning
        # Good default for discrete: -log(|A|)
        self.target_entropy = -np.log(self.num_actions)
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=self.alpha_lr)

        # Replay buffer
        self.replay_buffer = ReplayBuffer(self.buffer_capacity)

        if path is not None:
            self.actor.load_state_dict(torch.load(path, map_location=self.device))
            print(f"Loaded model from {path}")

    @property
    def alpha(self):
        return self.log_alpha.exp()

    def get_action(self, state, deterministic=False):
        """Sample action from policy (discrete)"""
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)

        logits = self.actor(state)
        probs = F.softmax(logits, dim=-1)
        if deterministic:
            action = probs.argmax(dim=-1)  # [1]
        else:
            dist = Categorical(probs=probs)
            action = dist.sample()         # [1]

        return action.item()

    def update(self, states, actions, rewards, next_states, dones):
        """Update SAC networks (discrete variant)"""
        states = states.to(self.device)           # [B, S]
        actions = actions.to(self.device)         # [B,1]
        rewards = rewards.to(self.device)         # [B,1]
        next_states = next_states.to(self.device) # [B, S]
        dones = dones.to(self.device)             # [B,1]

        alpha = self.alpha

        # ----- Critic update -----
        with torch.no_grad():
            # Next-state policy
            next_logits = self.actor(next_states)               # [B,A]
            next_probs = F.softmax(next_logits, dim=-1)         # [B,A]
            next_log_probs = torch.log(next_probs + 1e-8)       # [B,A]

            # Target Q-values for next state
            target_q1_next = self.target_critic_1(next_states)  # [B,A]
            target_q2_next = self.target_critic_2(next_states)  # [B,A]
            target_min_q_next = torch.min(target_q1_next, target_q2_next)  # [B,A]

            # V(s') = sum_a pi(a|s') [ Q_min(s',a) - alpha * log_pi(a|s') ]
            next_v = (next_probs * (target_min_q_next - alpha * next_log_probs)).sum(dim=1, keepdim=True)  # [B,1]

            target_q = rewards + (1 - dones) * self.gamma * next_v  # [B,1]

        # Current Q(s,a) for chosen actions
        current_q1_all = self.critic_1(states)       # [B,A]
        current_q2_all = self.critic_2(states)       # [B,A]
        current_q1 = current_q1_all.gather(1, actions)  # [B,1]
        current_q2 = current_q2_all.gather(1, actions)  # [B,1]

        critic_1_loss = F.mse_loss(current_q1, target_q)
        critic_2_loss = F.mse_loss(current_q2, target_q)

        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()

        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # ----- Actor update -----
        logits = self.actor(states)                  # [B,A]
        probs = F.softmax(logits, dim=-1)            # [B,A]
        log_probs = torch.log(probs + 1e-8)          # [B,A]

        q1_all = self.critic_1(states)               # [B,A]
        q2_all = self.critic_2(states)               # [B,A]
        min_q_all = torch.min(q1_all, q2_all)        # [B,A]

        # J_pi = E_s[ sum_a pi(a|s) (alpha * log pi(a|s) - Q_min(s,a)) ]
        actor_loss = (probs * (alpha * log_probs - min_q_all)).sum(dim=1).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # ----- Alpha (temperature) update -----
        # J(alpha) = E_s[ sum_a pi(a|s) * (-alpha (log pi(a|s) + target_entropy)) ]
        alpha_loss = -(self.log_alpha * (probs * (log_probs + self.target_entropy)).sum(dim=1).detach()).mean()

        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()

        # ----- Soft update target critics -----
        self._soft_update(self.target_critic_1, self.critic_1)
        self._soft_update(self.target_critic_2, self.critic_2)

        return critic_1_loss.item(), critic_2_loss.item(), actor_loss.item(), alpha.item()

    def _soft_update(self, target_model, source_model):
        """Soft update target network"""
        for target_param, source_param in zip(target_model.parameters(), source_model.parameters()):
            target_param.data.copy_(self.tau * source_param.data + (1.0 - self.tau) * target_param.data)

    def train(self):
        episode_rewards = []
        running_reward = 0
        episode_count = 0
        total_steps = 0

        average = deque(maxlen=100)

        print("Starting SAC (discrete) training on Acrobot-v1...")
        print(f"Hyperparameters: batch_size={self.batch_size}, actor_lr={self.actor_lr}, critic_lr={self.critic_lr}")

        while True:
            state, info = self.env.reset()
            episode_reward = 0

            for step in range(self.max_steps_per_episode):
                # Render environment if enabled
                if self.env.render_mode == 'human':
                    self.env.render()

                # Exploration warmup with random policy
                if total_steps < 1000:
                    action = self.env.action_space.sample()
                else:
                    action = self.get_action(state)

                # Execute action
                next_state, reward, terminated, truncated, info = self.env.step(action)
                done = terminated or truncated

                # Store transition
                self.replay_buffer.add(state, action, reward, next_state, done)

                episode_reward += reward
                state = next_state
                total_steps += 1

                # Train after collecting enough samples
                if self.replay_buffer.size() >= self.batch_size and total_steps % self.update_frequency == 0:
                    batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = \
                        self.replay_buffer.sample(self.batch_size)

                    critic_1_loss, critic_2_loss, actor_loss, alpha = self.update(
                        batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones
                    )

                if done:
                    break

            # Update statistics
            episode_rewards.append(episode_reward)
            average.append(episode_reward)
            running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
            episode_count += 1

            # Log progress
            if episode_count % 10 == 0:
                num_avg_episodes = min(len(average), 100)
                avg_reward = np.mean(list(average)[-100:]) if len(average) > 0 else episode_reward
                current_alpha = self.alpha.item()
                print(f"Episode: {episode_count}, Running Reward: {running_reward:.2f}, "
                      f"Avg (last {num_avg_episodes} ep): {avg_reward:.2f}, "
                      f"Alpha: {current_alpha:.3f}, Total Steps: {total_steps}")

            # Save model every 100 episodes (and at 10)
            if episode_count == 10:
                save_path = f'SAC_Acrobot_ep{episode_count}.pth'
                torch.save(self.actor.state_dict(), save_path)
                print(f"Model saved: {save_path}")

            if episode_count % 100 == 0:
                save_path = f'SAC_Acrobot_ep{episode_count}.pth'
                torch.save(self.actor.state_dict(), save_path)
                print(f"Model saved: {save_path}")

            # "Solved" criterion (Acrobot reward is negative; closer to 0 is better)
            if len(average) >= 100:
                avg_reward = np.mean(list(average))
                if avg_reward > -100:  # heuristic threshold; tune as you like
                    print(f"Stopped at episode {episode_count}. Average reward: {avg_reward:.2f}")
                    torch.save(self.actor.state_dict(), 'SAC_Acrobot_final.pth')
                    return self.actor

    def test(self, path, num_episodes=10):
        """Test trained policy"""
        if path:
            self.actor.load_state_dict(torch.load(path, map_location=self.device))
            print(f"Loaded model from {path}")

        total_rewards = []

        for episode in range(num_episodes):
            state, info = self.env.reset()
            episode_reward = 0

            for step in range(self.max_steps_per_episode):
                if self.env.render_mode == 'human':
                    self.env.render()
                action = self.get_action(state, deterministic=True)
                state, reward, terminated, truncated, info = self.env.step(action)
                done = terminated or truncated
                episode_reward += reward

                if done:
                    break

            total_rewards.append(episode_reward)
            print(f'Episode {episode + 1}: {episode_reward:.2f}')

        avg_reward = np.mean(total_rewards)
        print(f'Average reward over {num_episodes} episodes: {avg_reward:.2f}')
        return avg_reward

    def test_all_models(self, num_episodes=5):
        """Test all saved models and display their performance"""
        import glob

        # Find all saved model files
        model_files = glob.glob('SAC_Acrobot_ep*.pth')

        if not model_files:
            print("No saved models found!")
            return

        # Sort by episode number
        def get_episode_num(filename):
            import re
            match = re.search(r'ep(\d+)', filename)
            return int(match.group(1)) if match else 0

        model_files.sort(key=get_episode_num)

        print(f"\nFound {len(model_files)} saved models")
        print("=" * 70)

        results = []

        for model_file in model_files:
            episode_num = get_episode_num(model_file)
            print(f"\nTesting model: {model_file} (Episode {episode_num})")
            print("-" * 70)

            # Load and test the model
            self.actor.load_state_dict(torch.load(model_file, map_location=self.device))
            avg_reward = self.test(None, num_episodes=num_episodes)

            results.append({
                'model': model_file,
                'episode': episode_num,
                'avg_reward': avg_reward
            })

            print("-" * 70)

        # Display summary
        print("\n" + "=" * 70)
        print("SUMMARY OF ALL MODELS")
        print("=" * 70)
        print(f"{'Model':<40} {'Episode':<10} {'Avg Reward':<15}")
        print("-" * 70)

        for result in results:
            print(f"{result['model']:<40} {result['episode']:<10} {result['avg_reward']:<15.2f}")

        # Find best model
        best_model = max(results, key=lambda x: x['avg_reward'])
        print("=" * 70)
        print(f"Best Model: {best_model['model']} with avg reward: {best_model['avg_reward']:.2f}")
        print("=" * 70)

        return results


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str,
                        help="train - train model, test - test single model, test_all - test all saved models",
                        default='train')
    parser.add_argument("--path", type=str, help="policy path for single test", default=None)
    parser.add_argument("--episodes", type=int, help="number of episodes for testing", default=5)
    parser.add_argument("--render", action='store_true', help="enable rendering/visualization")
    parser.add_argument("--no-render", dest='render', action='store_false', help="disable rendering/visualization")
    parser.add_argument("--device", type=str, help="cpu or cuda", default='cpu')
    parser.set_defaults(render=True)
    return parser


if __name__ == '__main__':
    args = get_args().parse_args()
    # Enable rendering based on command line argument
    render_mode = 'human' if args.render else None

    if args.render:
        print("Rendering enabled: You will see the environment visualization")
    else:
        print("Rendering disabled: Training/testing will run faster without visualization")

    sac = SAC(args.path if args.mode == 'train' else None, render_mode=render_mode, device=args.device)

    if args.mode == 'train':
        sac.train()
    elif args.mode == 'test':
        if args.path is None:
            print("Error: Please provide --path argument for testing a single model")
        else:
            sac.test(args.path, num_episodes=args.episodes)
    elif args.mode == 'test_all':
        sac.test_all_models(num_episodes=args.episodes)
    else:
        print(f"Unknown mode: {args.mode}. Use 'train', 'test', or 'test_all'")
