import sqlite3
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import time
from torch.autograd import Variable

# import matplotlib.pyplot as plt
# import seaborn
# seaborn.set_context(context="talk")
# from longterm_analysis import LongTermAnalyzer
import os

import time
from datetime import datetime
import xarray as xr

from longterm_analysis import LongTermAnalyzer


# DATABASE_DIR="/data2/kaylakxu/PGPortfolio-master/PGPortfolio-master/database/Data.db"
DATABASE_DIR = "/kaggle/input/stock-data/stock_data.db"
# About time
NOW = 0
FIVE_MINUTES = 60 * 5
FIFTEEN_MINUTES = FIVE_MINUTES * 3
HALF_HOUR = FIFTEEN_MINUTES * 2
HOUR = HALF_HOUR * 2
TWO_HOUR = HOUR * 2
FOUR_HOUR = HOUR * 4
DAY = HOUR * 24
YEAR = DAY * 365
# trading table name
TABLE_NAME = "test"


# Configuration Variables
class CONFIG:
    # Basic parameters
    total_step = 500
    x_window_size = 30
    batch_size = 128
    coin_num = 11  # DEPRECATED: Ignored, all symbols in database will be used
    feature_number = 4
    output_step = 100
    model_index = 1
    multihead_num = 2
    local_context_length = 7
    model_dim = 12

    # OSBL parameters
    use_osbl = True  # Use OSBL training
    osbl_beta = 0.05  # OSBL geometric sampling parameter
    osbl_num_batches = 4  # Number of batches per OSBL update (N_b)
    osbl_max_memory = 2000  # Max transitions in OSBL buffer
    gradient_clip = 1.0  # Gradient clipping value for stability

    # Novel modifications for online learning (Sections 3.3 & 3.5)
    use_decay_attention = True  # Use decay-aware context attention
    temporal_decay_lambda = 0.1  # Temporal decay factor for context attention
    use_correlation_matrix = True  # Use incremental asset correlation matrix
    correlation_gamma = 0.9  # Decay rate for correlation matrix EMA
    use_recency_pe = True  # Use positional encoding with recency bias
    recency_beta = 0.1  # Recency bias parameter for positional encoding
    use_ewc = True  # DISABLED: EWC freezes learning and prevents trading!
    ewc_lambda = 100.0  # EWC regularization strength (reduced from 100.0)
    ewc_update_freq = 100  # Frequency of EWC parameter updates (in periods)

    # Trading parameters - OPTIMIZED TO ENCOURAGE TRADING
    test_portion = 0.4
    validation_portion = 0.2  # Portion of data for validation set
    trading_consumption = 0.0001
    variance_penalty = 0.0  
    cost_penalty = 0.00  
    learning_rate = 0.0005  
    weight_decay = 5e-8
    daily_interest_rate = 0.001

    # Date range and paths
    start = "2016/01/01"
    end = "2025/11/30"
    model_name = "RAT"
    log_dir = ""
    model_dir = ""


FLAGS = CONFIG()


def parse_time(time_string):
    return time.mktime(datetime.strptime(time_string, "%Y/%m/%d").timetuple())


class HistoryManager:
    # if offline ,the coin_list could be None
    # NOTE: return of the sqlite results is a list of tuples, each tuple is a row
    def __init__(
        self, coin_number, end, volume_average_days=1, volume_forward=0, online=True
    ):
        self.initialize_db()
        self.__storage_period = DAY  # Daily stock data
        self._coin_number = coin_number
        self._online = online
        if self._online:
            try:
                from coin_list_module import (
                    CoinList,
                )  # Replace with the actual module name if known
            except ImportError:
                raise ImportError(
                    "Module 'coin_list_module' not found. Ensure it is installed or replace it with the correct module."
                )
            self._coin_list = CoinList(end, volume_average_days, volume_forward)
        self.__volume_forward = volume_forward
        self.__volume_average_days = volume_average_days
        self.__coins = None

    @property
    def coins(self):
        return self.__coins

    def initialize_db(self):
        with sqlite3.connect(DATABASE_DIR) as connection:
            cursor = connection.cursor()
            cursor.execute(
                "CREATE TABLE IF NOT EXISTS History (date INTEGER,"
                " symbol varchar(20), high FLOAT, low FLOAT,"
                " open FLOAT, close FLOAT, volume FLOAT, "
                " trade_count FLOAT, vwap FLOAT,"
                "PRIMARY KEY (date, symbol));"
            )
            connection.commit()

    def get_global_data_matrix(self, start, end, period=86400, features=("close",)):
        """
        :return a numpy ndarray whose axis is [feature, coin, time]
        """
        return self.get_global_panel(start, end, period, features).values

    def get_global_panel(self, start, end, period=86400, features=("close",)):
        """
        :param start/end: linux timestamp in seconds
        :param period: time interval of each data access point (not used for daily data)
        :param features: tuple or list of the feature names
        :return a panel, [feature, coin, time]
        """
        start = int(start)
        end = int(end)
        coins = self.select_coins(
            start=end - self.__volume_forward - self.__volume_average_days * DAY,
            end=end - self.__volume_forward,
        )
        self.__coins = coins
        # Update _coin_number to match actual number of symbols found
        self._coin_number = len(coins)

        for coin in coins:
            self.update_data(start, end, coin)

        print(f"Using {self._coin_number} symbols from database")
        print("feature type list is %s" % str(features))
        self.__checkperiod(period)

        # Get actual unique dates from database instead of calculating continuous range
        connection = sqlite3.connect(DATABASE_DIR)
        cursor = connection.cursor()
        cursor.execute(
            "SELECT DISTINCT date FROM History WHERE date>=? and date<=? ORDER BY date",
            (int(start), int(end)),
        )
        unique_dates = [row[0] for row in cursor.fetchall()]
        print(f"Found {len(unique_dates)} unique trading days in date range")

        # Create time index from actual dates in database
        time_index = pd.to_datetime(unique_dates, unit="s")

        # Use xarray DataArray instead of deprecated pd.Panel
        panel = xr.DataArray(
            np.zeros((len(features), len(coins), len(time_index)), dtype=np.float32),
            coords={"items": features, "major_axis": coins, "minor_axis": time_index},
            dims=["items", "major_axis", "minor_axis"],
        )
        try:
            for row_number, coin in enumerate(coins):
                for feature in features:
                    # NOTE: For daily data, date is already aligned to day boundaries
                    if feature == "close":
                        sql = (
                            "SELECT date AS date_norm, close FROM History WHERE"
                            " date>={start} and date<={end}"
                            ' and symbol="{coin}"'.format(
                                start=start, end=end, coin=coin
                            )
                        )
                    elif feature == "open":
                        sql = (
                            "SELECT date AS date_norm, open FROM History WHERE"
                            " date>={start} and date<={end}"
                            ' and symbol="{coin}"'.format(
                                start=start, end=end, coin=coin
                            )
                        )
                    elif feature == "volume":
                        sql = (
                            "SELECT date AS date_norm, volume FROM History"
                            ' WHERE date>={start} and date<={end} and symbol="{coin}"'.format(
                                start=start, end=end, coin=coin
                            )
                        )
                    elif feature == "high":
                        sql = (
                            "SELECT date AS date_norm, high FROM History"
                            ' WHERE date>={start} and date<={end} and symbol="{coin}"'.format(
                                start=start, end=end, coin=coin
                            )
                        )
                    elif feature == "low":
                        sql = (
                            "SELECT date AS date_norm, low FROM History"
                            ' WHERE date>={start} and date<={end} and symbol="{coin}"'.format(
                                start=start, end=end, coin=coin
                            )
                        )
                    else:
                        msg = "The feature %s is not supported" % feature
                        print(msg)
                        raise ValueError(msg)
                    serial_data = pd.read_sql_query(
                        sql,
                        con=connection,
                        parse_dates=["date_norm"],
                        index_col="date_norm",
                    )
                    # Use xarray indexing
                    panel.loc[
                        dict(
                            items=feature, major_axis=coin, minor_axis=serial_data.index
                        )
                    ] = serial_data.squeeze().values
                    panel = panel_fillna(panel, "both")
        finally:
            connection.commit()
            connection.close()
        return panel

    # select all symbols/stocks from the database
    def select_coins(self, start, end):
        if not self._online:
            print(
                "select all symbols from %s to %s"
                % (
                    datetime.fromtimestamp(start).strftime("%Y-%m-%d %H:%M"),
                    datetime.fromtimestamp(end).strftime("%Y-%m-%d %H:%M"),
                )
            )
            connection = sqlite3.connect(DATABASE_DIR)
            try:
                cursor = connection.cursor()
                # Select all distinct symbols that have data in the date range
                cursor.execute(
                    "SELECT DISTINCT symbol FROM History WHERE"
                    " date>=? and date<=? ORDER BY symbol;",
                    (int(start), int(end)),
                )
                coins_tuples = cursor.fetchall()
                print(f"Found {len(coins_tuples)} symbols in database")
            finally:
                connection.commit()
                connection.close()
            coins = []
            for tuple in coins_tuples:
                coins.append(tuple[0])
        else:
            # For online mode, still use all available coins
            coins = list(self._coin_list.allActiveCoins.index)
        print("Selected symbols are: " + str(coins))
        return coins

    def __checkperiod(self, period):
        if period == FIVE_MINUTES:
            return
        elif period == FIFTEEN_MINUTES:
            return
        elif period == HALF_HOUR:
            return
        elif period == TWO_HOUR:
            return
        elif period == FOUR_HOUR:
            return
        elif period == DAY:
            return
        else:
            raise ValueError("peroid has to be 5min, 15min, 30min, 2hr, 4hr, or a day")

    # add new history data into the database
    def update_data(self, start, end, coin):
        connection = sqlite3.connect(DATABASE_DIR)
        try:
            cursor = connection.cursor()
            min_date = cursor.execute(
                "SELECT MIN(date) FROM History WHERE symbol=?;", (coin,)
            ).fetchall()[0][0]
            max_date = cursor.execute(
                "SELECT MAX(date) FROM History WHERE symbol=?;", (coin,)
            ).fetchall()[0][0]

            if min_date == None or max_date == None:
                self.__fill_data(start, end, coin, cursor)
            else:
                # if max_date + 10 * self.__storage_period < end:
                #     if not self._online:
                #         raise Exception("Have to be online")
                #     self.__fill_data(
                #         max_date + self.__storage_period, end, coin, cursor
                #     )
                # if min_date > start and self._online:
                #     self.__fill_data(
                #         start, min_date - self.__storage_period - 1, coin, cursor
                #     )
                if max_date + 10 * self.__storage_period < end:
                    if not self._online:
                        print(
                            f"Warning: {coin} data ends at {datetime.fromtimestamp(max_date).strftime('%Y-%m-%d %H:%M')}, requested end is {datetime.fromtimestamp(end).strftime('%Y-%m-%d %H:%M')}"
                        )
                        # Continue anyway with available data
                        pass
                    else:
                        self.__fill_data(
                            max_date + self.__storage_period, end, coin, cursor
                        )

            # if there is no data
        finally:
            connection.commit()
            connection.close()

    def __fill_data(self, start, end, coin, cursor):
        duration = 7819200  # three months
        bk_start = start
        for bk_end in range(start + duration - 1, end, duration):
            self.__fill_part_data(bk_start, bk_end, coin, cursor)
            bk_start += duration
        if bk_start < end:
            self.__fill_part_data(bk_start, end, coin, cursor)

    def __fill_part_data(self, start, end, coin, cursor):
        chart = self._coin_list.get_chart_until_success(
            pair=self._coin_list.allActiveCoins.at[coin, "pair"],
            start=start,
            end=end,
            period=self.__storage_period,
        )
        print(
            "fill %s data from %s to %s"
            % (
                coin,
                datetime.fromtimestamp(start).strftime("%Y-%m-%d %H:%M"),
                datetime.fromtimestamp(end).strftime("%Y-%m-%d %H:%M"),
            )
        )
        for c in chart:
            if c["date"] > 0:
                # For stock data, use vwap if available, otherwise use close price
                if c.get("vwap", 0) == 0:
                    vwap = c["close"]
                else:
                    vwap = c["vwap"]

                # Stock data doesn't have reversed pairs like crypto
                cursor.execute(
                    "INSERT INTO History VALUES (?,?,?,?,?,?,?,?,?)",
                    (
                        c["date"],
                        coin,
                        c["high"],
                        c["low"],
                        c["open"],
                        c["close"],
                        c["volume"],
                        c.get(
                            "trade_count", 0
                        ),  # trade_count may not always be present
                        vwap,
                    ),
                )


