from utils import generate_sample, gbellmf
from Narsese import Budget
import numpy as np
from Model.Concept import Prototype, Composition, Task, Belief, Location
from Model.Retina import Retina
from copy import copy
import matplotlib.pyplot as plt
import pickle
from collections import defaultdict
from ordered_set import OrderedSet
from Model.SMLayer2D import SMLayer2D
import matplotlib.pyplot as plt
from rich import print, traceback; traceback.install()

locs = [(5, 2), (2, 8), (8, 8)]
length = 21
features = [Prototype() for _ in range(1)]

sample = generate_sample(locs, length, 1)

retina = Retina((21, 21), len(features), 0.125, radius=3)
retina.update(sample)


'''构建模型'''
n_features = 1
retina.move_to(5+5, 2+5)

layer = SMLayer2D(3)

proto = layer.new_prototype(3, budget_m=Budget(q=0.3), budget_c=Budget(p=0.1))
locs = np.array([(5, 2), (2, 8), (8, 8)], dtype=float)/11.0
for loc in locs:
    composition = proto.new_composition(features[0], Location(loc))
    composition.truth.reset(1.0, 0.9)


ts_now = 0
lifetime = 100
print("start the loop")
trace = [tuple(retina.curr_loc)]
for _ in range(lifetime):
    ''' 获取视网膜的输出 '''
    fmap = retina.sense()
    feats = []
    for feat, row, col in zip(fmap[fmap>0], *np.where(fmap>0)):
        ''' Coordinates
        o----> x
        | .(x, y)
        v
        y
        '''
        lb_x,lb_y, *_ = retina.get_bounds()
        x, y = retina.curr_loc
        bias_x, bias_y = lb_x-x+col, lb_y-y+row
        # print(feat, (bias_x, bias_y))
        feats.append((feat, (bias_x, bias_y)))
    
    layer.sensory_input([(features[feat_bias[0]-1], feat_bias[1]) for feat_bias in feats], ts_now)  # 将当前观察到的特征输入给感知运动层

    dx, dy = layer.working_cycle(ts_now)  # 在一个工作周期内做处理：检索、学习

    dx, dy = retina.move(dx*layer.scale, dy*layer.scale)  # 根据感知运动层的输出，移动视觉位置
    layer.motor_input(dx/layer.scale, dy/layer.scale)  # 将移动的结果反馈给感知运动层
    layer.decay(ts_now)

    trace.append(tuple(retina.curr_loc))

    ts_now += 1
plt.cla()
plt.imshow(retina.data, cmap='gray')
plt.plot(*zip(*trace[:-1]), marker='x', color='r')  # 绘制轨迹
plt.plot(*zip(*trace[-4:]), marker='x', color='g')  # 绘制轨迹
plt.gca().set_aspect('equal', adjustable='box')
plt.pause(0.1)


plt.show()
print("done.")