from .Concept import Location, Prototype, Task, Belief, Composition, CompositionMirror
from .Base.LSTM import LSTM
from .Base.Buffer import Buffer
from Narsese import Budget, Truth

from random import random, choice, uniform
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
import math
from collections import defaultdict
from utils import circle_diff


alpha = 0.1  # correctness/matching
beta = 0.1  # compactness/simplicity
gamma = 0.1  # concreteness/goal-directedness


class SMLayer2D:
    novelty = 1.0

    halflife_period_budget = 20
    halflife_period_truth = 20

    def __init__(self, n_protos=100) -> None:
        self.memory = LSTM[Prototype](n_protos, n_protos)
        self.workspace = Buffer[Prototype](n_protos*2)

        self.thresh_curiosity = 1.0
        self.scale = 11.0

    def sensory_input(self, feat_bias_list: list[tuple[Prototype, tuple[int, int]]], ts_now=None):

        ''' Revising
            进行一次学习 
        '''
        # TODO

        ''' Retrieving
            根据输入，匹配相关的components，修改instances、components、prototypes的budget_c
            1. 根据输入，计算匹配值
            2. 根据预期，调整匹配值
                2.1 若预测成功，则匹配值乘以正的系数
                2.2 若预测失败，则匹配值乘以负的系数
            3. 根据匹配值，确定component、instance、prototype的budget_c的增加幅度
        '''
        for feat_bias in feat_bias_list:
            feat, bias = feat_bias
            bias = bias[0]/self.scale, bias[1]/self.scale

            protos = set[Prototype]()
            for composition in feat.upper_links.values():
                protos.add(composition.whole)

            # 要确保匹配的components的相关budget_c往上升，获得更多的关注
            for proto in protos:
                ''' 对已有的tasks进行匹配 '''
                match_values_proto: list[float] = []  # 该prototype的所有tasks的匹配值
                for task in proto.tasks:
                    match_values: list[float] = []  # 该task的所有mirrors的匹配值
                    qs = []

                    for mirror in task.mirrors:
                        loc_task = task.location.center
                        ''' 给定一个component，给定当前位置，给定feature和bias，计算match-value '''
                        # loc_task = Location._move(loc_task, *bias)
                        _value = mirror.item.location.match(loc_task)
                        # match_values.append(_value)
                        match_values.append((mirror, _value))
                        # component.budget_c.exhibit_p(alpha*(Truth(_value, _value).e-0.5))
                        ''' 根据预期调整匹配值。若预测成功，则budget_c增加的幅度应当更大 '''
                        # TODO:
                        # if mirror.truth_anticipation.c > 0:
                        a = mirror.truth_anticipation.c  # 预期量
                        v = _value  # 匹配量
                        q = max(a - v, 0.0)

                        # mirror.budget_c.inhibit_p(alpha*(1-v)*0.1)
                        # mirror.budget_c.inhibit_p(alpha*q)
                        qs.append(q)
                        # print(mirror.truth_anticipation, _value, q)
                        mirror.truth_anticipation.c = 0  # -= min(a, v)
                        #   _value *= 1 + component.truth_anticipation.e - 0.5
                        #   component.truth_event.revise(component.truth_anticipation, ts_now)
                        mirror.budget_c.inhibit_p(alpha*(v+a))
                        # mirror.budget_c.exhibit_p(alpha*(v+a))
                        pass

                    ''' 一个task中有多个composition匹配值，取最大的作为代表修改budget_c '''
                    mirror_max, match_value = max(
                        match_values, key=lambda x: x[1])
                    # print(mirror_max, match_value)
                    # task.budget_c.inhibit_p(alpha*(1-match_value)*5)
                    # q = max(qs)
                    task.budget_c.inhibit_p(alpha*q)
                    task.budget_c.exhibit_p(alpha*match_value)
                    match_values_proto.append(match_value)
                ''' 一个prototype中有多个instance匹配值，取最大的作为代表修改budget_c '''
                if len(match_values_proto) > 0:
                    match_value = max(match_values_proto)
                    proto.budget_c.exhibit_p(alpha*(match_value))

                ''' 建立新的task '''
                tasks: list[Task] = []
                for composition in proto.compositions:
                    loc = composition.location
                    # center = loc.center
                    center = Location._move(loc.center, -bias[0], -bias[1])
                    task = proto.new_task(center, budget_c=Budget(p=0.1))
                    tasks.append(task)

                for task in tasks:
                    match_values: list[float] = []  # 该task的所有mirrors的匹配值
                    for mirror in task.mirrors:
                        loc_task = task.location.center
                        ''' 给定一个component，给定当前位置，给定feature和bias，计算match-value '''
                        loc_task = Location._move(loc_task, *bias)
                        _value = mirror.item.location.match(loc_task)
                        ''' 根据预期调整匹配值。若预测成功，则budget_c增加的幅度应当更大 '''
                        mirror.budget_c.exhibit_p(alpha*_value)
                        pass

    def motor_input(self, dx: float, dy: float):
        for proto in self.workspace:
            for task in proto.tasks:
                task.location.move(dx, dy)
        pass

    def explore(self, ts_now: int):
        raise NotImplementedError(
            "Not implemented yet! Please don't set parameter 'curiosity' as less than 1.0, until this function is implemented.")

    def exploit(self, ts_now: int):
        prototype = self.workspace.item_maxpriority()
        task = prototype.tasks.item_maxpriority()
        if task is None:
            return (0.0, 0.0)
        mirror = task.mirrors.item_maxpriority()

        mirror.budget_c.inhibit_p(gamma*2)
        mirror_next = task.mirrors.item_maxpriority()
        if mirror_next is not mirror:
            mirror_next.budget_c.exhibit_p(gamma)
            mirror_next.truth_anticipation.revise_w(1.0, 1.0, ts_now)

        movement = mirror.item.location - task.location
        print(
            "task:", str(hex(id((task))))[-4:],
            f"{task.budget_c.p:.3f}",
            f"({task.location.center[0]:.2f},{task.location.center[1]:.2f})",
            ", mirror:", str(hex(id((mirror))))[-4:],
            f"{mirror.budget_c.p:.3f}",
            ", mirror_next:", str(hex(id((mirror_next))))[-4:],
            f"{mirror_next.budget_c.p:.3f}",
            ", movement:", f"{movement[0]:.2f},{movement[1]:.2f}"
        )

        return movement

    def new_prototype(self, capacity: int = 7, capacity_instances: int = None, budget_m: Budget = None, budget_c: Budget = None):
        proto = Prototype(capacity, capacity_instances, budget_m, budget_c)
        self.memory.insert(proto)
        self.workspace.insert(proto)
        return proto

    def working_cycle(self, ts_now: int) -> tuple[int, int]:
        if random() <= self.thresh_curiosity:
            return self.exploit(ts_now)
        else:
            return self.explore(ts_now)
        # self.prototypes.

    def decay(self, ts_now: int) -> None:
        proto: Prototype
        for proto in self.memory:
            proto.decay(ts_now)
