from Narsese import Budget
from .Item import Item
from typing import Generic, TypeVar
import random
from typing import Iterable

ItemType = TypeVar("ItemType", bound=Item)


class Memory(Generic[ItemType]):

    def __init__(self, capacity):
        self.buff = set[ItemType]()  # buffer
        self.capacity = capacity

    def remove(self, item: Item):
        self.buff.discard(item)

    def truncate(self):
        if self.capacity > 0 and self.capacity < len(self.buff):
            ltb_sorted = sorted(self.buff, key=lambda item: item.budget_m.q)
            self.buff = set(ltb_sorted[-self.capacity:])

    def set_capacity(self, capacity):
        self.capacity = capacity
        self.truncate()

    def __contains__(self, item: Item):
        return item in self.buff

    def __len__(self):
        return len(self.buff)

    def __iter__(self):
        return iter(self.buff)


class LTM(Memory[ItemType]):
    """Long Term Buffer"""

    def __init__(self, capacity):
        super().__init__(capacity)

    def insert(self, item: Item):
        if self.capacity <= 0 or len(self.buff) < self.capacity:
            self.buff.add(item)
            return None

        pv = item.budget_m.q
        item_min: Item = min(self.buff, key=lambda u: u.budget_m.q)
        if item_min.budget_m.q < pv:
            self.buff.remove(item_min)
            self.buff.add(item)
            return item_min
        else:
            return item


class STM(Memory[ItemType]):
    """Short Term Buffer"""

    def __init__(self, capacity):
        super().__init__(capacity)

    def insert(self, item: Item):
        if self.capacity <= 0 or len(self.buff) < self.capacity:
            self.buff.add(item)
            return None

        pv = item.budget_m.p
        item_min: Item = min(self.buff, key=lambda u: u.budget_m.p)
        if item_min.budget_m.p < pv:
            self.buff.remove(item_min)
            self.buff.add(item)
            return item_min
        else:
            return item


class LSTM(Generic[ItemType]):
    """Long Short-Term Buffer
    根据budget_m来分配记忆资源
    综合考虑长时记忆和短时记忆
    长期有用的和短期有用的都会被保留
    """

    def __init__(self, n_ltb, n_stb):
        self.ltb = LTM[ItemType](n_ltb)  # long term buffer
        self.stb = STM[ItemType](n_stb)  # short term buffer
        self.n_ltb = n_ltb
        self.n_stb = n_stb
        self.capacity = n_ltb + n_stb

    def insert_to_ltb(self, item: ItemType):
        return self.ltb.insert(item)

    def insert_to_stb(self, item: ItemType):
        return self.stb.insert(item)

    def insert(self, item: ItemType):
        of_stb = self.insert_to_stb(item)
        of_ltb = None if of_stb is None else self.insert_to_ltb(of_stb)
        return of_ltb

    def remove_from_ltb(self, unit):
        self.ltb.remove(unit)

    def remove_from_stb(self, unit):
        self.stb.remove(unit)

    def remove(self, unit):
        self.remove_from_ltb(unit)
        self.remove_from_stb(unit)

    def set_capacity(self, n_ltb, n_stb):
        self.stb.set_capacity(n_stb)
        self.ltb.set_capacity(n_ltb)

    def contain(self, unit):
        return unit in self.ltb or unit in self.stb

    def __contains__(self, unit):
        return self.contain(unit)

    def __len__(self):
        return len(self.ltb) + len(self.stb)

    def __iter__(self) -> Iterable[ItemType]:
        return iter((*iter(self.ltb), *iter(self.stb)))