def get_type_list(feature_number):
    """
    :param feature_number: an int indicates the number of features
    :return: a list of features n
    """
    if feature_number == 1:
        type_list = ["close"]
    elif feature_number == 2:
        type_list = ["close", "volume"]
        raise NotImplementedError("the feature volume is not supported currently")
    elif feature_number == 3:
        type_list = ["close", "high", "low"]
    elif feature_number == 4:
        type_list = ["close", "high", "low", "open"]
    else:
        raise ValueError("feature number could not be %s" % feature_number)
    return type_list


def get_volume_forward(time_span, portion, portion_reversed):
    volume_forward = 0
    if not portion_reversed:
        volume_forward = time_span * portion
    return volume_forward


def panel_fillna(panel, type="bfill"):
    """
    fill nan along the 3rd axis
    :param panel: the xarray DataArray to be filled
    :param type: bfill or ffill
    """
    if type == "both":
        # Forward fill then backward fill
        panel = panel.ffill(dim="minor_axis").bfill(dim="minor_axis")
    elif type == "bfill":
        panel = panel.bfill(dim="minor_axis")
    else:
        panel = panel.ffill(dim="minor_axis")
    return panel


class DataMatrices:
    def __init__(
        self,
        start,
        end,
        period,
        batch_size=50,
        volume_average_days=30,
        buffer_bias_ratio=0,
        market="poloniex",
        coin_filter=1,
        window_size=50,
        feature_number=3,
        test_portion=0.15,
        validation_portion=0.0,
        portion_reversed=False,
        online=False,
        is_permed=False,
        use_osbl=False,
        osbl_beta=0.1,
        osbl_max_memory=10000,
    ):
        """
        :param start: Unix time
        :param end: Unix time
        :param access_period: the data access period of the input matrix.
        :param trade_period: the trading period of the agent.
        :param global_period: the data access period of the global price matrix.
                              if it is not equal to the access period, there will be inserted observations
        :param coin_filter: ignored - all symbols in database will be used
        :param window_size: periods of input data
        :param train_portion: portion of training set
        :param is_permed: if False, the sample inside a mini-batch is in order
        :param validation_portion: portion of cross-validation set
        :param test_portion: portion of test set
        :param portion_reversed: if False, the order to sets are [train, validation, test]
        else the order is [test, validation, train]
        """
        start = int(start)
        self.__end = int(end)

        # assert window_size >= MIN_NUM_PERIOD
        # coin_filter is ignored - we'll use all symbols from database
        type_list = get_type_list(feature_number)
        self.__features = type_list
        self.feature_number = feature_number
        volume_forward = get_volume_forward(
            self.__end - start, test_portion, portion_reversed
        )
        self.__history_manager = HistoryManager(
            coin_number=coin_filter,  # This will be updated by select_coins
            end=self.__end,
            volume_average_days=volume_average_days,
            volume_forward=volume_forward,
            online=online,
        )
        if market == "poloniex":
            self.__global_data = self.__history_manager.get_global_panel(
                start, self.__end, period=period, features=type_list
            )
        else:
            raise ValueError("market {} is not valid".format(market))

        # Update __coin_no to actual number of coins from database
        self.__coin_no = self.__history_manager._coin_number

        self.__period_length = period
        # portfolio vector memory, [time, assets]
        self.__PVM = pd.DataFrame(
            index=self.__global_data.coords["minor_axis"].values,
            columns=self.__global_data.coords["major_axis"].values,
        )
        self.__PVM = self.__PVM.fillna(1.0 / self.__coin_no)

        self._window_size = window_size
        self._num_periods = len(self.__global_data.coords["minor_axis"])
        self.__divide_data(test_portion, validation_portion, portion_reversed)

        self._portion_reversed = portion_reversed
        self.__is_permed = is_permed
        self.__use_osbl = use_osbl

        self.__batch_size = batch_size
        self.__delta = 0  # the count of global increased
        end_index = self._train_ind[-1]

        # Initialize replay buffer based on mode (OSBL or standard)
        if use_osbl:
            print(
                f"Initializing OSBL Replay Buffer with beta={osbl_beta}, max_memory={osbl_max_memory}"
            )
            self.__replay_buffer = OSBLReplayBuffer(
                start_index=self._train_ind[0],
                end_index=end_index,
                batch_size=self.__batch_size,
                coin_number=self.__coin_no,
                beta=osbl_beta,
                max_memory=osbl_max_memory,
            )
        else:
            print("Initializing Standard Replay Buffer")
            self.__replay_buffer = ReplayBuffer(
                start_index=self._train_ind[0],
                end_index=end_index,
                sample_bias=buffer_bias_ratio,
                batch_size=self.__batch_size,
                coin_number=self.__coin_no,
                is_permed=self.__is_permed,
            )

        print(
            "the number of training examples is %s"
            ", of test examples is %s"
            % (self._num_train_samples, self._num_test_samples)
        )
        print(
            "the training set is from %s to %s"
            % (min(self._train_ind), max(self._train_ind))
        )
        print(
            "the test set is from %s to %s" % (min(self._test_ind), max(self._test_ind))
        )

    @property
    def global_weights(self):
        return self.__PVM

    @staticmethod
    def create_from_config(config):
        """main method to create the DataMatrices in this project
        @:param config: config dictionary
        @:return: a DataMatrices object
        """
        config = config.copy()
        input_config = config["input"]
        train_config = config["training"]
        start = parse_time(input_config["start_date"])
        end = parse_time(input_config["end_date"])
        return DataMatrices(
            start=start,
            end=end,
            market=input_config["market"],
            feature_number=input_config["feature_number"],
            window_size=input_config["window_size"],
            online=input_config["online"],
            period=input_config["global_period"],
            coin_filter=input_config["coin_number"],
            is_permed=input_config["is_permed"],
            buffer_bias_ratio=train_config["buffer_biased"],
            batch_size=train_config["batch_size"],
            volume_average_days=input_config["volume_average_days"],
            test_portion=input_config["test_portion"],
            portion_reversed=input_config["portion_reversed"],
        )

    @property
    def global_matrix(self):
        return self.__global_data

    @property
    def coin_list(self):
        return self.__history_manager.coins

    @property
    def num_train_samples(self):
        return self._num_train_samples

    @property
    def test_indices(self):
        return self._test_ind[: -(self._window_size + 1) :]

    @property
    def num_test_samples(self):
        return self._num_test_samples

    @property
    def validation_indices(self):
        return self._validation_ind

    @property
    def num_validation_samples(self):
        return self._num_validation_samples

    @property
    def coin_number(self):
        return self.__coin_no

    def append_experience(self, online_w=None):
        """
        :param online_w: (number of assets + 1, ) numpy array
        Let it be None if in the backtest case.
        """
        self.__delta += 1
        self._train_ind.append(self._train_ind[-1] + 1)
        appended_index = self._train_ind[-1]
        self.__replay_buffer.append_experience(appended_index)

    def get_test_set(self):
        return self.__pack_samples(self.test_indices)

    def get_validation_set(self):
        return self.__pack_samples(self.validation_indices)

    def get_test_set_online(self, ind_start, ind_end, x_window_size):
        return self.__pack_samples_test_online(ind_start, ind_end, x_window_size)

    def get_training_set(self):
        return self.__pack_samples(self._train_ind[: -self._window_size])

    ##############################################################################
    def next_batch(self):
        """
        @:return: the next batch of training sample. The sample is a dictionary
        with key "X"(input data); "y"(future relative price); "last_w" a numpy array
        with shape [batch_size, assets]; "w" a list of numpy arrays list length is
        batch_size
        """
        batch = self.__pack_samples(
            [exp.state_index for exp in self.__replay_buffer.next_experience_batch()]
        )
        #        print(np.shape([exp.state_index for exp in self.__replay_buffer.next_experience_batch()]),[exp.state_index for exp in self.__replay_buffer.next_experience_batch()])
        return batch

    def __pack_samples(self, indexs):
        indexs = np.array(indexs)
        last_w = self.__PVM.values[indexs - 1, :]

        def setw(w):
            self.__PVM.iloc[indexs, :] = w

        #            print("set w index from %d-%d!" %( indexs[0],indexs[-1]))
        M = [self.get_submatrix(index) for index in indexs]
        M = np.array(M)
        X = M[:, :, :, :-1]
        # Add small epsilon to prevent division by zero
        epsilon = 1e-8
        y = M[:, :, :, -1] / (M[:, 0, None, :, -2] + epsilon)

        return {
            "X": X,
            "y": y,
            "last_w": last_w,
            "setw": setw,
        }

    def __pack_samples_test_online(self, ind_start, ind_end, x_window_size):
        #        indexs = np.array(indexs)
        last_w = self.__PVM.values[ind_start - 1 : ind_start, :]

        #        y_window_size = window_size-x_window_size
        def setw(w):
            self.__PVM.iloc[ind_start, :] = w

        #            print("set w index from %d-%d!" %( indexs[0],indexs[-1]))
        M = [self.get_submatrix_test_online(ind_start, ind_end)]  # [1,4,11,2807]
        M = np.array(M)
        X = M[:, :, :, :-1]
        # Add small epsilon to prevent division by zero
        epsilon = 1e-8
        y = M[:, :, :, x_window_size:] / (
            M[:, 0, None, :, x_window_size - 1 : -1] + epsilon
        )

        return {
            "X": X,
            "y": y,
            "last_w": last_w,
            "setw": setw,
        }

    ##############################################################################################
    def get_submatrix(self, ind):
        return self.__global_data.values[:, :, ind : ind + self._window_size + 1]

    def get_submatrix_test_online(self, ind_start, ind_end):
        return self.__global_data.values[:, :, ind_start:ind_end]

    def __divide_data(self, test_portion, validation_portion, portion_reversed):
        train_portion = 1 - test_portion - validation_portion
        s = float(train_portion + validation_portion + test_portion)

        if portion_reversed:
            # [test, validation, train]
            portions = np.array([test_portion, validation_portion]) / s
            portion_split = (portions * self._num_periods).astype(int).cumsum()
            indices = np.arange(self._num_periods)
            splits = np.split(indices, portion_split)
            self._test_ind = splits[0]
            self._validation_ind = (
                splits[1] if validation_portion > 0 else np.array([], dtype=int)
            )
            self._train_ind = splits[2] if validation_portion > 0 else splits[1]
        else:
            # [train, validation, test]
            portions = np.array([train_portion, train_portion + validation_portion]) / s
            portion_split = (portions * self._num_periods).astype(int)
            indices = np.arange(self._num_periods)
            splits = np.split(indices, portion_split)
            self._train_ind = splits[0]
            self._validation_ind = (
                splits[1] if validation_portion > 0 else np.array([], dtype=int)
            )
            self._test_ind = splits[2] if validation_portion > 0 else splits[1]

        self._train_ind = self._train_ind[: -(self._window_size + 1)]
        # NOTE(zhengyao): change the logic here in order to fit both
        # reversed and normal version
        self._train_ind = list(self._train_ind)
        self._num_train_samples = len(self._train_ind)
        self._num_validation_samples = len(self._validation_ind)
        self._num_test_samples = len(self.test_indices)

        print(
            f"Data split: Train={self._num_train_samples}, "
            f"Validation={self._num_validation_samples}, Test={self._num_test_samples}"
        )


class ReplayBuffer:
    def __init__(
        self,
        start_index,
        end_index,
        batch_size,
        is_permed,
        coin_number,
        sample_bias=1.0,
    ):
        """
        :param start_index: start index of the training set on the global data matrices
        :param end_index: end index of the training set on the global data matrices
        """
        self.__coin_number = coin_number
        self.__experiences = [Experience(i) for i in range(start_index, end_index)]
        self.__is_permed = is_permed
        # NOTE: in order to achieve the previous w feature
        self.__batch_size = batch_size
        self.__sample_bias = sample_bias
        print("buffer_bias is %f" % sample_bias)

    def append_experience(self, state_index):
        self.__experiences.append(Experience(state_index))
        print("a new experience, indexed by %d, was appended" % state_index)

    def __sample(self, start, end, bias):
        """
        @:param end: is excluded
        @:param bias: value in (0, 1)
        """
        # TODO: deal with the case when bias is 0
        ran = np.random.geometric(bias)
        while ran > end - start:
            ran = np.random.geometric(bias)
        result = end - ran
        return result

    def next_experience_batch(self):
        # First get a start point randomly
        batch = []
        if self.__is_permed:
            for i in range(self.__batch_size):
                batch.append(
                    self.__experiences[
                        self.__sample(
                            self.__experiences[0].state_index,
                            self.__experiences[-1].state_index,
                            self.__sample_bias,
                        )
                    ]
                )
        else:
            batch_start = self.__sample(
                0, len(self.__experiences) - self.__batch_size, self.__sample_bias
            )
            batch = self.__experiences[batch_start : batch_start + self.__batch_size]
        return batch


