import numpy as np
from numpy.random import randint


class Retina:
    def __init__(self, shape=(5, 5), n_levels=16, sigma=0.125, radius=0) -> None:
        self.radius = radius
        self.shape = np.array(shape)
        self.data = np.full(shape, -1, dtype=int)
        self.curr_loc = np.array((2, 2))

        self.n_levels = n_levels
        self.sigma = sigma

    def move(self, dx, dy):
        x = self.curr_loc[0] + dx
        x = x if 0 <= x < self.shape[1] else (0 if x < 0 else self.shape[1]-1)
        y = self.curr_loc[1] + dy
        y = y if 0 <= y < self.shape[0] else (0 if y < 0 else self.shape[0]-1)
        dx_ = x - self.curr_loc[0]
        dy_ = y - self.curr_loc[1]
        self.curr_loc[:] = (x, y)
        return dx_, dy_
    
    @property
    def x(self):
        return self.curr_loc[0]
    
    @property
    def y(self):
        return self.curr_loc[1]

    def move_to(self, x, y):
        dx = x - self.curr_loc[0]
        dy = y - self.curr_loc[1]
        self.curr_loc[:] = (x, y)
        return dx, dy

    def move_to_center(self):
        yx = self.shape // 2
        dx = yx[1] - self.curr_loc[0] 
        dy = yx[0] - self.curr_loc[1]
        self.curr_loc[:] = (yx[1], yx[0])
        return dx, dy

    def move_to_rand(self):
        self.curr_loc[:] = randint(0, self.shape[1]), randint(0, self.shape[0])

    def randomly_move(self, r):
        self.move(*randint(-r, r+1, size=2))

    
    def get_bounds(self, bias=(0, 0)):
        radius = self.radius
        lb_x = max(self.curr_loc[0] + bias[0] - radius, 0)
        lb_y = max(self.curr_loc[1] + bias[1] - radius, 0)
        ub_x = min(self.curr_loc[0] + bias[0] + radius + 1, self.shape[0])
        ub_y = min(self.curr_loc[1] + bias[1] + radius + 1, self.shape[1])
        return lb_x, lb_y, ub_x, ub_y

    def sense(self, bias=(0, 0)) -> int:
        lb_x, lb_y, ub_x, ub_y = self.get_bounds(bias)
        value = self.data[lb_y:ub_y, lb_x:ub_x] # [0, 1]

        value = value.clip(0.0, 1.0)
        value = np.round(value * (self.n_levels)).astype(int)

        return value

    def update(self, data):
        self.data[:] = data

    def __len__(self):
        return max(self.shape)


if __name__ == "__main__":
    retina = Retina((5, 5))
    print(retina.move(1, 0))
    print(retina.move(4, 0))
    print(retina.move(-3, 3))
    print(retina.move_to_center())
    arr = np.full((5, 5), -1, dtype=int)
    arr[1, 2] = 1
    arr[3, 1] = 2
    arr[3, 3] = 3
    retina.update(arr)
    print(retina.data)
    print(retina.move(0, -1))
    print(retina.curr_loc)

    print(retina.sense())
