from utils import generate_sample, gbellmf
from Model.Prototype import Prototype, Component
from Narsese import Budget
import numpy as np
from Model.FLPair import FLPair, Location, Feature
from Model.Retina import Retina
from copy import copy
import matplotlib.pyplot as plt
import pickle
from collections import defaultdict
from ordered_set import OrderedSet

locs = [(5, 2), (2, 8), (8, 8)]
length = 11
feats = [Feature() for _ in range(1)]


sample = generate_sample(locs, length, 1)

retina = Retina((11, 11), len(feats), 0.125, radius=3)
retina.update(sample)


def get_closest_feature(patch: np.ndarray, R: Retina, bias: tuple[int, int]):
    lb_x, lb_y, ub_x, ub_y = R.get_bounds(bias)
    dist_min = np.inf
    feature = None
    location = None
    for f, row, col in zip(patch[patch>0], *np.where(patch>0)):
        x = lb_x + col
        y = lb_y + row
        dist = np.linalg.norm(np.array([x, y]) - np.array((R.x+bias[0], R.y+bias[1])))
        if dist < dist_min:
            dist_min = dist
            feature = f
            location = (x-R.x, y-R.y)
    return feature, np.array(location, dtype=float)/len(R)


def brute_force_matching(p: Prototype, c0: Component, R: Retina, k:int):
    visited = dict()
    value = 1.0
    c0.budget_c.inhibit_p(0.1)
    visited[c0] = value
    bias = np.array([0.0, 0.0])
    values = [value]
    plt.figure()
    traj_retina = []
    traj_p = []
    traj_bias = []
    v = 0.0
    for _ in range(k*len(p)):
        c = p.components_workspace.item_maxpriority()
        loc_0 = np.array(c0.flpair.location.center)
        loc_c = np.array(c.flpair.location.center)
        mv_1 = np.round((loc_c - loc_0 + bias)*p.scale).astype(int)
        loc_R = copy(R.curr_loc)
        mv_2 = R.move(*mv_1)
        err = mv_1 - mv_2
        patch = R.sense(err)
        feat, bias_f = get_closest_feature(patch, R, err)
        if feat is not None:
            v = c.flpair.location.match_bias(bias_f)
            if c in visited:
                value -= visited[c]
            bias = (bias*value + bias_f*v)/(value+v)
            visited[c] = v
            value += v
            c0 = c
            p.move_to(*loc_c)
        else:
            R.move_to(*loc_R)
            p.move_to(*loc_0)
        c.budget_c.inhibit_p(0.1)

        print(tuple(np.round(bias*p.scale).astype(int)), v)
        values.append(value)
        traj_retina.append(tuple(retina.curr_loc))
        traj_p.append(tuple(p.loc_now))
        traj_bias.append(tuple(bias*p.scale))
        plt.clf()
        ax = plt.subplot(2,2,1)
        plt.imshow(R.data)
        plt.title("retina")
        plt.plot(*zip(*traj_retina))
        plt.plot(*zip(*traj_retina[-2:]), "r")
        ax = plt.subplot(2,2,2)
        ax.invert_yaxis()
        plt.title("prototype")
        plt.plot(*zip(*traj_p))
        plt.plot(*zip(*traj_p[-2:]), "r")
        ax = plt.subplot(2,2,3)
        ax.invert_yaxis()
        plt.title("bias")
        plt.plot(*zip(*traj_bias))
        plt.subplot(2,2,4)
        plt.title("match")
        plt.plot(values)
        plt.tight_layout()
        plt.pause(0.1)

    return value


k = 30
for row, col in zip(*np.where(sample)):
    proto = Prototype(3, budget_m=Budget(q=0.3), budget_c=Budget(p=1.0))
    features = defaultdict(OrderedSet)
    p = 0.9
    for loc in np.array(locs, dtype=float)/length:
        feat = feats[0]
        flpair = FLPair(feat, Location(loc))
        flpair.location.radius = 0.2
        component = proto.new_component(flpair)
        component.budget_c.p = p
        p *= 0.9999
        component.truth_partof.reset(1.0, 0.9)
        features[feat].add(component)
        print(tuple(component.flpair.location.center))
    

    loc = (col, row)
    print(loc)
    feat: Feature = feats[int(sample[row, col])-1]
    component: Component = next(iter(features[feat]))
    retina.move_to(*loc)
    proto.move_to(*component.flpair.location)
    value = brute_force_matching(proto, component, retina, k)
    break

plt.show()
print("Done")