class OSBLReplayBuffer:
    """
    Online Stochastic Batch Learning (OSBL) Replay Buffer
    Implements geometric sampling: P_β(t_b) = β(1-β)^(t-t_b-n_b)
    Maintains temporal ordering within batches as required by RAT.
    """

    def __init__(
        self,
        start_index,
        end_index,
        batch_size,
        coin_number,
        beta=0.1,
        max_memory=10000,
    ):
        """
        :param start_index: start index of the training set
        :param end_index: end index of the training set
        :param batch_size: number of consecutive periods in each batch (n_b)
        :param coin_number: number of assets
        :param beta: geometric distribution parameter (higher = more recent bias)
        :param max_memory: maximum number of transitions to store (M)
        """
        self.__coin_number = coin_number
        self.__experiences = [Experience(i) for i in range(start_index, end_index)]
        self.__batch_size = batch_size
        self.__beta = beta
        self.__max_memory = max_memory
        self.__current_time = end_index
        print(f"OSBL Buffer initialized: beta={beta}, max_memory={max_memory}")

    def append_experience(self, state_index):
        """Add new experience and maintain buffer size limit"""
        self.__experiences.append(Experience(state_index))
        self.__current_time = state_index

        # Maintain max memory size - remove oldest if exceeds limit
        if len(self.__experiences) > self.__max_memory:
            self.__experiences.pop(0)

        print(
            f"OSBL: New experience at index {state_index}, buffer size: {len(self.__experiences)}"
        )

    def sample_batch_start_geometric(self):
        """
        Sample batch start position using geometric distribution.
        P_β(t_b) = β(1-β)^(t-t_b-n_b)
        Returns index in experiences list that allows for batch_size consecutive samples.
        """
        max_start = len(self.__experiences) - self.__batch_size
        if max_start <= 0:
            return 0

        # Sample from geometric distribution
        # More recent batches are exponentially more likely
        p = self.__beta
        sample = np.random.geometric(p)

        # Ensure we don't sample beyond available data
        while sample > max_start:
            sample = np.random.geometric(p)

        # Return start index (counting from end, more recent is preferred)
        batch_start = max_start - sample + 1
        return max(0, batch_start)

    def next_experience_batch(self):
        """
        Sample a batch of consecutive experiences using geometric weighting.
        Preserves temporal ordering as required by RAT's sequential processing.
        """
        if len(self.__experiences) < self.__batch_size:
            # Not enough experiences yet, return what we have
            return self.__experiences

        batch_start = self.sample_batch_start_geometric()
        batch = self.__experiences[batch_start : batch_start + self.__batch_size]
        return batch

    def sample_multiple_batches(self, num_batches):
        """
        Sample N_b mini-batches for OSBL training.
        Each batch contains n_b consecutive periods.
        Batches can overlap, providing diverse training signals.
        """
        batches = []
        for _ in range(num_batches):
            batch = self.next_experience_batch()
            batches.append(batch)
        return batches

    def get_beta(self):
        """Get current beta parameter"""
        return self.__beta

    def set_beta(self, beta):
        """Update beta parameter (for adaptive learning)"""
        self.__beta = beta
        print(f"OSBL: Updated beta to {beta}")

    def get_buffer_size(self):
        """Get current buffer size"""
        return len(self.__experiences)


class Experience:
    def __init__(self, state_index):
        self.state_index = int(state_index)


class AssetCorrelationTracker:
    """
    Incremental Asset Correlation Matrix with EMA.
    Implements Section 3.3: C_t = γC_t-1 + (1-γ)y_t y_t^T
    """

    def __init__(self, num_assets, gamma=0.9, device="cpu"):
        """
        Args:
            num_assets: Number of assets (excluding cash)
            gamma: EMA decay rate (0 < γ < 1)
            device: Torch device
        """
        self.num_assets = num_assets
        self.gamma = gamma
        self.device = device

        # Initialize correlation matrix as identity
        self.C_t = torch.eye(num_assets, device=device)
        self.initialized = False

    def update(self, returns):
        """
        Update correlation matrix with new returns.
        Args:
            returns: Tensor of shape (num_assets,) containing asset returns y_t
        """
        returns = returns.to(self.device)

        # Ensure returns is 1D
        if returns.dim() > 1:
            returns = returns.squeeze()

        # Compute outer product y_t y_t^T
        outer = torch.outer(returns, returns)

        # Update with EMA
        if not self.initialized:
            self.C_t = outer
            self.initialized = True
        else:
            self.C_t = self.gamma * self.C_t + (1 - self.gamma) * outer

    def get_correlation_bias(self):
        """
        Get correlation matrix for use as bias in relation attention.
        Returns:
            Correlation matrix C_t of shape (num_assets, num_assets)
        """
        return self.C_t.clone()

    def reset(self):
        """Reset correlation matrix to identity"""
        self.C_t = torch.eye(self.num_assets, device=self.device)
        self.initialized = False


class FisherInformationMatrix:
    """
    Computes and stores Fisher Information Matrix for EWC.
    Implements Section 3.5: F_i = E[(∂ log p(y|x;θ)/∂θ_i)²]
    """

    def __init__(self, model):
        self.model = model
        self.fisher = {}
        self.optimal_params = {}

    def compute_fisher(self, data_loader, num_samples=200):
        """
        Compute Fisher Information Matrix using sampled data.
        Args:
            data_loader: DataLoader providing training samples
            num_samples: Number of samples to use for Fisher estimation
        """
        self.fisher = {}

        # Initialize Fisher to zeros
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.fisher[name] = torch.zeros_like(param.data)

        self.model.eval()
        count = 0

        for batch in data_loader:
            if count >= num_samples:
                break

            self.model.zero_grad()

            # Forward pass to get log probabilities
            # Assuming batch contains 'src', 'tgt', 'src_mask', 'tgt_mask'
            output = self.model(
                batch["src"], batch["tgt"], batch.get("src_mask"), batch.get("tgt_mask")
            )

            # Compute log likelihood (negative loss)
            loss = -output.sum()
            loss.backward()

            # Accumulate squared gradients
            for name, param in self.model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    self.fisher[name] += param.grad.data**2

            count += 1

        # Average over samples
        for name in self.fisher:
            self.fisher[name] /= count

        self.model.train()

    def save_optimal_params(self):
        """Save current parameters as optimal θ*"""
        self.optimal_params = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.optimal_params[name] = param.data.clone()

    def get_fisher(self):
        """Get Fisher Information Matrix"""
        return self.fisher

    def get_optimal_params(self):
        """Get optimal parameters θ*"""
        return self.optimal_params


class EWCRegularizer:
    """
    Elastic Weight Consolidation regularization.
    Implements Section 3.5: L_total = L_policy + (λ_EWC/2)Σ F_i(θ_i - θ*_i)²
    """

    def __init__(self, model, lambda_ewc=100.0):
        """
        Args:
            model: Neural network model
            lambda_ewc: EWC regularization strength
        """
        self.model = model
        self.lambda_ewc = lambda_ewc
        self.fisher_matrix = FisherInformationMatrix(model)

    def compute_ewc_loss(self):
        """
        Compute EWC regularization term.
        Returns:
            EWC loss: (λ_EWC/2)Σ F_i(θ_i - θ*_i)²
        """
        fisher = self.fisher_matrix.get_fisher()
        optimal_params = self.fisher_matrix.get_optimal_params()

        if not fisher or not optimal_params:
            return 0.0

        ewc_loss = 0.0
        for name, param in self.model.named_parameters():
            if param.requires_grad and name in fisher:
                # Compute (θ_i - θ*_i)²
                param_diff = param - optimal_params[name]
                # Weight by Fisher information and sum
                ewc_loss += (fisher[name] * param_diff**2).sum()

        return (self.lambda_ewc / 2.0) * ewc_loss

    def update_fisher(self, data_loader, num_samples=200):
        """Update Fisher matrix and save current parameters as optimal"""
        self.fisher_matrix.compute_fisher(data_loader, num_samples)
        self.fisher_matrix.save_optimal_params()


class EncoderDecoder(nn.Module):
    """
    A standard Encoder-Decoder architecture. Base for this and many
    other models.
    """

    def __init__(
        self,
        batch_size,
        coin_num,
        window_size,
        feature_number,
        d_model_Encoder,
        d_model_Decoder,
        encoder,
        decoder,
        price_series_pe,
        local_price_pe,
        local_context_length,
    ):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.batch_size = batch_size
        self.coin_num = coin_num
        self.window_size = window_size
        self.feature_number = feature_number
        self.d_model_Encoder = d_model_Encoder
        self.d_model_Decoder = d_model_Decoder
        self.linear_price_series = nn.Linear(
            in_features=feature_number, out_features=d_model_Encoder
        )
        self.linear_local_price = nn.Linear(
            in_features=feature_number, out_features=d_model_Decoder
        )
        self.price_series_pe = price_series_pe
        self.local_price_pe = local_price_pe
        self.local_context_length = local_context_length
        self.linear_out = nn.Linear(in_features=1 + d_model_Encoder, out_features=1)
        self.bias = torch.nn.Parameter(torch.zeros([1, 1, 1]))

    def forward(
        self,
        price_series,
        local_price_context,
        previous_w,
        price_series_mask,
        local_price_mask,
        padding_price,
    ):  ##[4, 128, 31, 11]
        # price_series:[4,128,31,11]
        price_series = price_series / (price_series[0:1, :, -1:, :] + 1e-8)
        price_series = price_series.permute(3, 1, 2, 0)  # [4,128,31,11]->[11,128,31,4]
        price_series = price_series.contiguous().view(
            price_series.size()[0] * price_series.size()[1],
            self.window_size,
            self.feature_number,
        )  # [11,128,31,4]->[11*128,31,4]
        price_series = self.linear_price_series(
            price_series
        )  # [11*128,31,3]->[11*128,31,2*12]
        price_series = self.price_series_pe(price_series)  # [11*128,31,2*12]
        price_series = price_series.view(
            self.coin_num, -1, self.window_size, self.d_model_Encoder
        )  # [11*128,31,2*12]->[11,128,31,2*12]
        encode_out = self.encoder(price_series, price_series_mask)
        #        encode_out=self.linear_src_2_embedding(encode_out)
        ###########################padding price#######################################################################################
        if padding_price is not None:
            local_price_context = torch.cat(
                [padding_price, local_price_context], 2
            )  # [11,128,5-1,4] cat [11,128,1,4] -> [11,128,5,4]
            local_price_context = local_price_context.contiguous().view(
                local_price_context.size()[0] * price_series.size()[1],
                self.local_context_length * 2 - 1,
                self.feature_number,
            )  # [11,128,5,4]->[11*128,5,4]
        else:
            local_price_context = local_price_context.contiguous().view(
                local_price_context.size()[0] * price_series.size()[1],
                1,
                self.feature_number,
            )
        ##############Divide by close price################################
        local_price_context = local_price_context / (
            local_price_context[:, -1:, 0:1] + 1e-8
        )
        local_price_context = self.linear_local_price(
            local_price_context
        )  # [11*128,5,4]->[11*128,5,2*12]
        local_price_context = self.local_price_pe(
            local_price_context
        )  # [11*128,5,2*12]
        if padding_price is not None:
            padding_price = local_price_context[
                :, : -self.local_context_length, :
            ]  # [11*128,5-1,2*12]
            padding_price = padding_price.view(
                self.coin_num, -1, self.local_context_length - 1, self.d_model_Decoder
            )  # [11,128,5-1,2*12]
        local_price_context = local_price_context[
            :, -self.local_context_length :, :
        ]  # [11*128,5,2*12]
        local_price_context = local_price_context.view(
            self.coin_num, -1, self.local_context_length, self.d_model_Decoder
        )  # [11,128,5,2*12]
        #################################padding_price=None###########################################################################
        decode_out = self.decoder(
            local_price_context,
            encode_out,
            price_series_mask,
            local_price_mask,
            padding_price,
        )
        decode_out = decode_out.transpose(1, 0)  # [11,128,1,2*12]->#[128,11,1,2*12]
        decode_out = torch.squeeze(decode_out, 2)  # [128,11,1,2*12]->[128,11,2*12]
        previous_w = previous_w.permute(0, 2, 1)  # [128,1,11]->[128,11,1]
        out = torch.cat(
            [decode_out, previous_w], 2
        )  # [128,11,2*12]  cat [128,11,1] -> [128,11,2*12+1]
        ###################################  Decision making ##################################################
        out = self.linear_out(out)  # [128,11,2*12+1]->[128,11,1]

        bias = self.bias.repeat(out.size()[0], 1, 1)  # [128,1,1]

        out = torch.cat([bias, out], 1)  # [128,11,1] cat [128,1,1] -> [128,12,1]

        out = out.permute(0, 2, 1)  # [128,1,12]

        out = F.softmax(out, dim=-1)

        return out  # [128,1,12]


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):  # [64,10,512]
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2


class Encoder(nn.Module):
    "Core encoder is a stack of N layers"

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            #            print("Encoder:",x)
            x = layer(x, mask)
        #            print("Encoder:",x.size())
        return self.norm(x)


