Source code for emdp.gridworld.txt_utilities

"""Utilities to help load gridworlds from a text file.
"""
from .helper_utilities import flatten_state
from .builder_tools import (TransitionMatrixBuilder,
                            create_reward_matrix)
from . import GridWorldMDP
from typing import Union, Tuple, List


[docs]def get_char_matrix(raw_file): """ Examples: >>> get_char_matrix(['#####', '# g#', '# #', '#s# #', '#####']) [['#', '#', '#', '#', '#'], ['#', ' ', ' ', 'g', '#'], ['#', ' ', ' ', ' ', '#'], ['#', 's', '#', ' ', '#'], ['#', '#', '#', '#', '#']] Args: raw_file: Either a opened python file object or a list of strings containing the lines. """ return [[c for c in line.strip('\n')] for line in raw_file]
[docs]def build_gridworld_from_char_matrix( char_matrix, p_success=1, seed=2017, gamma=1, skip_checks=False, transition_matrix_builder_cls=TransitionMatrixBuilder) -> Tuple[GridWorldMDP, List[Tuple[int, int]]]: """ A parser to build a gridworld from a text file. Each grid has ONE start and goal location. A reward of +1 is positioned at the goal location. Examples: >>> char_matrix = get_char_matrix(['#####', '# g#', '# #', '#s# #', '#####']) >>> mdp, wall_locs = build_gridworld_from_char_matrix(char_matrix) (<emdp.gridworld.env.GridWorldMDP at 0x7fb4a67cb640>, [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (1, 0), (1, 4), (2, 0), (2, 4), (3, 0), (3, 2), (3, 4), (4, 0), (4, 1), (4, 2), (4, 3), (4, 4)]) Args: char_matrix: Matrix of characters. p_success: Probability that the action is successful. seed: The seed for the GridWorldMDP object. skip_checks: Skips assertion checks. transition_matrix_builder_cls: The transition matrix builder to use. Returns: Tuple[GridWorldMDP, List[Tuple[int,int]]]: MDP object, wall locations as list of ``(rwo, col)`` tuple. """ grid_size = len(char_matrix[0]) if not skip_checks: assert(len(char_matrix) == grid_size), 'Mismatch in the columns.' for row in char_matrix: assert(len(row) == grid_size), 'Mismatch in the rows.' wall_locs = [] start_loc = None goal_loc = None for r in range(grid_size): for c in range(grid_size): char = char_matrix[r][c] if char == '#': wall_locs.append((r, c)) elif char == 's': assert start_loc is None, 'Start loc was overwritten!' start_loc = (r, c) elif char == 'g': assert goal_loc is None, 'Goal loc was overwritten!' goal_loc = (r, c) elif char != ' ': raise ValueError('Unknown character {} in grid.'.format(char)) # Attempt to make the desired gridworld. reward_spec = {(goal_loc[0], goal_loc[1]): +1} tmb = transition_matrix_builder_cls(grid_size, has_terminal_state=True) tmb.add_grid(terminal_states=reward_spec.keys(), p_success=p_success) for (r, c) in wall_locs: tmb.add_wall_at((r, c)) P = tmb.P R = create_reward_matrix(P.shape[0], grid_size, reward_spec, action_space=4) p0 = flatten_state(start_loc, grid_size, R.shape[0]) gw = GridWorldMDP(P, R, gamma, p0, terminal_states=reward_spec.keys(), size=grid_size, seed=seed) return gw, wall_locs