def get_state(self) -> list[float]: head = self.snake_head width, height = self.grid_size # always 20,20 now dir_vec = self.direction

# === 1. Direction one-hot ===
dir_l = int(dir_vec.dx == -1 and dir_vec.dy == 0)
dir_r = int(dir_vec.dx ==  1 and dir_vec.dy == 0)
dir_u = int(dir_vec.dx ==  0 and dir_vec.dy == -1)
dir_d = int(dir_vec.dx ==  0 and dir_vec.dy ==  1)

# === 2. Relative dangers (snake body only + wall) ===
def get_relative_dangers():
    if dir_r:
        straight = Position(head.x + 1, head.y)
        right    = Position(head.x, head.y + 1)
        left     = Position(head.x, head.y - 1)
    elif dir_l:
        straight = Position(head.x - 1, head.y)
        right    = Position(head.x, head.y - 1)
        left     = Position(head.x, head.y + 1)
    elif dir_u:
        straight = Position(head.x, head.y - 1)
        right    = Position(head.x + 1, head.y)
        left     = Position(head.x - 1, head.y)
    else:  # down
        straight = Position(head.x, head.y + 1)
        right    = Position(head.x - 1, head.y)
        left     = Position(head.x + 1, head.y)

    snake_d_straight = self._is_snake_collision(straight)
    snake_d_right    = self._is_snake_collision(right)
    snake_d_left     = self._is_snake_collision(left)
    wall_d_straight  = self._is_wall_collision(straight)
    wall_d_right     = self._is_wall_collision(right)
    wall_d_left      = self._is_wall_collision(left)

    return [snake_d_left, snake_d_straight, snake_d_right,
            wall_d_left,  wall_d_straight,  wall_d_right]

# === 3. Food & Tail cardinal flags (keeping your style) ===
food_left  = int(self.food_position.x < head.x)
food_right = int(self.food_position.x > head.x)
food_up    = int(self.food_position.y < head.y)
food_down  = int(self.food_position.y > head.y)
food_on_x  = int(head.x == self.food_position.x)
food_on_y  = int(head.y == self.food_position.y)

tail = self.snake_body[-1] if self.snake_body else head
tail_left  = int(tail.x < head.x)
tail_right = int(tail.x > head.x)
tail_up    = int(tail.y < head.y)
tail_down  = int(tail.y > head.y)

# === 4. Length bits ===
length_bits = self._int_to_bits(self.STATE_LENGTH_BITS, self.get_snake_length())

# === 5. Local 7x7 directional occupancy grid (the key addition) ===
grid = []
# We rotate so current direction is always "up" (negative y)
for dy in range(-3, 4):      # 7 rows
    for dx in range(-3, 4):  # 7 columns
        if dx == 0 and dy == 0:
            grid.append(0.0)          # head = special value
            continue

        # Transform to world coordinates based on current heading
        if dir_r:   # right = current "up" becomes world +x? Wait, let's define properly
            wx = head.x + dy   # adjust mapping so snake "looks" up in the grid
            wy = head.y - dx
        elif dir_l:
            wx = head.x - dy
            wy = head.y + dx
        elif dir_u:
            wx = head.x + dx
            wy = head.y + dy
        else:  # down
            wx = head.x - dx
            wy = head.y - dy

        pos = Position(wx, wy)
        if not self.is_position_within_bounds(pos):
            grid.append(-1.0)                    # wall
        elif self._is_snake_collision(pos):
            grid.append(1.0)                     # body
        elif pos == self.food_position:
            grid.append(0.5)                     # food
        else:
            grid.append(0.0)                     # empty

# Build final state
state = [
    *get_relative_dangers(),           # 6 features
    dir_l, dir_r, dir_u, dir_d,        # 4
    food_left, food_right, food_up, food_down, food_on_x, food_on_y,  # 6
    tail_left, tail_right, tail_up, tail_down,                        # 4
    *length_bits,                      # 7
    *grid,                             # 49 (7×7)
]

assert len(state) == DNetDef.INPUT_SIZE return [float(x) for x in state]