class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))


class EncoderLayer(nn.Module):
    "Encoder is made up of self-attn and feed forward (defined below)"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        "Follow Figure 1 (left) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask, None, None))
        return self.sublayer[1](x, self.feed_forward)


######################################Decoder############################################
class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, memory, price_series_mask, local_price_mask, padding_price):
        for layer in self.layers:
            x = layer(x, memory, price_series_mask, local_price_mask, padding_price)
        return self.norm(x)


class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, price_series_mask, local_price_mask, padding_price):
        "Follow Figure 1 (right) for connections."
        m = memory
        x = self.sublayer[0](
            x,
            lambda x: self.self_attn(
                x, x, x, local_price_mask, padding_price, padding_price
            ),
        )
        x = x[:, :, -1:, :]
        x = self.sublayer[1](
            x, lambda x: self.src_attn(x, m, m, price_series_mask, None, None)
        )
        return self.sublayer[2](x, self.feed_forward)


def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8")
    return torch.from_numpy(subsequent_mask) == 0


def attention(query, key, value, mask=None, dropout=None, correlation_bias=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)  # 64
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    # Add correlation matrix as bias if provided (Section 3.3)
    if correlation_bias is not None:
        # correlation_bias shape: (num_assets, num_assets)
        # scores shape: (batch, heads, seq_len, seq_len) or (batch, heads, assets, assets)
        # For relation attention, add correlation bias to asset-asset scores
        scores = scores + correlation_bias.unsqueeze(0).unsqueeze(0)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)  # [30, 8, 9, 9]
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        asset_atten,
        h,
        d_model,
        dropout,
        local_context_length,
        use_decay_attention=False,
        temporal_decay_lambda=0.1,
        use_correlation_matrix=False,
        correlation_tracker=None,
    ):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.local_context_length = local_context_length
        self.linears = clones(nn.Linear(d_model, d_model), 2)
        self.conv_q = nn.Conv2d(
            d_model, d_model, (1, 1), stride=1, padding=0, bias=True
        )
        self.conv_k = nn.Conv2d(
            d_model, d_model, (1, 1), stride=1, padding=0, bias=True
        )

        self.ass_linears_v = nn.Linear(d_model, d_model)
        self.ass_conv_q = nn.Conv2d(
            d_model, d_model, (1, 1), stride=1, padding=0, bias=True
        )
        self.ass_conv_k = nn.Conv2d(
            d_model, d_model, (1, 1), stride=1, padding=0, bias=True
        )

        self.attn = None
        self.attn_asset = None
        self.dropout = nn.Dropout(p=dropout)
        self.feature_weight_linear = nn.Linear(d_model, d_model)
        self.asset_atten = asset_atten

        # Novel modifications (Section 3.3)
        self.use_decay_attention = use_decay_attention
        self.use_correlation_matrix = use_correlation_matrix
        self.correlation_tracker = correlation_tracker

        # Learnable temporal decay parameter λ
        if use_decay_attention:
            self.temporal_decay_lambda = nn.Parameter(
                torch.tensor(temporal_decay_lambda)
            )
        else:
            self.temporal_decay_lambda = None

    def forward(self, query, key, value, mask, padding_price_q, padding_price_k):
        # query [4,128,1,2*12] or (4,128,31,2*12) key, value(4,128,31,2*12)
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)  # [128,1,1,31]    [128,1,1,1]
            mask = mask.repeat(
                query.size()[0], 1, 1, 1
            )  # [128*3,1,1,31]  [128*3,1,1,1]    #[9, 1, 1, 31]
            mask = mask.to(device)
        q_size0 = query.size(0)  # 11
        q_size1 = query.size(1)  # 128
        q_size2 = query.size(2)  # 31 0r 1
        q_size3 = query.size(3)  # 2*12
        key_size0 = key.size(0)
        key_size1 = key.size(1)
        key_size2 = key.size(2)
        key_size3 = key.size(3)
        ##################################query#################################################
        if padding_price_q is not None:
            padding_price_q = padding_price_q.permute(
                (1, 3, 0, 2)
            )  # [11,128,4,2*12]->[128,2*12,11,4]
            padding_q = padding_price_q
        else:
            if self.local_context_length > 1:
                padding_q = torch.zeros(
                    (q_size1, q_size3, q_size0, self.local_context_length - 1)
                ).to(device)
            else:
                padding_q = None
        query = query.permute((1, 3, 0, 2))
        if padding_q is not None:
            query = torch.cat([padding_q, query], -1)
        ##########################################context-agnostic query matrix##################################################
        # linar
        query = self.conv_q(query)
        query = query.permute((0, 2, 3, 1))  # [128,2*12,11,31+4]->[128,11,31+4,2*12]
        ########################################### local-attention ######################################################
        local_weight_q = torch.matmul(
            query[:, :, self.local_context_length - 1 :, :], query.transpose(-2, -1)
        ) / math.sqrt(q_size3)  # [128,11,31,2*12] *[128,11,2*12,31+4]->[128,11,31,31+4]
        # [128,11,31,31+4]->[128,11,1,5*31]
        local_weight_q_list = [
            F.softmax(
                local_weight_q[:, :, i : i + 1, i : i + self.local_context_length],
                dim=-1,
            )
            for i in range(q_size2)
        ]
        local_weight_q_list = torch.cat(local_weight_q_list, 3)
        # [128,11,1,5*31]->[128,11,5*31,1]
        local_weight_q_list = local_weight_q_list.permute(0, 1, 3, 2)
        # [128,11,31+4,2*12]->[128,11,5*31,2*12]
        q_list = [
            query[:, :, i : i + self.local_context_length, :] for i in range(q_size2)
        ]
        q_list = torch.cat(q_list, 2)
        # [128,11,5*31,1]*[128,11,5*31,2*12]->[128,11,5*31,2*12]
        # Apply temporal decay if enabled (Section 3.3, Equation 3)
        if self.use_decay_attention and self.temporal_decay_lambda is not None:
            # Create decay weights: exp(-λ·j) for j = 0, 1, ..., local_context_length-1
            decay_positions = torch.arange(
                self.local_context_length, dtype=torch.float32, device=query.device
            )
            # Expand for each position in the sequence
            decay_weights = torch.exp(-self.temporal_decay_lambda * decay_positions)
            # Reshape to [1, 1, local_context_length, 1]
            decay_weights = decay_weights.view(1, 1, self.local_context_length, 1)
            # Tile decay weights for all positions
            # q_list shape: [batch, heads, local_context_length * seq_len, features]
            # We need to apply decay pattern repeatedly for each window
            decay_pattern = decay_weights.repeat(1, 1, q_size2, 1)
            query = local_weight_q_list * q_list * decay_pattern
        else:
            query = local_weight_q_list * q_list
        # [128,11,5*31,2*12]->[128,11,5,31,2*12]
        query = query.contiguous().view(
            q_size1, q_size0, self.local_context_length, q_size2, q_size3
        )
        # [128,11,5,31,2*12]->[128,11,31,2*12]
        query = torch.sum(query, 2)
        # [128,11,31,2*12]->[128,2*12,11,31]
        query = query.permute((0, 3, 1, 2))
        ######################################################################################
        query = query.permute((2, 0, 3, 1))  # [128,2*12,11,31] ->[11,128,31,2*12]
        query = query.contiguous().view(
            q_size0 * q_size1, q_size2, q_size3
        )  # [11,128,31,2*12] ->[11*128,31,2*12]
        query = (
            query.contiguous()
            .view(q_size0 * q_size1, q_size2, self.h, self.d_k)
            .transpose(1, 2)
        )  # [11*128,31,2*12] ->[11*128,31,2,12]->[11*109,2,31,12]
        #####################################key#################################################
        if padding_price_k is not None:
            padding_price_k = padding_price_k.permute(
                (1, 3, 0, 2)
            )  # [11,128,4,2*12]->#[128,2*12,11,4]
            padding_k = padding_price_k
        else:
            if self.local_context_length > 1:
                padding_k = torch.zeros(
                    (key_size1, key_size3, key_size0, self.local_context_length - 1)
                ).to(device)
            else:
                padding_k = None
        key = key.permute((1, 3, 0, 2))
        if padding_k is not None:
            key = torch.cat([padding_k, key], -1)
        ##########################################context-aware key matrix############################################################################
        # linar
        key = self.conv_k(key)
        key = key.permute((0, 2, 3, 1))  # [128,2*12,11,31+4]->[128,11,31+4,2*12]
        ########################################### local-attention ##########################################################################
        local_weight_k = torch.matmul(
            key[:, :, self.local_context_length - 1 :, :], key.transpose(-2, -1)
        ) / math.sqrt(
            key_size3
        )  # [128,11,31,2*12] *[128,11,2*12,31+4]->[128,11,31,31+4]
        # [128,11,31,31+4]->[128,11,1,5*31]
        local_weight_k_list = [
            F.softmax(
                local_weight_k[:, :, i : i + 1, i : i + self.local_context_length],
                dim=-1,
            )
            for i in range(key_size2)
        ]
        local_weight_k_list = torch.cat(local_weight_k_list, 3)
        # [128,11,1,5*31]->[128,11,5*31,1]
        local_weight_k_list = local_weight_k_list.permute(0, 1, 3, 2)
        # [128,11,31+4,2*12]->[128,11,5*31,2*12]
        k_list = [
            key[:, :, i : i + self.local_context_length, :] for i in range(key_size2)
        ]
        k_list = torch.cat(k_list, 2)
        # [128,11,5*31,1]*[128,11,5*31,2*12]->[128,11,5*31,2*12]
        key = local_weight_k_list * k_list
        # [128,11,5*31,2*12]->[128,11,5,31,2*12]
        key = key.contiguous().view(
            key_size1, key_size0, self.local_context_length, key_size2, key_size3
        )
        # [128,11,5,31,2*12]->[128,11,31,2*12]
        key = torch.sum(key, 2)
        # [128,11,31,2*12]->[128,2*12,11,31]
        key = key.permute((0, 3, 1, 2))
        #        key = self.conv_k(key)
        key = key.permute((2, 0, 3, 1))
        key = key.contiguous().view(key_size0 * key_size1, key_size2, key_size3)
        key = (
            key.contiguous()
            .view(key_size0 * key_size1, key_size2, self.h, self.d_k)
            .transpose(1, 2)
        )
        ##################################################### value matrix #############################################################################
        value = value.view(
            key_size0 * key_size1, key_size2, key_size3
        )  # [4,128,31,2*12]->[4*128,31,2*12]
        nbatches = q_size0 * q_size1
        value = (
            self.linears[0](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
        )  # [11*128,31,2,12]

        ################################################ Multi-head attention ##########################################################################
        x, self.attn = attention(query, key, value, mask=None, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
        x = x.view(
            q_size0, q_size1, q_size2, q_size3
        )  # D[11,128,1,2*12] or E[11,128,31,2*12]

        ########################## Relation-attention ######################################################################
        if self.asset_atten:
            #######################################ass_query#####################################################################
            ass_query = x.permute(
                (2, 1, 0, 3)
            )  # D[11,128,1,2*12]->[1,128,11,2*12] or E[11,128,31,2*12]->[31,128,11,2*12]
            ass_query = ass_query.contiguous().view(
                q_size2 * q_size1, q_size0, q_size3
            )  # [31,128,11,2*12] -> [31*128,11,2*12]
            ass_query = (
                ass_query.contiguous()
                .view(q_size2 * q_size1, q_size0, self.h, self.d_k)
                .transpose(1, 2)
            )  # [31*109,8,11,64]
            ########################################ass_key####################################################################
            ass_key = x.permute(
                (2, 1, 0, 3)
            )  # D[11,128,1,2*12]->[1,128,11,2*12] or E[11,128,31,2*12]->[31,128,11,2*12]
            ass_key = ass_key.contiguous().view(
                q_size2 * q_size1, q_size0, q_size3
            )  # [31,128,11,2*12]->[31*128,11,2*12]
            ass_key = (
                ass_key.contiguous()
                .view(q_size2 * q_size1, q_size0, self.h, self.d_k)
                .transpose(1, 2)
            )  # [31*128,2,11,12]
            ####################################################################################################################
            ass_value = x.permute(
                (2, 1, 0, 3)
            )  # D[11,128,1,2*12]->[1,128,11,2*12] or E[11,128,31,2*12]->[31,128,11,2*12]
            ass_value = ass_value.contiguous().view(
                q_size2 * q_size1, q_size0, q_size3
            )  # [31,128,11,2*12]->[31*128,11,2*12]
            ass_value = (
                ass_value.contiguous()
                .view(q_size2 * q_size1, -1, self.h, self.d_k)
                .transpose(1, 2)
            )  # [31*128,2,11,12]
            ######################################################################################################################
            #            ass_mask=torch.ones(q_size2*q_size1,1,1,q_size0).to(device)  #[31*128,1,1,11]
            # Get correlation bias if enabled (Section 3.3, Equation 4)
            correlation_bias = None
            if self.use_correlation_matrix and self.correlation_tracker is not None:
                correlation_bias = self.correlation_tracker.get_correlation_bias()

            x, self.attn_asset = attention(
                ass_query,
                ass_key,
                ass_value,
                mask=None,
                dropout=self.dropout,
                correlation_bias=correlation_bias,
            )
            x = (
                x.transpose(1, 2)
                .contiguous()
                .view(q_size2 * q_size1, -1, self.h * self.d_k)
            )  # [31*128,11,2*12]
            x = x.view(q_size2, q_size1, q_size0, q_size3)  # [31,128,11,2*12]
            x = x.permute(2, 1, 0, 3)  # [11,128,31,2*12]
        return self.linears[-1](x)


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        #        print("ffn:",x.size())
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class PositionalEncoding(nn.Module):
    "Implement the PE function with optional recency bias (Section 3.3, Equation 5)."

    def __init__(
        self,
        d_model,
        start_indx,
        dropout,
        max_len=5000,
        use_recency_bias=False,
        recency_beta=0.1,
    ):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.start_indx = start_indx
        self.use_recency_bias = use_recency_bias

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0.0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

        # Learnable recency bias parameter β (Section 3.3)
        if use_recency_bias:
            self.recency_beta = nn.Parameter(torch.tensor(recency_beta))
        else:
            self.recency_beta = None

    def forward(self, x):
        # Get positional encoding slice
        pe_slice = self.pe[:, self.start_indx : self.start_indx + x.size(1)]

        # Apply recency bias if enabled: PE * exp(-β·(k-pos))
        if self.use_recency_bias and self.recency_beta is not None:
            seq_len = x.size(1)
            k = self.start_indx + seq_len - 1  # Most recent position
            # Create position indices
            positions = torch.arange(
                self.start_indx, self.start_indx + seq_len, device=x.device
            )
            # Compute recency weights: exp(-β·(k-pos))
            recency_weights = torch.exp(-self.recency_beta * (k - positions))
            # Reshape to (1, seq_len, 1) for broadcasting
            recency_weights = recency_weights.unsqueeze(0).unsqueeze(-1)
            # Apply recency bias
            pe_slice = pe_slice * recency_weights

        x = x + Variable(pe_slice, requires_grad=False)
        return self.dropout(x)


class NoamOpt:
    "Optim wrapper that implements rate."

    # 512, 1, 400
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p["lr"] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        if self.warmup == 0:
            return self.factor
        else:
            return self.factor * (
                self.model_size ** (-0.5)
                * min(step ** (-0.5), step * self.warmup ** (-1.5))
            )


class Batch_Loss(nn.Module):
    def __init__(
        self, commission_ratio, interest_rate, gamma=0.1, beta=0.1, size_average=True
    ):
        super(Batch_Loss, self).__init__()
        self.gamma = gamma  # variance penalty
        self.beta = beta
        self.size_average = size_average
        self.commission_ratio = commission_ratio
        self.interest_rate = interest_rate

    def forward(self, w, y):  # w:[128,1,12]   y:[128,11,4]
        close_price = y[:, :, 0:1].to(device)  #   [128,11,1]
        # future close prise (including cash)
        close_price = torch.cat(
            [torch.ones(close_price.size()[0], 1, 1).to(device), close_price], 1
        ).to(device)  # [128,11,1]cat[128,1,1]->[128,12,1]
        reward = torch.matmul(w, close_price)  # [128,1,1]
        close_price = close_price.view(
            close_price.size()[0], close_price.size()[2], close_price.size()[1]
        )  # [128,1,12]
        ###############################################################################################################
        element_reward = w * close_price
        interest = torch.zeros(element_reward.size(), dtype=torch.float).to(device)
        interest[element_reward < 0] = element_reward[element_reward < 0]
        interest = torch.sum(interest, 2).unsqueeze(2) * self.interest_rate  # [128,1,1]
        ###############################################################################################################
        future_omega = w * close_price / (reward + 1e-8)  # [128,1,12]
        wt = future_omega[:-1]  # [128,1,12]
        wt1 = w[1:]  # [128,1,12]

        # Calculate transaction cost with commission only
        pure_pc = (
            1 - torch.sum(torch.abs(wt - wt1), -1).unsqueeze(2) * self.commission_ratio
        )  # [127,1,1]
        pure_pc = pure_pc.to(device)
        pure_pc = torch.cat([torch.ones([1, 1, 1]).to(device), pure_pc], 0)  # [128,1,1]

        cost_penalty = torch.sum(torch.abs(wt - wt1), -1)
        ################## Deduct transaction fee ##################
        reward = reward * pure_pc  # reward=pv_vector
        ################## Deduct loan interest ####################
        reward = reward + interest
        portfolio_value = torch.prod(reward, 0)
        batch_loss = -torch.log(reward + 1e-8)
        #####################variance_penalty##############################
        #        variance_penalty = ((torch.log(reward)-torch.log(reward).mean())**2).mean()
        if self.size_average:
            loss = (
                batch_loss.mean()
            )  # + self.gamma*variance_penalty + self.beta*cost_penalty.mean()
            return loss, portfolio_value[0][0]
        else:
            loss = (
                batch_loss.mean()
            )  # +self.gamma*variance_penalty + self.beta*cost_penalty.mean() #(dim=0)
            return loss, portfolio_value[0][0]


class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, criterion, opt=None):
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y):
        loss, portfolio_value = self.criterion(x, y)
        if self.opt is not None:
            loss.backward()
            self.opt.step()
            self.opt.optimizer.zero_grad()
        return loss, portfolio_value


