from .helper_utilities import unflatten_state
from .env import GridWorldMDP
import numpy as np
[docs]class GridWorldPlotter(object):
"""
Utility to plot gridworlds
Args:
grid_size(int): size of the gridworld
has_absorbing_state(bool, optional): boolean representing if the gridworld has an absorbing state
"""
def __init__(self, grid_size, has_absorbing_state=True):
if isinstance(grid_size, (GridWorldMDP,)):
raise TypeError('grid_size cannot be a GridWorldMDP. '
'To instantiate from GridWorldMDP use GridWorldPlotter.from_mdp()')
assert type(grid_size) is int, 'Gridworld size must be int'
self.size = grid_size
self.has_absorbing_state = has_absorbing_state
# TODO: store where the rewards are so we can plot them.
def _unflatten(self, onehot_state):
return unflatten_state(onehot_state, self.size, self.has_absorbing_state)
[docs] @staticmethod
def from_mdp(mdp: GridWorldMDP):
# TODO: obtain reward specifications
if not isinstance(mdp, GridWorldMDP):
raise TypeError('Only GridWorldMDPs can be used with GridWorldPlotters')
return GridWorldPlotter(mdp.size, mdp.has_absorbing_state)
[docs] def plot_grid(self, ax):
"""
Plots the skeleton of the grid world
:param ax:
:return:
"""
for i in range(self.size + 1):
ax.plot(np.arange(self.size + 1) - 0.5, np.ones(self.size + 1) * i - 0.5, color='k')
for i in range(self.size + 1):
ax.plot(np.ones(self.size + 1) * i - 0.5, np.arange(self.size + 1) - 0.5, color='k')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.grid(False)
return ax
[docs] def plot_trajectories(self, ax, trajectories, dont_unflatten=False, jitter_scale=1):
"""
Plots a individual trajectory paths with some jitter.
:param ax: The axes to plot this on
:param trajectories: a list of trajectories. Each trajectory is a list of states (numpy arrays)
These states should be obtained by using the mdp.step() operation. To prevent
this automatic conversion use `dont_unflatten`
:param dont_unflatten: will not automatically unflatten the trajectories into (x,y) pairs.
(!) this assumes you have already unflattened them!
:return:
"""
if not dont_unflatten:
trajectories_unflat = list(self.unflat_trajectories(trajectories))
else:
trajectories_unflat = trajectories
for trajectory_unflattened in trajectories_unflat:
x, y = list(zip(*trajectory_unflattened))
x = np.array(x) + jitter_scale * np.random.rand(len(x)) / (2 * self.size)
y = np.array(y) + jitter_scale * np.random.rand(len(x)) / (2 * self.size)
ax.plot(x, y)
return ax
[docs] def plot_environment(self, ax, wall_locs=None, plot_grid=False):
"""
Plots the environment with walls.
Args:
ax: The axes to plot this on
wall_locs(List[Tuple[int,int]]): Locations of the walls for plotting them in a different color.
The locations is a list of (row, col) tuples.
plot_grid(bool): Boolean to plot the grid.
Returns:
Tuple:
ax: The axes of the final plot.\n
imshow_ax: The final plot.
"""
# plot walls in lame way -- set them to some hand-engineered color
wall_img = np.zeros((self.size, self.size, 4))
if wall_locs is not None:
for state in wall_locs:
y_coord = state[0]
x_coord = state[1]
wall_img[y_coord, x_coord, 0] = 0.0 # R
wall_img[y_coord, x_coord, 1] = 0.0 # G
wall_img[y_coord, x_coord, 2] = 0.0 # B
wall_img[y_coord, x_coord, 3] = 1.0 # alpha
# render heatmap and overlay the walls image
imshow_ax = ax.imshow(wall_img, interpolation=None)
ax.grid(False)
# Switch on flag if you want to plot grid
if plot_grid:
for i in range(self.size + 1):
ax.plot(np.arange(self.size + 1) - 0.5, np.ones(self.size + 1) * i - 0.5, color='k')
for i in range(self.size + 1):
ax.plot(np.ones(self.size + 1) * i - 0.5, np.arange(self.size + 1) - 0.5, color='k')
ax.set_xlabel('x')
ax.set_ylabel('y')
return ax, imshow_ax
[docs] def plot_heatmap(self, ax, trajectories, dont_unflatten=False, wall_locs=None):
"""
Plots a state-visitation heatmap with walls.
:param ax: The axes to plot this on.
:param trajectories: a list of trajectories. Each trajectory is a list of states (numpy arrays)
These states should be obtained by using the mdp.step() operation. To prevent
this automatic conversion use `dont_unflatten`
:param dont_unflatten: will not automatically unflatten the trajectories into (x,y) pairs.
(!) this assumes you have already unflattened them!
:param wall_locs: Locations of the walls for plotting them in a different color..
:return:
"""
if not dont_unflatten:
trajectories_unflat = list(self.unflat_trajectories(trajectories))
else:
trajectories_unflat = trajectories
state_visitations = np.zeros((self.size, self.size))
# plot actual state visitation heatmap
for trajectory in trajectories_unflat:
for state in trajectory:
x_coord = state[0]
y_coord = state[1]
state_visitations[y_coord, x_coord] += 1.
# plot walls in lame way -- set them to some hand-engineered color
wall_img = np.zeros((self.size, self.size, 4))
if wall_locs is not None:
mid_visits = (np.max(state_visitations) - np.min(state_visitations)) / 2.
for state in wall_locs:
y_coord = state[0]
x_coord = state[1]
wall_img[y_coord, x_coord, 0] = 0.6 # R
wall_img[y_coord, x_coord, 1] = 0.4 # G
wall_img[y_coord, x_coord, 2] = 0.4 # B
wall_img[y_coord, x_coord, 3] = 1.0 # alpha
# render heatmap and overlay the walls image
imshow_ax = ax.imshow(state_visitations, interpolation=None)
imshow_ax = ax.imshow(wall_img, interpolation=None)
ax.grid(False)
return ax, imshow_ax
[docs] def unflat_trajectories(self, trajectories):
"""
Returns a generator where the trajectories have been unflattened.
:param trajectories:
:return:
"""
return map(lambda traj: list(map(self._unflatten, traj)), trajectories)