"""
Long-Term Investment Analysis Module
Comprehensive evaluation for buy-and-hold portfolio strategies
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
import warnings

warnings.filterwarnings("ignore")

# Set style
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (14, 8)
plt.rcParams["font.size"] = 10


class LongTermAnalyzer:
    """
    Analyzes portfolio performance for long-term investing
    Focus: Risk-adjusted returns, drawdowns, consistency
    """

    def __init__(
        self,
        portfolio_values,
        returns_array,
        weights_history,
        start_date,
        trading_days_per_year=252,
        benchmark_returns=None,
    ):
        """
        Args:
            portfolio_values: List of daily portfolio values (starting at 1.0)
            returns_array: Array of daily returns (portfolio changes - 1)
            weights_history: Array of portfolio weights over time [time, assets]
            start_date: Start date as string "YYYY/MM/DD"
            trading_days_per_year: 252 for stocks, 365 for crypto
            benchmark_returns: Optional benchmark daily returns (e.g., SPY)
        """
        self.portfolio_values = np.array(portfolio_values)
        self.returns = np.array(returns_array)
        self.weights_history = weights_history
        self.start_date = datetime.strptime(start_date, "%Y/%m/%d")
        self.trading_days = trading_days_per_year
        self.num_days = len(self.portfolio_values)
        self.years = self.num_days / self.trading_days

        # Benchmark
        self.benchmark_returns = benchmark_returns
        if benchmark_returns is not None:
            self.benchmark_values = np.cumprod(1 + np.array(benchmark_returns))
        else:
            # Default: assume 10% annualized SPY-like benchmark
            daily_benchmark = (1.10 ** (1 / self.trading_days)) - 1
            self.benchmark_returns = np.full(self.num_days, daily_benchmark)
            self.benchmark_values = np.cumprod(1 + self.benchmark_returns)

        # Date index
        self.dates = [self.start_date + timedelta(days=i) for i in range(self.num_days)]

    # ============================================================================
    # TIER 1 - ABSOLUTELY ESSENTIAL
    # ============================================================================

    def calculate_tier1_metrics(self):
        """Core metrics for long-term investing"""
        metrics = {}

        # Cumulative return
        final_value = self.portfolio_values[-1]
        metrics["Total Return"] = (final_value - 1.0) * 100  # percentage

        # CAGR (Compound Annual Growth Rate)
        metrics["CAGR"] = ((final_value ** (1 / self.years)) - 1) * 100

        # Benchmark returns
        benchmark_final = self.benchmark_values[-1]
        metrics["Benchmark Total Return"] = (benchmark_final - 1.0) * 100
        metrics["Benchmark CAGR"] = ((benchmark_final ** (1 / self.years)) - 1) * 100

        # Alpha (excess return over benchmark)
        metrics["Alpha"] = metrics["CAGR"] - metrics["Benchmark CAGR"]

        # Max Drawdown
        peak = np.maximum.accumulate(self.portfolio_values)
        drawdown = (self.portfolio_values - peak) / peak
        metrics["Max Drawdown"] = np.min(drawdown) * 100  # percentage

        # Drawdown Duration (time to recover from max drawdown)
        max_dd_idx = np.argmin(drawdown)
        max_dd_peak_idx = np.argmax(self.portfolio_values[: max_dd_idx + 1])

        # Find recovery point
        recovery_idx = max_dd_idx
        peak_value = self.portfolio_values[max_dd_peak_idx]
        for i in range(max_dd_idx, len(self.portfolio_values)):
            if self.portfolio_values[i] >= peak_value:
                recovery_idx = i
                break

        metrics["Drawdown Duration (days)"] = recovery_idx - max_dd_peak_idx

        # Volatility (annualized)
        metrics["Volatility"] = np.std(self.returns) * np.sqrt(self.trading_days) * 100

        # Sharpe Ratio (annualized)
        mean_return = np.mean(self.returns)
        std_return = np.std(self.returns)
        metrics["Sharpe Ratio"] = (mean_return / std_return) * np.sqrt(
            self.trading_days
        )

        return metrics

    # ============================================================================
    # TIER 2 - RISK-ADJUSTED QUALITY
    # ============================================================================

    def calculate_tier2_metrics(self):
        """Advanced risk-adjusted metrics"""
        metrics = {}

        # Sortino Ratio (downside deviation)
        downside_returns = self.returns[self.returns < 0]
        downside_std = np.std(downside_returns) if len(downside_returns) > 0 else 1e-8
        metrics["Sortino Ratio"] = (np.mean(self.returns) / downside_std) * np.sqrt(
            self.trading_days
        )

        # Calmar Ratio
        cagr = (self.portfolio_values[-1] ** (1 / self.years)) - 1
        peak = np.maximum.accumulate(self.portfolio_values)
        drawdown = (self.portfolio_values - peak) / peak
        max_dd = abs(np.min(drawdown))
        metrics["Calmar Ratio"] = cagr / (max_dd + 1e-8)

        # Win Rate
        metrics["Win Rate"] = (np.sum(self.returns > 0) / len(self.returns)) * 100

        # Best/Worst Day
        metrics["Best Day Return"] = np.max(self.returns) * 100
        metrics["Worst Day Return"] = np.min(self.returns) * 100

        # Rolling Sharpe (1-year)
        rolling_window = self.trading_days  # 1 year
        if len(self.returns) >= rolling_window:
            rolling_sharpes = []
            for i in range(len(self.returns) - rolling_window + 1):
                window_returns = self.returns[i : i + rolling_window]
                if np.std(window_returns) > 0:
                    sharpe = (
                        np.mean(window_returns) / np.std(window_returns)
                    ) * np.sqrt(self.trading_days)
                    rolling_sharpes.append(sharpe)
            metrics["Min Rolling 1Y Sharpe"] = np.min(rolling_sharpes)
            metrics["Avg Rolling 1Y Sharpe"] = np.mean(rolling_sharpes)

        return metrics

    # ============================================================================
    # TIER 3 - PORTFOLIO BEHAVIOR
    # ============================================================================

    def calculate_tier3_metrics(self):
        """Portfolio construction and trading behavior"""
        metrics = {}

        # Turnover (annualized)
        if self.weights_history is not None and len(self.weights_history) > 1:
            weight_changes = np.sum(
                np.abs(np.diff(self.weights_history, axis=0)), axis=1
            )
            daily_turnover = np.mean(weight_changes)
            metrics["Turnover (annualized)"] = daily_turnover * self.trading_days * 100

            # Average holding period (rough estimate)
            if daily_turnover > 0:
                metrics["Avg Holding Period (days)"] = 1.0 / daily_turnover

            # Rebalancing frequency
            significant_rebalances = np.sum(weight_changes > 0.1)  # >10% change
            metrics["Major Rebalances"] = significant_rebalances
            metrics["Rebalancing Frequency (per year)"] = (
                significant_rebalances / self.num_days
            ) * self.trading_days

            # Concentration (Herfindahl index)
            avg_weights = np.mean(self.weights_history, axis=0)
            herfindahl = np.sum(avg_weights**2)
            metrics["Concentration (HHI)"] = herfindahl
            metrics["Effective # Assets"] = 1.0 / herfindahl if herfindahl > 0 else 0

        return metrics

    # ============================================================================
    # TIER 4 - TRADING DRAG
    # ============================================================================

    def calculate_tier4_metrics(self, transaction_cost=0.0002):
        """Transaction costs and implementation drag"""
        metrics = {}

        if self.weights_history is not None and len(self.weights_history) > 1:
            # Estimate transaction costs
            weight_changes = np.sum(
                np.abs(np.diff(self.weights_history, axis=0)), axis=1
            )
            total_turnover = np.sum(weight_changes)
            estimated_costs = total_turnover * transaction_cost

            metrics["Est. Transaction Costs"] = (
                estimated_costs * 100
            )  # as percentage of initial capital
            metrics["Cost Drag (annual)"] = (estimated_costs / self.years) * 100

            # Cash allocation
            cash_weights = self.weights_history[:, 0]  # Assume first asset is cash
            metrics["Avg Cash Allocation"] = np.mean(cash_weights) * 100
            metrics["Max Cash Allocation"] = np.max(cash_weights) * 100

        return metrics

    # ============================================================================
    # VISUALIZATION
    # ============================================================================

    def plot_tier1_charts(self, save_dir="./analysis"):
        """Essential charts for long-term investing"""
        import os

        os.makedirs(save_dir, exist_ok=True)

        fig = plt.figure(figsize=(16, 10))

        # 1. Portfolio Value (log scale)
        ax1 = plt.subplot(2, 2, 1)
        ax1.semilogy(
            self.dates,
            self.portfolio_values,
            linewidth=2,
            label="RAT Portfolio",
            color="#2E86AB",
        )
        ax1.semilogy(
            self.dates,
            self.benchmark_values,
            linewidth=2,
            label="Benchmark (SPY)",
            color="#A23B72",
            linestyle="--",
            alpha=0.7,
        )
        ax1.axhline(y=1.0, color="gray", linestyle=":", linewidth=1)
        ax1.set_title(
            "Portfolio Value Over Time (Log Scale)", fontsize=12, fontweight="bold"
        )
        ax1.set_xlabel("Date")
        ax1.set_ylabel("Portfolio Value")
        ax1.legend()
        ax1.grid(alpha=0.3)

        # 2. Drawdown Curve
        ax2 = plt.subplot(2, 2, 2)
        peak = np.maximum.accumulate(self.portfolio_values)
        drawdown = (self.portfolio_values - peak) / peak * 100
        ax2.fill_between(self.dates, drawdown, 0, alpha=0.3, color="red")
        ax2.plot(self.dates, drawdown, linewidth=1.5, color="darkred")
        ax2.set_title("Drawdown Over Time", fontsize=12, fontweight="bold")
        ax2.set_xlabel("Date")
        ax2.set_ylabel("Drawdown (%)")
        ax2.grid(alpha=0.3)

        # 3. Rolling 1-Year Returns
        ax3 = plt.subplot(2, 2, 3)
        rolling_window = self.trading_days
        if len(self.returns) >= rolling_window:
            rolling_returns = []
            rolling_dates = []
            for i in range(len(self.returns) - rolling_window):
                window_pv = self.portfolio_values[i : i + rolling_window + 1]
                annual_return = (window_pv[-1] / window_pv[0] - 1) * 100
                rolling_returns.append(annual_return)
                rolling_dates.append(self.dates[i + rolling_window])

            ax3.plot(rolling_dates, rolling_returns, linewidth=2, color="#F18F01")
            ax3.axhline(y=0, color="gray", linestyle="--", linewidth=1)
            ax3.set_title("Rolling 1-Year Returns", fontsize=12, fontweight="bold")
            ax3.set_xlabel("Date")
            ax3.set_ylabel("1-Year Return (%)")
            ax3.grid(alpha=0.3)

        # 4. Monthly Returns Heatmap
        ax4 = plt.subplot(2, 2, 4)
        monthly_returns = self.calculate_monthly_returns()
        if len(monthly_returns) > 0:
            # Pivot to heatmap format
            monthly_df = pd.DataFrame(monthly_returns)
            monthly_df["Year"] = monthly_df["Date"].dt.year
            monthly_df["Month"] = monthly_df["Date"].dt.month
            pivot = monthly_df.pivot(index="Year", columns="Month", values="Return")

            sns.heatmap(
                pivot,
                annot=True,
                fmt=".1f",
                cmap="RdYlGn",
                center=0,
                cbar_kws={"label": "Return (%)"},
                ax=ax4,
                linewidths=0.5,
            )
            ax4.set_title("Monthly Returns Heatmap", fontsize=12, fontweight="bold")
            ax4.set_xlabel("Month")
            ax4.set_ylabel("Year")

        plt.tight_layout()
        plt.savefig(f"{save_dir}/tier1_analysis.png", dpi=150, bbox_inches="tight")
        print(f"✅ Saved: {save_dir}/tier1_analysis.png")
        plt.close()

    def plot_tier2_charts(self, save_dir="./analysis"):
        """Risk-adjusted quality charts"""
        import os

        os.makedirs(save_dir, exist_ok=True)

        fig = plt.figure(figsize=(16, 10))

        # 1. Rolling Sharpe Ratio
        ax1 = plt.subplot(2, 2, 1)
        rolling_window = self.trading_days
        if len(self.returns) >= rolling_window:
            rolling_sharpes = []
            rolling_dates = []
            for i in range(len(self.returns) - rolling_window + 1):
                window_returns = self.returns[i : i + rolling_window]
                if np.std(window_returns) > 0:
                    sharpe = (np.mean(window_returns) / np.std(window_returns)) * np.sqrt(
                        self.trading_days
                    )
                    rolling_sharpes.append(sharpe)
                    rolling_dates.append(self.dates[i + rolling_window - 1])

            ax1.plot(rolling_dates, rolling_sharpes, linewidth=2, color="#6A4C93")
            ax1.axhline(
                y=2.0, color="green", linestyle="--", linewidth=1, label="Target (2.0)"
            )
            ax1.axhline(
                y=1.0,
                color="orange",
                linestyle="--",
                linewidth=1,
                label="Acceptable (1.0)",
            )
            ax1.set_title("Rolling 1-Year Sharpe Ratio", fontsize=12, fontweight="bold")
            ax1.set_xlabel("Date")
            ax1.set_ylabel("Sharpe Ratio")
            ax1.legend()
            ax1.grid(alpha=0.3)

        # 2. Return Distribution
        ax2 = plt.subplot(2, 2, 2)
        ax2.hist(
            self.returns * 100, bins=50, alpha=0.7, color="#1982C4", edgecolor="black"
        )
        ax2.axvline(
            x=np.mean(self.returns) * 100,
            color="red",
            linestyle="--",
            linewidth=2,
            label=f"Mean: {np.mean(self.returns) * 100:.3f}%",
        )
        ax2.axvline(x=0, color="gray", linestyle="-", linewidth=1)
        ax2.set_title("Daily Returns Distribution", fontsize=12, fontweight="bold")
        ax2.set_xlabel("Daily Return (%)")
        ax2.set_ylabel("Frequency")
        ax2.legend()
        ax2.grid(alpha=0.3)

        # 3. Upside vs Downside Capture
        ax3 = plt.subplot(2, 2, 3)
        if self.benchmark_returns is not None:
            # Split into up/down market days
            up_days = self.benchmark_returns > 0
            down_days = self.benchmark_returns < 0

            upside_capture = (
                np.mean(self.returns[up_days])
                / np.mean(self.benchmark_returns[up_days])
            ) * 100
            downside_capture = (
                np.mean(self.returns[down_days])
                / np.mean(self.benchmark_returns[down_days])
            ) * 100

            categories = ["Upside Capture", "Downside Capture"]
            values = [upside_capture, downside_capture]
            colors = ["green" if v > 100 else "red" for v in values]

            bars = ax3.bar(
                categories, values, color=colors, alpha=0.7, edgecolor="black"
            )
            ax3.axhline(
                y=100,
                color="black",
                linestyle="--",
                linewidth=1,
                label="100% (Benchmark)",
            )
            ax3.set_title(
                "Upside/Downside Capture Ratio", fontsize=12, fontweight="bold"
            )
            ax3.set_ylabel("Capture Ratio (%)")
            ax3.legend()
            ax3.grid(alpha=0.3, axis="y")

            # Add value labels
            for bar in bars:
                height = bar.get_height()
                ax3.text(
                    bar.get_x() + bar.get_width() / 2.0,
                    height,
                    f"{height:.1f}%",
                    ha="center",
                    va="bottom",
                )

        # 4. Drawdown Duration Analysis
        ax4 = plt.subplot(2, 2, 4)
        peak = np.maximum.accumulate(self.portfolio_values)
        drawdown = (self.portfolio_values - peak) / peak

        # Find all drawdown periods
        in_drawdown = drawdown < -0.01  # >1% drawdown
        drawdown_periods = []
        start_idx = None

        for i, is_dd in enumerate(in_drawdown):
            if is_dd and start_idx is None:
                start_idx = i
            elif not is_dd and start_idx is not None:
                duration = i - start_idx
                max_dd_in_period = np.min(drawdown[start_idx:i]) * 100
                drawdown_periods.append(
                    {"duration": duration, "depth": max_dd_in_period}
                )
                start_idx = None

        if len(drawdown_periods) > 0:
            durations = [dp["duration"] for dp in drawdown_periods]
            depths = [abs(dp["depth"]) for dp in drawdown_periods]

            scatter = ax4.scatter(
                durations,
                depths,
                s=100,
                alpha=0.6,
                c=depths,
                cmap="Reds",
                edgecolors="black",
                linewidths=1,
            )
            ax4.set_title("Drawdown Duration vs Depth", fontsize=12, fontweight="bold")
            ax4.set_xlabel("Duration (days)")
            ax4.set_ylabel("Depth (%)")
            ax4.grid(alpha=0.3)
            plt.colorbar(scatter, ax=ax4, label="Drawdown Depth (%)")

        plt.tight_layout()
        plt.savefig(f"{save_dir}/tier2_analysis.png", dpi=150, bbox_inches="tight")
        print(f"✅ Saved: {save_dir}/tier2_analysis.png")
        plt.close()

    def plot_tier3_charts(self, save_dir="./analysis"):
        """Portfolio behavior charts"""
        if self.weights_history is None:
            print("⚠️ No weights history available for Tier 3 analysis")
            return

        import os

        os.makedirs(save_dir, exist_ok=True)

        fig = plt.figure(figsize=(16, 10))

        # 1. Asset Allocation Over Time (Stacked Area)
        ax1 = plt.subplot(2, 2, 1)
        num_assets = self.weights_history.shape[1]

        # Create stacked area plot
        ax1.stackplot(
            self.dates,
            self.weights_history.T,
            alpha=0.7,
            labels=[f"Asset {i}" if i > 0 else "Cash" for i in range(num_assets)],
        )
        ax1.set_title("Asset Allocation Over Time", fontsize=12, fontweight="bold")
        ax1.set_xlabel("Date")
        ax1.set_ylabel("Portfolio Weight")
        ax1.legend(loc="upper left", bbox_to_anchor=(1, 1), ncol=1)
        ax1.grid(alpha=0.3)

        # 2. Turnover Over Time
        ax2 = plt.subplot(2, 2, 2)
        weight_changes = np.sum(np.abs(np.diff(self.weights_history, axis=0)), axis=1)

        # Rolling average turnover
        window = 20  # ~1 month
        if len(weight_changes) >= window:
            rolling_turnover = pd.Series(weight_changes).rolling(window).mean()
            ax2.plot(
                self.dates[1:], rolling_turnover * 100, linewidth=2, color="#E76F51"
            )
            ax2.set_title(
                f"Rolling {window}-Day Turnover", fontsize=12, fontweight="bold"
            )
            ax2.set_xlabel("Date")
            ax2.set_ylabel("Turnover (%)")
            ax2.grid(alpha=0.3)

        # 3. Concentration Over Time
        ax3 = plt.subplot(2, 2, 3)
        # Herfindahl index over time
        hhi_over_time = np.sum(self.weights_history**2, axis=1)
        effective_assets = 1.0 / hhi_over_time

        ax3.plot(self.dates, effective_assets, linewidth=2, color="#264653")
        ax3.set_title(
            "Diversification (Effective # of Assets)", fontsize=12, fontweight="bold"
        )
        ax3.set_xlabel("Date")
        ax3.set_ylabel("Effective # of Assets")
        ax3.grid(alpha=0.3)

        # 4. Top Holdings Distribution
        ax4 = plt.subplot(2, 2, 4)
        avg_weights = np.mean(self.weights_history, axis=0)
        sorted_indices = np.argsort(avg_weights)[::-1]
        top_k = min(10, len(avg_weights))

        top_weights = avg_weights[sorted_indices[:top_k]]
        labels = [f"Asset {i}" if i > 0 else "Cash" for i in sorted_indices[:top_k]]

        ax4.bar(
            range(top_k),
            top_weights * 100,
            color="#8AB17D",
            edgecolor="black",
            alpha=0.7,
        )
        ax4.set_title(
            f"Top {top_k} Holdings (Average Weight)", fontsize=12, fontweight="bold"
        )
        ax4.set_xlabel("Asset Rank")
        ax4.set_ylabel("Average Weight (%)")
        ax4.set_xticks(range(top_k))
        ax4.set_xticklabels(labels, rotation=45, ha="right")
        ax4.grid(alpha=0.3, axis="y")

        plt.tight_layout()
        plt.savefig(f"{save_dir}/tier3_analysis.png", dpi=150, bbox_inches="tight")
        print(f"✅ Saved: {save_dir}/tier3_analysis.png")
        plt.close()

    # ============================================================================
    # COMPREHENSIVE REPORT
    # ============================================================================

    def generate_report(self, save_dir="./analysis"):
        """Generate comprehensive analysis report"""
        import os

        os.makedirs(save_dir, exist_ok=True)

        print("\n" + "=" * 80)
        print("LONG-TERM INVESTMENT ANALYSIS REPORT")
        print("=" * 80)

        # Tier 1
        print("\n📊 TIER 1 - CORE PERFORMANCE METRICS")
        print("-" * 80)
        tier1 = self.calculate_tier1_metrics()
        for key, value in tier1.items():
            if "Duration" in key:
                print(f"  {key:<35} {value:>10.0f} days")
            else:
                print(f"  {key:<35} {value:>10.2f}%")

        # Tier 2
        print("\n🎯 TIER 2 - RISK-ADJUSTED QUALITY")
        print("-" * 80)
        tier2 = self.calculate_tier2_metrics()
        for key, value in tier2.items():
            if "Ratio" in key:
                print(f"  {key:<35} {value:>10.2f}")
            else:
                print(f"  {key:<35} {value:>10.2f}%")

        # Tier 3
        print("\n📈 TIER 3 - PORTFOLIO BEHAVIOR")
        print("-" * 80)
        tier3 = self.calculate_tier3_metrics()
        if len(tier3) > 0:
            for key, value in tier3.items():
                if "HHI" in key or "Effective" in key or "Frequency" in key:
                    print(f"  {key:<35} {value:>10.2f}")
                elif "Period" in key:
                    print(f"  {key:<35} {value:>10.1f} days")
                else:
                    print(f"  {key:<35} {value:>10.2f}%")
        else:
            print("  No weight history available")

        # Tier 4
        print("\n💰 TIER 4 - TRADING DRAG")
        print("-" * 80)
        tier4 = self.calculate_tier4_metrics()
        if len(tier4) > 0:
            for key, value in tier4.items():
                print(f"  {key:<35} {value:>10.2f}%")
        else:
            print("  No weight history available")

        # Assessment
        print("\n" + "=" * 80)
        print("ASSESSMENT")
        print("=" * 80)

        # Check targets
        sr_target = tier1["Sharpe Ratio"] >= 2.0
        cr_target = tier2["Calmar Ratio"] >= 1.5
        dd_target = abs(tier1["Max Drawdown"]) <= 15.0
        alpha_positive = tier1["Alpha"] > 0

        print(f"  Sharpe Ratio > 2.0:        {'✅ YES' if sr_target else '❌ NO'}")
        print(f"  Calmar Ratio > 1.5:        {'✅ YES' if cr_target else '❌ NO'}")
        print(f"  Max Drawdown < 15%:        {'✅ YES' if dd_target else '❌ NO'}")
        print(f"  Positive Alpha:            {'✅ YES' if alpha_positive else '❌ NO'}")

        all_targets = sr_target and cr_target and dd_target and alpha_positive

        print("\n" + "=" * 80)
        if all_targets:
            print("🎉 EXCELLENT! All targets met. Ready for deployment.")
        else:
            print("⚠️  Some targets not met. Consider fine-tuning hyperparameters.")
        print("=" * 80 + "\n")

        # Save to CSV
        all_metrics = {**tier1, **tier2, **tier3, **tier4}
        df = pd.DataFrame([all_metrics])
        df.to_csv(f"{save_dir}/metrics_summary.csv", index=False)
        print(f"✅ Saved metrics to: {save_dir}/metrics_summary.csv\n")

        return all_metrics

    def calculate_monthly_returns(self):
        """Calculate monthly returns for heatmap"""
        monthly_data = []
        current_month_start = 0
        current_month = self.dates[0].month
        current_year = self.dates[0].year

        for i, date in enumerate(self.dates):
            if date.month != current_month or date.year != current_year:
                # Calculate return for completed month
                month_return = (
                    self.portfolio_values[i - 1]
                    / self.portfolio_values[current_month_start]
                    - 1
                ) * 100
                monthly_data.append(
                    {"Date": self.dates[current_month_start], "Return": month_return}
                )
                current_month_start = i
                current_month = date.month
                current_year = date.year

        # Last month
        if current_month_start < len(self.dates) - 1:
            month_return = (
                self.portfolio_values[-1] / self.portfolio_values[current_month_start]
                - 1
            ) * 100
            monthly_data.append(
                {"Date": self.dates[current_month_start], "Return": month_return}
            )

        return monthly_data

    def plot_all(self, save_dir="./analysis"):
        """Generate all charts and report"""
        print("\n🎨 Generating comprehensive analysis...")
        self.plot_tier1_charts(save_dir)
        self.plot_tier2_charts(save_dir)
        self.plot_tier3_charts(save_dir)
        metrics = self.generate_report(save_dir)
        print(f"\n✅ Analysis complete! Check {save_dir}/ for all outputs.\n")
        return metrics


def analyze_from_csv(csv_path, start_date="2010/01/01"):
    """
    Convenience function to analyze from train_summary.csv

    Args:
        csv_path: Path to train_summary.csv
        start_date: Trading start date
    """
    df = pd.read_csv(csv_path)

    # Parse portfolio history
    st_v_str = df["St_v"].iloc[0]
    portfolio_values = [float(x.strip()) for x in st_v_str.split(",") if x.strip()]

    # Parse returns
    pc_str = df["backtest_test_history"].iloc[0]
    returns = [float(x.strip()) - 1.0 for x in pc_str.split(",") if x.strip()]

    # Create analyzer
    analyzer = LongTermAnalyzer(
        portfolio_values=portfolio_values,
        returns_array=returns,
        weights_history=None,  # Not available from CSV
        start_date=start_date,
    )

    return analyzer.plot_all()


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        csv_path = sys.argv[1]
        print(f"Analyzing results from: {csv_path}")
        analyze_from_csv(csv_path)
    else:
        print("Usage: python longterm_analysis.py <path_to_train_summary.csv>")
        print(
            "Example: python longterm_analysis.py log/optimized_longterm/train_summary.csv"
        )