def max_drawdown(pc_array):
    """calculate the max drawdown with the portfolio changes
    @:param pc_array: all the portfolio changes during a trading process
    @:return: max drawdown
    """
    portfolio_values = []
    drawdown_list = []
    max_benefit = 0
    for i in range(pc_array.shape[0]):
        if i > 0:
            portfolio_values.append(portfolio_values[i - 1] * pc_array[i])
        else:
            portfolio_values.append(pc_array[i])
        if portfolio_values[i] > max_benefit:
            max_benefit = portfolio_values[i]
            drawdown_list.append(0.0)
        else:
            drawdown_list.append(1.0 - portfolio_values[i] / max_benefit)
    return max(drawdown_list)


class Test_Loss(nn.Module):
    def __init__(
        self, commission_ratio, interest_rate, gamma=0.1, beta=0.1, size_average=True
    ):
        super(Test_Loss, self).__init__()
        self.gamma = gamma  # variance penalty
        self.beta = beta
        self.size_average = size_average
        self.commission_ratio = commission_ratio
        self.interest_rate = interest_rate

    def forward(self, w, y):  # w:[128,10,1,12] y(128,10,11,4)
        close_price = y[:, :, :, 0:1].to(device)  #   [128,10,11,1]
        close_price = torch.cat(
            [
                torch.ones(close_price.size()[0], close_price.size()[1], 1, 1).to(
                    device
                ),
                close_price,
            ],
            2,
        ).to(device)  # [128,10,11,1]cat[128,10,1,1]->[128,10,12,1]
        reward = torch.matmul(
            w, close_price
        )  #  [128,10,1,12] * [128,10,12,1] ->[128,10,1,1]
        close_price = close_price.view(
            close_price.size()[0],
            close_price.size()[1],
            close_price.size()[3],
            close_price.size()[2],
        )  # [128,10,12,1] -> [128,10,1,12]
        ##############################################################################
        element_reward = w * close_price
        interest = torch.zeros(element_reward.size(), dtype=torch.float).to(device)
        interest[element_reward < 0] = element_reward[element_reward < 0]
        #        print("interest:",interest.size(),interest,'\r\n')
        interest = (
            torch.sum(interest, 3).unsqueeze(3) * self.interest_rate
        )  # [128,10,1,1]
        ##############################################################################
        future_omega = (
            w * close_price / (reward + 1e-8)
        )  # [128,10,1,12]*[128,10,1,12]/[128,10,1,1]
        wt = future_omega[:, :-1]  # [128, 9,1,12]
        wt1 = w[:, 1:]  # [128, 9,1,12]

        # Calculate transaction cost with commission only
        pure_pc = (
            1 - torch.sum(torch.abs(wt - wt1), -1).unsqueeze(3) * self.commission_ratio
        )  # [batch,9,1,1]
        pure_pc = pure_pc.to(device)
        pure_pc = torch.cat(
            [torch.ones([pure_pc.size()[0], 1, 1, 1]).to(device), pure_pc], 1
        )  # [batch,1,1,1] cat [batch,9,1,1] ->[batch,10,1,1]

        cost_penalty = torch.sum(torch.abs(wt - wt1), -1)  # [128, 9, 1]
        ################## Deduct transaction fee ##################
        reward = reward * pure_pc  # [128,10,1,1]*[128,10,1,1]  test: [1,2808-31,1,1]
        ################## Deduct loan interest ####################
        reward = reward + interest
        if not self.size_average:
            tst_pc_array = reward.squeeze()  # Daily portfolio changes

            # Calculate daily returns (portfolio changes - 1)
            daily_returns = tst_pc_array - 1

            # Sharpe Ratio: Annualized (Assuming 252 trading days per year for stocks)
            # SR = (Mean Daily Return * 252) / (Std Daily Return * sqrt(252))
            # Simplified: SR = Mean Daily Return / Std Daily Return * sqrt(252)
            mean_daily_return = daily_returns.mean()
            std_daily_return = daily_returns.std() + 1e-8
            SR = (mean_daily_return / std_daily_return) * torch.sqrt(
                torch.tensor(252.0)
            )

            # Calculate cumulative returns
            SN = torch.prod(reward, 1)  # Final portfolio value
            SN = SN.squeeze()

            # Build portfolio value series for MDD calculation
            St_v = []
            St = 1.0
            for k in range(reward.size()[1]):
                St *= reward[0, k, 0, 0]
                St_v.append(St.item())

            # Maximum Drawdown
            MDD = max_drawdown(tst_pc_array)

            # Calmar Ratio: Annualized Return / Maximum Drawdown
            # Annualized Return = (Final Value / Initial Value) ^ (252 / num_days) - 1
            num_days = reward.size()[1]
            annualized_return = torch.pow(SN, 252.0 / num_days) - 1
            CR = annualized_return / (MDD + 1e-8)

            # Turnover
            TO = cost_penalty.mean()
        ##############################################
        portfolio_value = torch.prod(reward, 1)  # [128,1,1]
        batch_loss = -torch.log(portfolio_value)  # [128,1,1]

        if self.size_average:
            loss = batch_loss.mean()
            return loss, portfolio_value.mean()
        else:
            loss = batch_loss.mean()
            return loss, portfolio_value[0][0][0], SR, CR, St_v, tst_pc_array, TO


class SimpleLossCompute_tst:
    "A simple loss compute and train function."

    def __init__(self, criterion, opt=None):
        self.criterion = criterion
        self.opt = opt

    def __call__(self, x, y):
        if self.opt is not None:
            loss, portfolio_value = self.criterion(x, y)
            loss.backward()
            self.opt.step()
            self.opt.optimizer.zero_grad()
            return loss, portfolio_value
        else:
            loss, portfolio_value, SR, CR, St_v, tst_pc_array, TO = self.criterion(x, y)
            return loss, portfolio_value, SR, CR, St_v, tst_pc_array, TO


def make_std_mask(local_price_context, batch_size):
    "Create a mask to hide padding and future words."
    local_price_mask = torch.ones(batch_size, 1, 1) == 1
    local_price_mask = local_price_mask & (
        subsequent_mask(local_price_context.size(-2)).type_as(local_price_mask.data)
    )
    return local_price_mask


def train_one_step(DM, x_window_size, model, loss_compute, local_context_length):
    batch = DM.next_batch()
    batch_input = batch["X"]  # (128, 4, 11, 31)
    batch_y = batch["y"]  # (128, 4, 11)
    batch_last_w = batch["last_w"]  # (128, 11)
    batch_w = batch["setw"]
    #############################################################################
    previous_w = torch.tensor(batch_last_w, dtype=torch.float).to(device)
    previous_w = torch.unsqueeze(previous_w, 1)  # [128, 11] -> [128,1,11]
    batch_input = batch_input.transpose((1, 0, 2, 3))
    batch_input = batch_input.transpose((0, 1, 3, 2))
    src = torch.tensor(batch_input, dtype=torch.float).to(device)
    price_series_mask = torch.ones(src.size()[1], 1, x_window_size) == 1  # [128, 1, 31]
    currt_price = src.permute((3, 1, 2, 0))  # [4,128,31,11]->[11,128,31,4]
    if local_context_length > 1:
        padding_price = currt_price[:, :, -(local_context_length) * 2 + 1 : -1, :]
    else:
        padding_price = None
    currt_price = currt_price[:, :, -1:, :]  # [11,128,31,4]->[11,128,1,4]
    trg_mask = make_std_mask(currt_price, src.size()[1])
    batch_y = batch_y.transpose((0, 2, 1))  # [128, 4, 11] ->#[128,11,4]
    trg_y = torch.tensor(batch_y, dtype=torch.float).to(device)
    out = model.forward(
        src, currt_price, previous_w, price_series_mask, trg_mask, padding_price
    )
    new_w = out[:, :, 1:]  # 去掉cash
    new_w = new_w[:, 0, :]  # #[109,1,11]->#[109,11]
    new_w = new_w.detach().cpu().numpy()
    batch_w(new_w)

    loss, portfolio_value = loss_compute(out, trg_y)
    return loss, portfolio_value


