Source code for rlgraph.spaces.space_utils

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from six.moves import xrange as range_
import numpy as np

from rlgraph.spaces.bool_box import BoolBox
from rlgraph.spaces.box_space import BoxSpace
from rlgraph.spaces.containers import Dict, Tuple
from rlgraph.spaces.float_box import FloatBox
from rlgraph.spaces.int_box import IntBox
from rlgraph.spaces.text_box import TextBox
from rlgraph.utils.util import RLGraphError, dtype, get_shape

if get_backend() == "pytorch":
    import torch


# TODO: replace completely by `Component.get_variable` (python-backend)
[docs]def get_list_registry(from_space, capacity=None, initializer=0, flatten=True, add_batch_rank=False): """ Creates a list storage for a space by providing an ordered dict mapping space names to empty lists. Args: from_space: Space to create registry from. capacity (Optional[int]): Optional capacity to initalize list. initializer (Optional(any)): Optional initializer for list if capacity is not None. flatten (bool): Whether to produce a FlattenedDataOp with auto-keys. add_batch_rank (Optional[bool,int]): If from_space is given and is True, will add a 0th rank (None) to the created variable. If it is an int, will add that int instead of None. Default: False. Returns: dict: Container dict mapping spaces to empty lists. """ if flatten: if capacity is not None: var = from_space.flatten( custom_scope_separator="-", scope_separator_at_start=False, mapping=lambda k, primitive: [initializer for _ in range_(capacity)] ) else: var = from_space.flatten( custom_scope_separator="-", scope_separator_at_start=False, mapping=lambda k, primitive: [] ) else: if capacity is not None: var = [initializer for _ in range_(capacity)] else: var = [] return var
[docs]def get_space_from_op(op): """ Tries to re-create a Space object given some DataOp. This is useful for shape inference when passing a Socket's ops through a GraphFunction and auto-inferring the resulting shape/Space. Args: op (DataOp): The op to create a corresponding Space for. Returns: Space: The inferred Space object. """ # a Dict if isinstance(op, dict): # DataOpDict spec = {} add_batch_rank = False add_time_rank = False for key, value in op.items(): spec[key] = get_space_from_op(value) if spec[key].has_batch_rank: add_batch_rank = True if spec[key].has_time_rank: add_time_rank = True return Dict(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # a Tuple elif isinstance(op, tuple): # DataOpTuple spec = [] add_batch_rank = False add_time_rank = False for i in op: space = get_space_from_op(i) if space == 0: return 0 spec.append(space) if spec[-1].has_batch_rank: add_batch_rank = True if spec[-1].has_time_rank: add_time_rank = True return Tuple(spec, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank) # primitive Space -> infer from op dtype and shape else: # Simple constant value DataOp (python type or an np.ndarray). assert not hasattr(op, "constant_value") # we should be done with this by now #if isinstance(op, SingleDataOp) and op.constant_value is not None: # value = op.constant_value # if isinstance(value, np.ndarray): # return BoxSpace.from_spec(spec=dtype(str(value.dtype), "np"), shape=value.shape) # Op itself is a single value, simple python type. if isinstance(op, (bool, int, float)): return BoxSpace.from_spec(spec=type(op), shape=()) # A single numpy array. elif isinstance(op, np.ndarray): return BoxSpace.from_spec(spec=dtype(str(op.dtype), "np"), shape=op.shape) # No Space: e.g. the tf.no_op, a distribution (anything that's not a tensor). # PyTorch Tensors do not have get_shape so must check backend. elif hasattr(op, "dtype") is False or (get_backend() == "tf" and not hasattr(op, "get_shape")): return 0 # Some tensor: can be converted into a BoxSpace. else: shape = get_shape(op, ) # Unknown shape (e.g. a cond op). if shape is None: return 0 add_batch_rank = False add_time_rank = False time_major = False new_shape = list(shape) # New way: Detect via op._batch_rank and op._time_rank properties where these ranks are. if hasattr(op, "_batch_rank") and isinstance(op._batch_rank, int): add_batch_rank = True new_shape[op._batch_rank] = -1 # elif get_backend() == "pytorch": # if isinstance(op, torch.Tensor): # if op.dim() > 1 and shape[0] == 1: # add_batch_rank = True # new_shape[0] = 1 if hasattr(op, "_time_rank") and isinstance(op._time_rank, int): add_time_rank = True if op._time_rank == 0: time_major = True new_shape[op._time_rank] = -1 shape = tuple(n for n in new_shape if n != -1) # Old way: Detect automatically whether the first rank(s) are batch and/or time rank. if add_batch_rank is False and add_time_rank is False and shape != () and shape[0] is None: if len(shape) > 1 and shape[1] is None: #raise RLGraphError( # "ERROR: Cannot determine time-major flag if both batch- and time-ranks are in an op w/o saying " # "which rank goes to which position!" #) shape = shape[2:] add_time_rank = True else: shape = shape[1:] add_batch_rank = True base_dtype = op.dtype.base_dtype if hasattr(op.dtype, "base_dtype") else op.dtype # PyTorch does not have a bool type if get_backend() == "pytorch": if op.dtype is torch.uint8: base_dtype = bool base_dtype_str = str(base_dtype) # FloatBox if "float" in base_dtype_str: return FloatBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=dtype(base_dtype, "np")) # IntBox elif "int" in base_dtype_str: return IntBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, dtype=dtype(base_dtype, "np")) # a BoolBox elif "bool" in base_dtype_str: return BoolBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # a TextBox elif "string" in base_dtype_str: return TextBox(shape=shape, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) raise RLGraphError("ERROR: Cannot derive Space from op '{}' (unknown type?)!".format(op))
[docs]def sanity_check_space( space, allowed_types=None, non_allowed_types=None, must_have_batch_rank=None, must_have_time_rank=None, must_have_batch_or_time_rank=False, must_have_categories=None, num_categories=None, rank=None ): """ Sanity checks a given Space for certain criteria and raises exceptions if they are not met. Args: space (Space): The Space object to check. allowed_types (Optional[List[type]]): A list of types that this Space must be an instance of. non_allowed_types (Optional[List[type]]): A list of type that this Space must not be an instance of. must_have_batch_rank (Optional[bool]): Whether the Space must (True) or must not (False) have the `has_batch_rank` property set to True. None, if it doesn't matter. must_have_time_rank (Optional[bool]): Whether the Space must (True) or must not (False) have the `has_time_rank` property set to True. None, if it doesn't matter. must_have_batch_or_time_rank (Optional[bool]): Whether the Space must (True) or must not (False) have either the `has_batch_rank` or the `has_time_rank` property set to True. must_have_categories (Optional[bool]): For IntBoxes, whether the Space must (True) or must not (False) have global bounds with `num_categories` > 0. None, if it doesn't matter. num_categories (Optional[int,tuple]): An int or a tuple (min,max) range within which the Space's `num_categories` rank must lie. Only valid for IntBoxes. None if it doesn't matter. rank (Optional[int,tuple]): An int or a tuple (min,max) range within which the Space's rank must lie. None if it doesn't matter. Raises: RLGraphError: Various RLGraphErrors, if any of the conditions is not met. """ # Check the types. if allowed_types is not None: if not isinstance(space, tuple(allowed_types)): raise RLGraphError("ERROR: Space ({}) is not an instance of {}!".format(space, allowed_types)) if non_allowed_types is not None: if isinstance(space, tuple(non_allowed_types)): raise RLGraphError("ERROR: Space ({}) must not be an instance of {}!".format(space, non_allowed_types)) if must_have_batch_or_time_rank is True: if space.has_batch_rank is False and space.has_time_rank is False: raise RLGraphError( "ERROR: Space ({}) does not have a batch- or a time-rank, but must have either one of " "these!".format(space) ) if must_have_batch_rank is not None: if (space.has_batch_rank is False and must_have_batch_rank is True) or \ (space.has_batch_rank is not False and must_have_batch_rank is False): # Last chance: Check for rank >= 2, that would be ok as well. if must_have_batch_rank is True and len(space.get_shape(with_batch_rank=True)) >= 2: pass # Something is wrong. elif space.has_batch_rank is not False: raise RLGraphError("ERROR: Space ({}) has a batch rank, but is not allowed to!".format(space)) else: raise RLGraphError("ERROR: Space ({}) does not have a batch rank, but must have one!".format(space)) if must_have_time_rank is not None: if (space.has_time_rank is False and must_have_time_rank is True) or \ (space.has_time_rank is not False and must_have_time_rank is False): # Last chance: Check for rank >= 3, that would be ok as well. if must_have_time_rank is True and len(space.get_shape(with_batch_rank=True, with_time_rank=True)) >= 2: pass # Something is wrong. elif space.has_time_rank is not False: raise RLGraphError("ERROR: Space ({}) has a time rank, but is not allowed to!".format(space)) else: raise RLGraphError("ERROR: Space ({}) does not have a time rank, but must have one!".format(space)) if must_have_categories is not None: if not isinstance(space, IntBox): raise RLGraphError("ERROR: Space ({}) is not an IntBox. Only IntBox Spaces can have categories!". format(space)) elif space.global_bounds is False: raise RLGraphError("ERROR: Space ({}) must have categories (globally valid value bounds)!".format(space)) if rank is not None: flattened = space.flatten() if isinstance(rank, int): for key, sub_space in flattened.items(): if sub_space.rank != rank: raise RLGraphError( "ERROR: A Space (flat-key={}) of '{}' has rank {}, but must have rank " "{}!".format(key, space, sub_space.rank, rank) ) else: for key, sub_space in flattened.items(): if not ((rank[0] or 0) <= sub_space.rank <= (rank[1] or float("inf"))): raise RLGraphError( "ERROR: A Space (flat-key={}) of '{}' has rank {}, but its rank must be between {} and " "{}!".format(key, space, sub_space.rank, rank[0], rank[1])) if num_categories is not None: if not isinstance(space, IntBox): raise RLGraphError("ERROR: Space ({}) is not an IntBox. Only IntBox Spaces can have " "categories!".format(space)) elif isinstance(num_categories, int): if space.num_categories != num_categories: raise RLGraphError("ERROR: Space ({}) has `num_categories` {}, but must have {}!". format(space, space.num_categories, num_categories)) elif not ((num_categories[0] or 0) <= space.num_categories <= (num_categories[1] or float("inf"))): raise RLGraphError("ERROR: Space ({}) has `num_categories` {}, but this value must be between {} and " "{}!".format(space, space.num_categories, num_categories[0], num_categories[1]))
[docs]def check_space_equivalence(space1, space2): """ Compares the two input Spaces for equivalence and returns the more generic Space of the two. The more generic Space is the one that has the properties has_batch_rank and/or has _time_rank set (instead of hard values in these ranks). E.g.: FloatBox((64,)) is equivalent with FloatBox((), +batch-rank). The latter will be returned. NOTE: FloatBox((2,)) and FloatBox((3,)) are NOT equivalent. Args: space1 (Space): The 1st Space to compare. space2 (Space): The 2nd Space to compare. Returns: Union[Space,False]: False is the two spaces are not equivalent. The more generic Space of the two if they are equivalent. """ # Spaces are the same: Return one of them. if space1 == space2: return space1 # One has batch-rank, the other doesn't, but has one more rank. elif space1.has_batch_rank and not space2.has_batch_rank and \ (np.asarray(space1.rank) == np.asarray(space2.rank) - 1).all(): return space1 elif space2.has_batch_rank and not space1.has_batch_rank and \ (np.asarray(space2.rank) == np.asarray(space1.rank) - 1).all(): return space2 elif get_backend() == "pytorch": if not space1.has_batch_rank and not space2.has_batch_rank and \ (np.asarray(space1.rank) == np.asarray(space2.rank)).all(): return space1 # TODO problem is that batch ranks are principally not handled correctly here -> # e.g. (batch_size, 256), (256,) can both be valid between layers if not space1.has_batch_rank and not space2.has_batch_rank and space1.rank > space2.rank: return space2 if not space1.has_batch_rank and not space2.has_batch_rank and space1.rank < space2.rank: return space1 # TODO: time rank? return False