"""
A simple grid world environment
"""
import numpy as np
import random
from ..common import MDP
from ..exceptions import EpisodeDoneError, InvalidActionError
from ..actions import LEFT, RIGHT, UP, DOWN
from .helper_utilities import flatten_state, unflatten_state
from typing import List, Tuple
[docs]class GridWorldMDP(MDP):
"""
.. note::
if ``terminal_states`` is not empty then there will be an absorbing state. So
the actual number of states will be :math:`size^2 + 1`
if there is a terminal state, it should be the last one.
Args:
P (np.ndarray): state transition matrix :math:`P: \mathcal{S}\\times\mathcal{A}\\times\mathcal{S}\mapsto\mathbb{R}`,
the shape is :math:`|S| \\times |A| \\times |S|`.
R (np.ndarray): reward matrix :math:`r: \mathcal{S}\\times \mathcal{A}\mapsto \mathbb{R}`,
the shape is:math:`|S| \\times |A|`.
gamma (float): discount factor :math:`\gamma`
p0 (np.ndarray): initial starting distribution :math:`p_0`. The array shape is :math:`|\mathcal{S}|=size\\times size`.
terminal_states (List[Tuple[int,int]]): Must be a list of (x,y) tuples.
use skip_terminal_state_conversion if giving ints
size (int): the size of the grid world (i.e there are :math:`size \\times size + 1 = |\mathcal{S}|` states in total).
seed (int, optional): the random seed for simulations. Defaults to 1337.
skip_check (bool, optional): _description_. Defaults to False.
convert_terminal_states_to_ints (bool, optional): _description_. Defaults to False.
"""
def __init__(self, P, R, gamma, p0,
terminal_states: List[Tuple[int, int]], size: int,
seed=1337, skip_check=False,
convert_terminal_states_to_ints=False):
if not convert_terminal_states_to_ints:
terminal_states = list(map(lambda tupl: int(size * tupl[0] + tupl[1]), terminal_states))
self.size = size
self.human_state = (None, None)
self.has_absorbing_state = len(terminal_states) > 0
super().__init__(P, R, gamma, p0, terminal_states, seed=seed, skip_check=skip_check)
[docs] def reset(self):
super().reset()
self.human_state = self.unflatten_state(self.current_state)
return self.current_state
[docs] def flatten_state(self, state):
"""
Flatten state (row, col) into a one-hot vector.
see also: :func:`emdp.gridworld.helper_utilities.flatten_state`
Args:
state (Tuple[int,int]): (row, col) pair
Returns:
np.ndarray: one-hot vector of shape (size * size)
"""
return flatten_state(state, self.size, self.state_space)
[docs] def unflatten_state(self, onehot) -> Tuple[int, int]:
"""Unflatten a one-hot state vector into a (row, col) pair
see also: :func:`emdp.gridworld.helper_utilities.unflatten_state`
Args:
onehot (np.ndarray): one-hot vector of shape (size, size)
Returns:
Tuple[int,int]: (row, col) pair
"""
return unflatten_state(onehot, self.size, self.has_absorbing_state)
[docs] def step(self, action):
state, reward, done, info = super().step(action)
self.human_state = self.unflatten_state(self.current_state)
return state, reward, done, info
[docs] def set_current_state_to(self, tuple_state):
return super().set_current_state_to(self.flatten_state(tuple_state).argmax())