def train_one_step_osbl(
    DM,
    x_window_size,
    model,
    loss_compute,
    local_context_length,
    gradient_clip=1.0,
    correlation_tracker=None,
    ewc_regularizer=None,
):
    """
    OSBL training step: Process one batch with gradient clipping for stability.
    Optionally updates correlation matrix and applies EWC regularization.
    Returns loss and portfolio value for this batch.
    """
    batch = DM.next_batch()
    batch_input = batch["X"]
    batch_y = batch["y"]
    batch_last_w = batch["last_w"]
    batch_w = batch["setw"]

    # Prepare inputs
    previous_w = torch.tensor(batch_last_w, dtype=torch.float).to(device)
    previous_w = torch.unsqueeze(previous_w, 1)
    batch_input = batch_input.transpose((1, 0, 2, 3))
    batch_input = batch_input.transpose((0, 1, 3, 2))
    src = torch.tensor(batch_input, dtype=torch.float).to(device)
    price_series_mask = torch.ones(src.size()[1], 1, x_window_size) == 1
    currt_price = src.permute((3, 1, 2, 0))

    if local_context_length > 1:
        padding_price = currt_price[:, :, -(local_context_length) * 2 + 1 : -1, :]
    else:
        padding_price = None

    currt_price = currt_price[:, :, -1:, :]
    trg_mask = make_std_mask(currt_price, src.size()[1])
    batch_y = batch_y.transpose((0, 2, 1))
    trg_y = torch.tensor(batch_y, dtype=torch.float).to(device)

    # Update correlation matrix with returns if enabled (Section 3.3)
    if correlation_tracker is not None and trg_y.size(0) > 0:
        # Extract returns for all assets (excluding cash)
        # trg_y shape: (batch, assets, time) - take mean over batch and time
        returns = trg_y.mean(dim=(0, 2))  # Average returns for each asset
        if returns.size(0) > 1:  # Ensure we have multiple assets
            correlation_tracker.update(returns[1:])  # Exclude cash asset

    # Forward pass
    out = model.forward(
        src, currt_price, previous_w, price_series_mask, trg_mask, padding_price
    )

    # Update portfolio weights
    new_w = out[:, :, 1:]
    new_w = new_w[:, 0, :]
    new_w = new_w.detach().cpu().numpy()
    batch_w(new_w)

    # Compute loss and add EWC regularization if enabled (Section 3.5)
    loss, portfolio_value = loss_compute.criterion(out, trg_y)

    if ewc_regularizer is not None:
        ewc_loss = ewc_regularizer.compute_ewc_loss()
        if isinstance(ewc_loss, torch.Tensor):
            loss = loss + ewc_loss

    if loss_compute.opt is not None:
        loss.backward()
        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        loss_compute.opt.step()
        loss_compute.opt.optimizer.zero_grad()

    return loss, portfolio_value


def test_online(DM, x_window_size, model, evaluate_loss_compute, local_context_length):
    """
    Test model on the full test set.
    Returns: loss, portfolio_value, SR, CR, portfolio_values_history, returns_array, TO, weights_history
    """
    tst_batch = DM.get_test_set_online(DM._test_ind[0], DM._test_ind[-1], x_window_size)
    tst_batch_input = tst_batch["X"]
    tst_batch_y = tst_batch["y"]
    tst_batch_last_w = tst_batch["last_w"]
    tst_batch_w = tst_batch["setw"]

    tst_previous_w = torch.tensor(tst_batch_last_w, dtype=torch.float).to(device)
    tst_previous_w = torch.unsqueeze(tst_previous_w, 1)

    tst_batch_input = tst_batch_input.transpose((1, 0, 2, 3))
    tst_batch_input = tst_batch_input.transpose((0, 1, 3, 2))

    long_term_tst_src = torch.tensor(tst_batch_input, dtype=torch.float).to(device)
    #########################################################################################
    tst_src_mask = torch.ones(long_term_tst_src.size()[1], 1, x_window_size) == 1

    long_term_tst_currt_price = long_term_tst_src.permute((3, 1, 2, 0))
    long_term_tst_currt_price = long_term_tst_currt_price[:, :, x_window_size - 1 :, :]
    ###############################################################################################
    tst_trg_mask = make_std_mask(
        long_term_tst_currt_price[:, :, 0:1, :], long_term_tst_src.size()[1]
    )

    tst_batch_y = tst_batch_y.transpose((0, 3, 2, 1))
    tst_trg_y = torch.tensor(tst_batch_y, dtype=torch.float).to(device)
    tst_long_term_w = []
    tst_y_window_size = len(DM._test_ind) - x_window_size - 1 - 1
    for j in range(tst_y_window_size + 1):  # 0-9
        tst_src = long_term_tst_src[:, :, j : j + x_window_size, :]
        tst_currt_price = long_term_tst_currt_price[:, :, j : j + 1, :]
        if local_context_length > 1:
            padding_price = long_term_tst_src[
                :,
                :,
                j + x_window_size - 1 - local_context_length * 2 + 2 : j
                + x_window_size
                - 1,
                :,
            ]
            padding_price = padding_price.permute(
                (3, 1, 2, 0)
            )  # [4, 1, 2, 11] ->[11,1,2,4]
        else:
            padding_price = None
        out = model.forward(
            tst_src,
            tst_currt_price,
            tst_previous_w,  # [109,1,11]   [109, 11, 31, 3]) torch.Size([109, 11, 3]
            tst_src_mask,
            tst_trg_mask,
            padding_price,
        )
        if j == 0:
            tst_long_term_w = out.unsqueeze(0)  # [1,109,1,12]
        else:
            tst_long_term_w = torch.cat([tst_long_term_w, out.unsqueeze(0)], 0)
        out = out[:, :, 1:]  # 去掉cash #[109,1,11]
        tst_previous_w = out
    tst_long_term_w = tst_long_term_w.permute(
        1, 0, 2, 3
    )  ##[10,128,1,12]->#[128,10,1,12]

    # Get portfolio weights history (convert to numpy for saving)
    weights_history = (
        tst_long_term_w[0, :, 0, :].detach().cpu().numpy()
    )  # [num_periods, num_assets]

    tst_loss, tst_portfolio_value, SR, CR, St_v, tst_pc_array, TO = (
        evaluate_loss_compute(tst_long_term_w, tst_trg_y)
    )
    return (
        tst_loss,
        tst_portfolio_value,
        SR,
        CR,
        St_v,
        tst_pc_array,
        TO,
        weights_history,
    )


def test_net(
    DM,
    x_window_size,
    local_context_length,
    model,
    evaluate_loss_compute,
    log_dir="./log",
):
    """
    Test model on the full test set ONCE.
    Saves portfolio weights, returns, and values history to log files.

    Returns: portfolio_value, SR, CR, portfolio_values_history, returns_array, TO
    """
    print("\n" + "=" * 80)
    print("TESTING ON FULL TEST SET")
    print("=" * 80)

    model.eval()
    test_start = time.time()

    with torch.no_grad():
        (
            tst_loss,
            tst_portfolio_value,
            SR,
            CR,
            St_v,
            tst_pc_array,
            TO,
            weights_history,
        ) = test_online(
            DM,
            x_window_size,
            model,
            evaluate_loss_compute,
            local_context_length,
        )

    test_elapsed = time.time() - test_start

    # Convert tensors to numpy for logging
    if torch.is_tensor(tst_pc_array):
        returns_history = tst_pc_array.detach().cpu().numpy()
    else:
        returns_history = np.array(tst_pc_array)

    portfolio_values_history = np.array(St_v)

    # Print summary
    print("\n" + "=" * 80)
    print("🎯 TEST SET RESULTS")
    print("=" * 80)
    print(f"Final Portfolio Value: {tst_portfolio_value.item():.4f}")
    print(f"Sharpe Ratio:          {SR.item():.4f}")
    print(f"Calmar Ratio:          {CR.item():.4f}")
    print(f"Turnover:              {TO.item():.4f}")
    print(f"Test Duration:         {test_elapsed:.2f} seconds")
    print(f"Number of Test Periods: {len(returns_history)}")
    print("=" * 80)

    # Save detailed history to log files
    os.makedirs(log_dir, exist_ok=True)

    # Save portfolio weights history
    weights_file = os.path.join(log_dir, "portfolio_weights_history.csv")
    np.savetxt(
        weights_file,
        weights_history,
        delimiter=",",
        header="Portfolio weights for each period (rows=periods, cols=assets including cash)",
        comments="",
    )
    print(f"\n✓ Saved portfolio weights to: {weights_file}")

    # Save returns history
    returns_file = os.path.join(log_dir, "returns_history.csv")
    np.savetxt(
        returns_file,
        returns_history,
        delimiter=",",
        header="Portfolio returns for each period",
        comments="",
    )
    print(f"✓ Saved returns history to: {returns_file}")

    # Save portfolio values history
    values_file = os.path.join(log_dir, "portfolio_values_history.csv")
    np.savetxt(
        values_file,
        portfolio_values_history,
        delimiter=",",
        header="Cumulative portfolio value over time",
        comments="",
    )
    print(f"✓ Saved portfolio values to: {values_file}")

    # Save summary statistics
    summary_file = os.path.join(log_dir, "test_summary.txt")
    with open(summary_file, "w") as f:
        f.write("=" * 80 + "\n")
        f.write("TEST SET PERFORMANCE SUMMARY\n")
        f.write("=" * 80 + "\n\n")
        f.write(f"Final Portfolio Value: {tst_portfolio_value.item():.6f}\n")
        f.write(f"Sharpe Ratio:          {SR.item():.6f}\n")
        f.write(f"Calmar Ratio:          {CR.item():.6f}\n")
        f.write(f"Turnover:              {TO.item():.6f}\n")
        f.write(f"Test Loss:             {tst_loss.item():.6f}\n")
        f.write(f"Number of Periods:     {len(returns_history)}\n")
        f.write(f"Test Duration:         {test_elapsed:.2f} seconds\n")
        f.write("\n" + "=" * 80 + "\n")
        f.write("RETURNS STATISTICS\n")
        f.write("=" * 80 + "\n")
        f.write(f"Mean Return:           {returns_history.mean():.6f}\n")
        f.write(f"Std Return:            {returns_history.std():.6f}\n")
        f.write(f"Min Return:            {returns_history.min():.6f}\n")
        f.write(f"Max Return:            {returns_history.max():.6f}\n")
        f.write("\n" + "=" * 80 + "\n")
        f.write("PORTFOLIO STATISTICS\n")
        f.write("=" * 80 + "\n")
        f.write("Initial Value:         1.0000\n")
        f.write(f"Final Value:           {portfolio_values_history[-1]:.6f}\n")
        f.write(f"Max Value:             {portfolio_values_history.max():.6f}\n")
        f.write(f"Min Value:             {portfolio_values_history.min():.6f}\n")

    print(f"✓ Saved summary to: {summary_file}")
    print("=" * 80 + "\n")

    return tst_portfolio_value, SR, CR, St_v, tst_pc_array, TO


def validation_batch(
    DM, x_window_size, model, evaluate_loss_compute, local_context_length
):
    """Evaluate model on validation set"""
    val_batch = DM.get_validation_set()
    val_batch_input = val_batch["X"]
    val_batch_y = val_batch["y"]
    val_batch_last_w = val_batch["last_w"]
    val_batch_w = val_batch["setw"]

    val_previous_w = torch.tensor(val_batch_last_w, dtype=torch.float).to(device)
    val_previous_w = torch.unsqueeze(val_previous_w, 1)
    val_batch_input = val_batch_input.transpose((1, 0, 2, 3))
    val_batch_input = val_batch_input.transpose((0, 1, 3, 2))
    val_src = torch.tensor(val_batch_input, dtype=torch.float).to(device)
    val_src_mask = torch.ones(val_src.size()[1], 1, x_window_size) == 1
    val_currt_price = val_src.permute((3, 1, 2, 0))

    if local_context_length > 1:
        padding_price = val_currt_price[:, :, -(local_context_length) * 2 + 1 : -1, :]
    else:
        padding_price = None

    val_currt_price = val_currt_price[:, :, -1:, :]
    val_trg_mask = make_std_mask(val_currt_price, val_src.size()[1])
    val_batch_y = val_batch_y.transpose((0, 2, 1))
    val_trg_y = torch.tensor(val_batch_y, dtype=torch.float).to(device)

    val_out = model.forward(
        val_src,
        val_currt_price,
        val_previous_w,
        val_src_mask,
        val_trg_mask,
        padding_price,
    )

    val_loss, val_portfolio_value = evaluate_loss_compute(val_out, val_trg_y)
    return val_loss, val_portfolio_value


