from Narsese import Budget
from .Item import Item
from typing import Generic, TypeVar
import random
from typing import Iterable

ItemType = TypeVar("ItemType", bound=Item)


class Memory(Generic[ItemType]):

    def __init__(self, capacity):
        self.mem = set[ItemType]()  # buffer
        self.capacity = capacity

    def remove(self, item: Item):
        self.mem.discard(item)

    def truncate(self):
        if self.capacity > 0 and self.capacity < len(self.mem):
            mem_sorted = sorted(self.mem, key=lambda item: item.budget_m.q)
            self.mem = set(mem_sorted[-self.capacity:])
            return mem_sorted[:len(mem_sorted) - self.capacity]
        return ()

    def set_capacity(self, capacity):
        self.capacity = capacity
        return self.truncate()

    def __contains__(self, item: Item):
        return item in self.mem

    def __len__(self):
        return len(self.mem)

    def __iter__(self):
        return iter(self.mem)


class LTM(Memory[ItemType]):
    """Long Term Buffer"""

    def __init__(self, capacity):
        super().__init__(capacity)

    def insert(self, item: Item):
        if self.capacity <= 0 or len(self.mem) < self.capacity:
            self.mem.add(item)
            return None

        pv = item.budget_m.q
        item_min: Item = min(self.mem, key=lambda u: u.budget_m.q)
        if item_min.budget_m.q < pv:
            self.mem.remove(item_min)
            self.mem.add(item)
            return item_min
        else:
            return item


class STM(Memory[ItemType]):
    """Short Term Buffer"""

    def __init__(self, capacity):
        super().__init__(capacity)

    def insert(self, item: Item):
        if self.capacity <= 0 or len(self.mem) < self.capacity:
            self.mem.add(item)
            return None

        pv = item.budget_m.p
        item_min: Item = min(self.mem, key=lambda u: u.budget_m.p)
        if item_min.budget_m.p < pv:
            self.mem.remove(item_min)
            self.mem.add(item)
            return item_min
        else:
            return item


class LSTM(Generic[ItemType]):
    """Long- and Short-Term Buffer"""

    def __init__(self, n_ltb, n_stb):
        self.ltb = LTM[ItemType](n_ltb)  # long term buffer
        self.stb = STM[ItemType](n_stb)  # short term buffer
        self.n_ltb = n_ltb
        self.n_stb = n_stb
        self.capacity = n_ltb + n_stb

    def insert_to_ltb(self, item: ItemType):
        return self.ltb.insert(item)

    def insert_to_stb(self, item: ItemType):
        return self.stb.insert(item)

    def insert(self, item: ItemType):
        of_stb = self.insert_to_stb(item)
        of_ltb = None if of_stb is None else self.insert_to_ltb(of_stb)
        return of_ltb

    def remove_from_ltb(self, unit):
        self.ltb.remove(unit)

    def remove_from_stb(self, unit):
        self.stb.remove(unit)

    def remove(self, unit):
        self.remove_from_ltb(unit)
        self.remove_from_stb(unit)

    def set_capacity(self, n_ltb, n_stb):
        return tuple(*self.stb.set_capacity(n_stb), *self.ltb.set_capacity(n_ltb))

    def contain(self, unit):
        return unit in self.ltb or unit in self.stb

    def __contains__(self, unit):
        return self.contain(unit)

    def __len__(self):
        return len(self.ltb) + len(self.stb)

    def __iter__(self) -> Iterable[ItemType]:
        return iter((*iter(self.ltb), *iter(self.stb)))