def test_batch(DM, x_window_size, model, evaluate_loss_compute, local_context_length):
    tst_batch = DM.get_test_set()
    tst_batch_input = tst_batch["X"]  # (128, 4, 11, 31)
    tst_batch_y = tst_batch["y"]
    tst_batch_last_w = tst_batch["last_w"]
    tst_batch_w = tst_batch["setw"]

    tst_previous_w = torch.tensor(tst_batch_last_w, dtype=torch.float).to(device)
    tst_previous_w = torch.unsqueeze(tst_previous_w, 1)  # [2426, 1, 11]
    tst_batch_input = tst_batch_input.transpose((1, 0, 2, 3))
    tst_batch_input = tst_batch_input.transpose((0, 1, 3, 2))
    tst_src = torch.tensor(tst_batch_input, dtype=torch.float).to(device)
    tst_src_mask = torch.ones(tst_src.size()[1], 1, x_window_size) == 1  # [128, 1, 31]
    tst_currt_price = tst_src.permute((3, 1, 2, 0))  # (4,128,31,11)->(11,128,31,3)
    #############################################################################
    if local_context_length > 1:
        padding_price = tst_currt_price[
            :, :, -(local_context_length) * 2 + 1 : -1, :
        ]  # (11,128,8,4)
    else:
        padding_price = None
    #########################################################################

    tst_currt_price = tst_currt_price[:, :, -1:, :]  # (11,128,31,4)->(11,128,1,4)
    tst_trg_mask = make_std_mask(tst_currt_price, tst_src.size()[1])
    tst_batch_y = tst_batch_y.transpose((0, 2, 1))  # (128, 4, 11) ->(128,11,4)
    tst_trg_y = torch.tensor(tst_batch_y, dtype=torch.float).to(device)
    ###########################################################################################################
    tst_out = model.forward(
        tst_src,
        tst_currt_price,
        tst_previous_w,  # [128,1,11]   [128, 11, 31, 4])
        tst_src_mask,
        tst_trg_mask,
        padding_price,
    )

    tst_loss, tst_portfolio_value = evaluate_loss_compute(tst_out, tst_trg_y)
    return tst_loss, tst_portfolio_value


def train_net(
    DM,
    total_step,
    output_step,
    x_window_size,
    local_context_length,
    model,
    model_dir,
    model_index,
    loss_compute,
    evaluate_loss_compute,
    is_trn=True,
    evaluate=True,
):
    "Standard Training and Logging Function"
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    ####每个epoch开始时previous_w=0
    max_tst_portfolio_value = 0
    for i in range(total_step):
        if is_trn:
            model.train()
            loss, portfolio_value = train_one_step(
                DM, x_window_size, model, loss_compute, local_context_length
            )
            total_loss += loss.item()
        if i % output_step == 0 and is_trn:
            elapsed = time.time() - start
            print(
                "Epoch Step: %d| Loss per batch: %f| Portfolio_Value: %f | batch per Sec: %f \r\n"
                % (i, loss.item(), portfolio_value.item(), output_step / elapsed)
            )
            start = time.time()
        #########################################################tst########################################################
        tst_total_loss = 0
        with torch.no_grad():
            if i % output_step == 0 and evaluate:
                model.eval()

                # Use validation set if available, otherwise fall back to test set
                if DM.num_validation_samples > 0:
                    # Evaluate on VALIDATION set
                    eval_loss, eval_portfolio_value = validation_batch(
                        DM,
                        x_window_size,
                        model,
                        evaluate_loss_compute,
                        local_context_length,
                    )
                    tst_total_loss += eval_loss.item()
                    elapsed = time.time() - start
                    print(
                        "Validation: %d Loss: %f| Portfolio_Value: %f | validation per Sec: %f \r\n"
                        % (
                            i,
                            eval_loss.item(),
                            eval_portfolio_value.item(),
                            1 / elapsed,
                        )
                    )
                    start = time.time()

                    if eval_portfolio_value > max_tst_portfolio_value:
                        max_tst_portfolio_value = eval_portfolio_value
                        torch.save(model, model_dir + "/" + str(model_index) + ".pkl")
                        print("Saved best model based on validation performance!")
                else:
                    # WARNING: No validation set
                    print(
                        "WARNING: No validation set! Using test set for model selection (data leakage risk)"
                    )
                    eval_loss, eval_portfolio_value = test_batch(
                        DM,
                        x_window_size,
                        model,
                        evaluate_loss_compute,
                        local_context_length,
                    )
                    tst_total_loss += eval_loss.item()
                    elapsed = time.time() - start
                    print(
                        "Test: %d Loss: %f| Portfolio_Value: %f | testset per Sec: %f \r\n"
                        % (
                            i,
                            eval_loss.item(),
                            eval_portfolio_value.item(),
                            1 / elapsed,
                        )
                    )
                    start = time.time()

                    if eval_portfolio_value > max_tst_portfolio_value:
                        max_tst_portfolio_value = eval_portfolio_value
                        torch.save(model, model_dir + "/" + str(model_index) + ".pkl")
                        print("save model!")
    return eval_loss, eval_portfolio_value


def train_net_osbl(
    DM,
    total_step,
    output_step,
    x_window_size,
    local_context_length,
    model,
    model_dir,
    model_index,
    loss_compute,
    evaluate_loss_compute,
    num_batches=4,
    gradient_clip=1.0,
    correlation_tracker=None,
    ewc_regularizer=None,
    ewc_update_freq=100,
    is_trn=True,
    evaluate=True,
):
    """
    OSBL Training Function with Novel Online Learning Modifications
    At each period t:
    1. Sample N_b mini-batches using geometric distribution
    2. Each batch contains n_b consecutive periods
    3. Update model with gradient clipping for stability
    4. Update correlation matrix (Section 3.3)
    5. Periodically update EWC Fisher matrix (Section 3.5)
    """
    start = time.time()
    max_tst_portfolio_value = 0

    print(
        f"Starting OSBL training with {num_batches} batches per update, gradient_clip={gradient_clip}"
    )
    if correlation_tracker is not None:
        print("Correlation matrix tracking enabled")
    if ewc_regularizer is not None:
        print(f"EWC regularization enabled with update_freq={ewc_update_freq}")

    for i in range(total_step):
        if is_trn:
            model.train()

            # OSBL: Sample multiple batches per update
            total_loss = 0.0
            total_pv = 0.0

            for batch_idx in range(num_batches):
                loss, portfolio_value = train_one_step_osbl(
                    DM,
                    x_window_size,
                    model,
                    loss_compute,
                    local_context_length,
                    gradient_clip,
                    correlation_tracker,
                    ewc_regularizer,
                )
                total_loss += loss.item()
                total_pv += portfolio_value.item()

            # Average metrics across batches
            avg_loss = total_loss / num_batches
            avg_pv = total_pv / num_batches

            # Update EWC Fisher matrix periodically (Section 3.5)
            if ewc_regularizer is not None and i > 0 and i % ewc_update_freq == 0:
                print(f"Step {i}: Updating EWC Fisher matrix...")
                # Create a simple data loader from recent replay buffer samples
                # For simplicity, we'll just update with current parameters
                ewc_regularizer.fisher_matrix.save_optimal_params()
                print("EWC parameters updated")

        if i % output_step == 0 and is_trn:
            elapsed = time.time() - start

            # Get current portfolio weights for diagnostics
            with torch.no_grad():
                model.eval()
                batch = DM.next_batch()
                batch_input = batch["X"]
                batch_last_w = batch["last_w"]
                previous_w = torch.tensor(batch_last_w, dtype=torch.float).to(device)
                previous_w = torch.unsqueeze(previous_w, 1)
                batch_input = batch_input.transpose((1, 0, 2, 3))
                batch_input = batch_input.transpose((0, 1, 3, 2))
                src = torch.tensor(batch_input, dtype=torch.float).to(device)
                price_series_mask = torch.ones(src.size()[1], 1, x_window_size) == 1
                currt_price = src.permute((3, 1, 2, 0))

                if local_context_length > 1:
                    padding_price = currt_price[
                        :, :, -(local_context_length) * 2 + 1 : -1, :
                    ]
                else:
                    padding_price = None

                currt_price = currt_price[:, :, -1:, :]
                trg_mask = make_std_mask(currt_price, src.size()[1])

                out = model.forward(
                    src,
                    currt_price,
                    previous_w,
                    price_series_mask,
                    trg_mask,
                    padding_price,
                )

                # Analyze portfolio weights
                weights = out[0, 0, :].cpu().numpy()  # First sample weights
                cash_weight = weights[0]
                asset_weights = weights[1:]
                max_weight = asset_weights.max() if len(asset_weights) > 0 else 0
                min_weight = asset_weights.min() if len(asset_weights) > 0 else 0
                weight_std = asset_weights.std() if len(asset_weights) > 0 else 0
                num_active = (asset_weights > 0.05).sum()  # Assets with >5% allocation

                model.train()

            print(
                f"OSBL Step: {i} | Avg Loss: {avg_loss:.6f} | Avg Portfolio_Value: {avg_pv:.4f} | "
                f"Batches/update: {num_batches} | Steps per Sec: {output_step / elapsed:.4f}\n"
                f"  Portfolio: Cash={cash_weight:.3f} | Max Asset={max_weight:.3f} | Min Asset={min_weight:.3f} | "
                f"Std={weight_std:.4f} | Active Assets (>5%)={num_active}\n"
                f"Full Portfolio Weights: {weights}"
            )
            start = time.time()

        # Evaluation on VALIDATION set (not test!)
        with torch.no_grad():
            if i % output_step == 0 and evaluate:
                model.eval()

                # Use validation set if available, otherwise warn and use test
                if DM.num_validation_samples > 0:
                    # Proper approach: evaluate on validation set using test_batch logic
                    val_batch = DM.get_validation_set()
                    val_batch_input = val_batch["X"]
                    val_batch_y = val_batch["y"]
                    val_batch_last_w = val_batch["last_w"]

                    val_previous_w = torch.tensor(
                        val_batch_last_w, dtype=torch.float
                    ).to(device)
                    val_previous_w = torch.unsqueeze(val_previous_w, 1)
                    val_batch_input = val_batch_input.transpose((1, 0, 2, 3))
                    val_batch_input = val_batch_input.transpose((0, 1, 3, 2))
                    val_src = torch.tensor(val_batch_input, dtype=torch.float).to(
                        device
                    )
                    val_src_mask = torch.ones(val_src.size()[1], 1, x_window_size) == 1
                    val_currt_price = val_src.permute((3, 1, 2, 0))

                    if local_context_length > 1:
                        padding_price = val_currt_price[
                            :, :, -(local_context_length) * 2 + 1 : -1, :
                        ]
                    else:
                        padding_price = None

                    val_currt_price = val_currt_price[:, :, -1:, :]
                    val_trg_mask = make_std_mask(val_currt_price, val_src.size()[1])
                    val_batch_y = val_batch_y.transpose((0, 2, 1))
                    val_trg_y = torch.tensor(val_batch_y, dtype=torch.float).to(device)

                    val_out = model.forward(
                        val_src,
                        val_currt_price,
                        val_previous_w,
                        val_src_mask,
                        val_trg_mask,
                        padding_price,
                    )

                    eval_loss, eval_portfolio_value = evaluate_loss_compute(
                        val_out, val_trg_y
                    )
                    elapsed = time.time() - start
                    print(
                        f"OSBL Validation: {i} | Loss: {eval_loss.item():.6f} | "
                        f"Portfolio_Value: {eval_portfolio_value.item():.4f} | "
                        f"Val per Sec: {1 / elapsed:.4f}\n"
                    )

                    # Save best model based on VALIDATION performance
                    if eval_portfolio_value > max_tst_portfolio_value:
                        max_tst_portfolio_value = eval_portfolio_value
                        torch.save(
                            model, model_dir + "/" + str(model_index) + "_osbl.pkl"
                        )
                        print(
                            "OSBL: Saved best model based on validation performance!\n"
                        )
                else:
                    # WARNING: No validation set, falling back to test (not recommended)
                    print(
                        "WARNING: No validation set! Using test set for model selection (data leakage risk)"
                    )
                    eval_loss, eval_portfolio_value = test_batch(
                        DM,
                        x_window_size,
                        model,
                        evaluate_loss_compute,
                        local_context_length,
                    )
                    elapsed = time.time() - start
                    print(
                        f"OSBL Test: {i} | Loss: {eval_loss.item():.6f} | "
                        f"Portfolio_Value: {eval_portfolio_value.item():.4f} | "
                        f"Test per Sec: {1 / elapsed:.4f}\n"
                    )

                    if eval_portfolio_value > max_tst_portfolio_value:
                        max_tst_portfolio_value = eval_portfolio_value
                        torch.save(
                            model, model_dir + "/" + str(model_index) + "_osbl.pkl"
                        )
                        print("OSBL: Saved best model!\n")

                start = time.time()

    return eval_loss, eval_portfolio_value


start = parse_time(FLAGS.start)
end = parse_time(FLAGS.end)
DM = DataMatrices(
    start=start,
    end=end,
    market="poloniex",
    feature_number=FLAGS.feature_number,
    window_size=FLAGS.x_window_size,
    online=False,
    period=86400,
    coin_filter=11,
    is_permed=False,
    buffer_bias_ratio=5e-5,
    batch_size=FLAGS.batch_size,  # 128,
    volume_average_days=30,
    test_portion=FLAGS.test_portion,  # 0.08,
    validation_portion=FLAGS.validation_portion,  # 0.0 by default
    portion_reversed=False,
    use_osbl=FLAGS.use_osbl,
    osbl_beta=FLAGS.osbl_beta,
    osbl_max_memory=FLAGS.osbl_max_memory,
)


def make_model(
    batch_size,
    coin_num,
    window_size,
    feature_number,
    N=6,
    d_model_Encoder=512,
    d_model_Decoder=16,
    d_ff_Encoder=2048,
    d_ff_Decoder=64,
    h=8,
    dropout=0.0,
    local_context_length=3,
    use_decay_attention=False,
    temporal_decay_lambda=0.1,
    use_correlation_matrix=False,
    correlation_tracker=None,
    use_recency_pe=False,
    recency_beta=0.1,
):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn_Encoder = MultiHeadedAttention(
        True,
        h,
        d_model_Encoder,
        0.1,
        local_context_length,
        use_decay_attention,
        temporal_decay_lambda,
        use_correlation_matrix,
        correlation_tracker,
    )
    attn_Decoder = MultiHeadedAttention(
        True,
        h,
        d_model_Decoder,
        0.1,
        local_context_length,
        use_decay_attention,
        temporal_decay_lambda,
        use_correlation_matrix,
        correlation_tracker,
    )
    attn_En_Decoder = MultiHeadedAttention(
        False, h, d_model_Decoder, 0.1, 1, False, 0.1, False, None
    )
    ff_Encoder = PositionwiseFeedForward(d_model_Encoder, d_ff_Encoder, dropout)
    ff_Decoder = PositionwiseFeedForward(d_model_Decoder, d_ff_Decoder, dropout)
    position_Encoder = PositionalEncoding(
        d_model_Encoder,
        0,
        dropout,
        use_recency_bias=use_recency_pe,
        recency_beta=recency_beta,
    )
    position_Decoder = PositionalEncoding(
        d_model_Decoder,
        window_size - local_context_length * 2 + 1,
        dropout,
        use_recency_bias=use_recency_pe,
        recency_beta=recency_beta,
    )

    model = EncoderDecoder(
        batch_size,
        coin_num,
        window_size,
        feature_number,
        d_model_Encoder,
        d_model_Decoder,
        Encoder(
            EncoderLayer(d_model_Encoder, c(attn_Encoder), c(ff_Encoder), dropout), N
        ),
        Decoder(
            DecoderLayer(
                d_model_Decoder,
                c(attn_Decoder),
                c(attn_En_Decoder),
                c(ff_Decoder),
                dropout,
            ),
            N,
        ),
        c(position_Encoder),  # price series position ecoding
        c(position_Decoder),  # local_price_context position ecoding
        local_context_length,
    )
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model


#################set device (CPU or CUDA)###################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#################set learning rate###################
lr_model_sz = 5120
factor = FLAGS.learning_rate  # 1.0
warmup = 0  # 800

total_step = FLAGS.total_step
x_window_size = FLAGS.x_window_size  # 31

batch_size = FLAGS.batch_size
# Get actual coin_num from database (not from FLAGS)
coin_num = DM.coin_number
print(
    f"Using {coin_num} symbols from database (ignoring FLAGS.coin_num={FLAGS.coin_num})"
)
feature_number = FLAGS.feature_number  # 4
trading_consumption = FLAGS.trading_consumption  # 0.0025
variance_penalty = FLAGS.variance_penalty  # 0 #0.01
cost_penalty = FLAGS.cost_penalty  # 0 #0.01
output_step = FLAGS.output_step  # 50
local_context_length = FLAGS.local_context_length
model_dim = FLAGS.model_dim
weight_decay = FLAGS.weight_decay
# Interest rate: Convert daily rate to half-hour rate (original crypto frequency)
# For daily stock data, use: interest_rate = FLAGS.daily_interest_rate (no division)
# For 30-min crypto data: interest_rate = FLAGS.daily_interest_rate / 24 / 2
interest_rate = FLAGS.daily_interest_rate  # Daily rate for daily stock data

# Initialize correlation tracker if enabled (Section 3.3)
correlation_tracker = None
if FLAGS.use_correlation_matrix:
    correlation_tracker = AssetCorrelationTracker(
        num_assets=coin_num, gamma=FLAGS.correlation_gamma, device=device
    )

model = make_model(
    batch_size,
    coin_num,
    x_window_size,
    feature_number,
    N=1,
    d_model_Encoder=FLAGS.multihead_num * model_dim,
    d_model_Decoder=FLAGS.multihead_num * model_dim,
    d_ff_Encoder=FLAGS.multihead_num * model_dim,
    d_ff_Decoder=FLAGS.multihead_num * model_dim,
    h=FLAGS.multihead_num,
    dropout=0.01,
    local_context_length=local_context_length,
    use_decay_attention=FLAGS.use_decay_attention,
    temporal_decay_lambda=FLAGS.temporal_decay_lambda,
    use_correlation_matrix=FLAGS.use_correlation_matrix,
    correlation_tracker=correlation_tracker,
    use_recency_pe=FLAGS.use_recency_pe,
    recency_beta=FLAGS.recency_beta,
)

# model = make_model3(N=6, d_model=512, d_ff=2048, h=8, dropout=0.1)
model = model.to(device)

# Initialize EWC regularizer if enabled (Section 3.5)
ewc_regularizer = None
if FLAGS.use_ewc:
    ewc_regularizer = EWCRegularizer(model, lambda_ewc=FLAGS.ewc_lambda)
    print(
        f"EWC regularization enabled with λ={FLAGS.ewc_lambda}, update_freq={FLAGS.ewc_update_freq}"
    )

# model_size, factor, warmup, optimizer)
model_opt = NoamOpt(
    lr_model_sz,
    factor,
    warmup,
    torch.optim.Adam(
        model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
    ),
)

loss_compute = SimpleLossCompute(
    Batch_Loss(
        trading_consumption, interest_rate, variance_penalty, cost_penalty, True
    ),
    model_opt,
)
evaluate_loss_compute = SimpleLossCompute(
    Batch_Loss(
        trading_consumption, interest_rate, variance_penalty, cost_penalty, False
    ),
    None,
)
test_loss_compute = SimpleLossCompute_tst(
    Test_Loss(
        trading_consumption, interest_rate, variance_penalty, cost_penalty, False
    ),
    None,
)


##########################train net####################################################
if FLAGS.use_osbl:
    print("=" * 80)
    print("Starting OSBL Training Mode")
    print(
        f"Beta: {FLAGS.osbl_beta}, Num Batches: {FLAGS.osbl_num_batches}, Max Memory: {FLAGS.osbl_max_memory}"
    )
    print(f"Gradient Clipping: {FLAGS.gradient_clip}")
    if FLAGS.use_decay_attention:
        print(
            f"Decay-aware Context Attention enabled (λ={FLAGS.temporal_decay_lambda})"
        )
    if FLAGS.use_correlation_matrix:
        print(f"Asset Correlation Matrix enabled (γ={FLAGS.correlation_gamma})")
    if FLAGS.use_recency_pe:
        print(f"Recency-biased Positional Encoding enabled (β={FLAGS.recency_beta})")
    if FLAGS.use_ewc:
        print(
            f"EWC Regularization enabled (λ={FLAGS.ewc_lambda}, freq={FLAGS.ewc_update_freq})"
        )
    print("=" * 80)

    tst_loss, tst_portfolio_value = train_net_osbl(
        DM,
        total_step,
        output_step,
        x_window_size,
        local_context_length,
        model,
        FLAGS.model_dir,
        FLAGS.model_index,
        loss_compute,
        evaluate_loss_compute,
        num_batches=FLAGS.osbl_num_batches,
        gradient_clip=FLAGS.gradient_clip,
        correlation_tracker=correlation_tracker,
        ewc_regularizer=ewc_regularizer,
        ewc_update_freq=FLAGS.ewc_update_freq,
        is_trn=True,
        evaluate=True,
    )

    # Load best OSBL model
    model_path = FLAGS.model_dir + "/" + str(FLAGS.model_index) + "_osbl.pkl"
    if os.path.isfile(model_path):
        model = torch.load(model_path, weights_only=False)
        print(f"Loaded best OSBL model from {model_path}")
    else:
        print(f"Warning: OSBL model not found at {model_path}, using current model")
else:
    print("=" * 80)
    print("Starting Standard Training Mode")
    print("=" * 80)

    tst_loss, tst_portfolio_value = train_net(
        DM,
        total_step,
        output_step,
        x_window_size,
        local_context_length,
        model,
        FLAGS.model_dir,
        FLAGS.model_index,
        loss_compute,
        evaluate_loss_compute,
        True,
        True,
    )

    # Load best standard model
    model = torch.load(
        FLAGS.model_dir + "/" + str(FLAGS.model_index) + ".pkl", weights_only=False
    )

##########################test net#####################################################
print("\nLoaded trained model. Running test on full test set...\n")
tst_portfolio_value, SR, CR, St_v, tst_pc_array, TO = test_net(
    DM,
    x_window_size,
    local_context_length,
    model,
    test_loss_compute,
    log_dir=FLAGS.log_dir if FLAGS.log_dir else "./log",
)


csv_dir = FLAGS.log_dir + "/" + "train_summary.csv"
d = {
    "net_dir": [FLAGS.model_index],
    "fAPV": [tst_portfolio_value.item()],
    "SR": [SR.item()],
    "CR": [CR.item()],
    "TO": [TO.item()],
    "St_v": ["".join(str(e) + ", " for e in St_v)],
    "backtest_test_history": [
        "".join(str(e) + ", " for e in tst_pc_array.cpu().numpy())
    ],
}
new_data_frame = pd.DataFrame(data=d).set_index("net_dir")
if os.path.isfile(csv_dir):
    dataframe = pd.read_csv(csv_dir).set_index("net_dir")
    dataframe = pd.concat([dataframe, new_data_frame])
else:
    dataframe = new_data_frame
dataframe.to_csv(csv_dir)

##########################Long-Term Analysis#####################################################
print("\n" + "=" * 80)
print("📊 GENERATING COMPREHENSIVE LONG-TERM ANALYSIS")
print("=" * 80 + "\n")

try:
    # Prepare data for analyzer
    portfolio_values = St_v
    returns_array = (
        tst_pc_array.cpu().numpy() - 1.0
    )  # Convert portfolio changes to returns

    # Create analyzer
    analyzer = LongTermAnalyzer(
        portfolio_values=portfolio_values,
        returns_array=returns_array,
        weights_history=None,  # Not tracked in current implementation
        start_date=FLAGS.start,
        trading_days_per_year=252,  # Stock market
        benchmark_returns=None,  # Will use default SPY-like benchmark
    )

    # Generate all analysis
    analysis_dir = "./log"
    metrics = analyzer.plot_all(save_dir=analysis_dir)

    print("\n" + "=" * 80)
    print("✅ ANALYSIS COMPLETE!")
    print("=" * 80)
    print(f"Charts saved to: {analysis_dir}/")
    print(f"  - tier1_analysis.png (Core Performance)")
    print(f"  - tier2_analysis.png (Risk-Adjusted Quality)")
    print(f"  - tier3_analysis.png (Portfolio Behavior)")
    print(f"  - metrics_summary.csv (All Metrics)")
    print("=" * 80 + "\n")

except ImportError:
    print("⚠️ longterm_analysis.py not found. Skipping detailed analysis.")
    print("   Run: python longterm_analysis.py " + csv_dir)
except Exception as e:
    print(f"⚠️ Error during analysis: {e}")
    print("   Results still saved to: " + csv_dir)
