# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import copy
from rlgraph.utils.specifiable import Specifiable
[docs]class Space(Specifiable):
"""
Space class (based on and compatible with openAI Spaces).
Provides a classification for state-, action-, reward- and other spaces.
"""
def __init__(self, add_batch_rank=False, add_time_rank=False, time_major=False):
"""
Args:
add_batch_rank (bool): Whether to always add a batch rank at the 0th (or 1st) position when creating
variables from this Space.
add_time_rank (bool): Whether to always add a time rank at the 1st (or 0th) position when creating
variables from this Space.
time_major (bool): Whether the time rank should come before the batch rank. Not important if one
of these ranks (or both) does not exist.
"""
super(Space, self).__init__()
self._shape = None
self.has_batch_rank = None
self.has_time_rank = None
self.time_major = None
self._add_batch_rank(add_batch_rank)
self._add_time_rank(add_time_rank, time_major)
def _add_batch_rank(self, add_batch_rank=False):
"""
Changes the add_batch_rank property of this Space (and of all child Spaces in a ContainerSpace).
Args:
add_batch_rank (bool): Whether this Space (and all child Spaces in a ContainerSpace) should have a
batch rank.
"""
self.has_batch_rank = add_batch_rank
def _add_time_rank(self, add_time_rank=False, time_major=False):
"""
Changes the add_time_rank property of this Space (and of all child Spaces in a ContainerSpace).
Args:
add_time_rank (bool): Whether this Space (and all child Spaces in a ContainerSpace) should have a
time rank.
time_major (bool): Whether the time rank should come before the batch rank. Not important if no batch rank
exists.
"""
self.has_time_rank = add_time_rank
self.time_major = time_major
[docs] def with_batch_rank(self, add_batch_rank=True):
"""
Returns a deepcopy of this Space, but with `has_batch_rank` set to the provided value.
Args:
add_batch_rank (Union[bool,int]): The fixed size of the batch-rank or True or False.
Returns:
Space: The deepcopy of this Space, but with `has_batch_rank` set to True.
"""
return self.with_extra_ranks(add_batch_rank=add_batch_rank, add_time_rank=None)
[docs] def with_time_rank(self, add_time_rank=True):
"""
Returns a deepcopy of this Space, but with `has_time_rank` set to the provided value.
Args:
add_time_rank (Union[bool,int]): The fixed size of the time-rank or True or False.
Returns:
Space: The deepcopy of this Space, but with `has_time_rank` set to True.
"""
return self.with_extra_ranks(add_batch_rank=None, add_time_rank=add_time_rank)
[docs] def force_batch(self, samples):
"""
Makes sure that `samples` is always returned with a batch rank no matter whether
it already has one or not (in which case this method returns a batch of 1) or
whether this Space has a batch rank or not.
Args:
samples (any): The samples to be batched. If already batched, return as-is.
Returns:
any: The batched sample.
"""
raise NotImplementedError
@property
def shape(self):
"""
Returns:
tuple: The shape of this Space as a tuple. Without batch or time ranks.
"""
return self._shape
[docs] def get_shape(self, with_batch_rank=False, with_time_rank=False, time_major=None, **kwargs):
"""
Returns the shape of this Space as a tuple with certain additional ranks at the front (batch) or the back
(e.g. categories).
Args:
with_batch_rank (Union[bool,int]): Whether to include a possible batch-rank as `None` at 0th (or 1st)
position. If `with_batch_rank` is an int (e.g. -1), the possible batch-rank is returned as that number
(instead of None) at the 0th (or 1st if time_major is True) position.
Default: False.
with_time_rank (Union[bool,int]): Whether to include a possible time-rank as `None` at 1st (or 0th)
position. If `with_time_rank` is an int, the possible time-rank is returned as that number
(instead of None) at the 1st (or 0th if time_major is True) position.
Default: False.
time_major (bool): Overwrites `self.time_major` if not None. Default: None (use `self.time_major`).
Returns:
tuple: The shape of this Space as a tuple.
"""
raise NotImplementedError
@property
def rank(self):
"""
Returns:
int: The rank of the Space not including batch- or time-ranks
(e.g. 3 for a space with shape=(10, 7, 5)).
"""
return len(self.shape)
@property
def flat_dim(self):
"""
Returns:
int: The length of a flattened vector derived from this Space.
"""
raise NotImplementedError
[docs] def get_variable(self, name, is_input_feed=False, add_batch_rank=None, add_time_rank=None,
time_major=False, is_python=False, local=False, **kwargs):
"""
Returns a backend-specific variable/placeholder that matches the space's shape.
Args:
name (str): The name for the variable.
is_input_feed (bool): Whether the returned object should be an input placeholder,
instead of a full variable.
add_batch_rank (Optional[bool,int]): If True, will add a 0th (or 1st) rank (None) to
the created variable. If it is an int, will add that int (-1 means None).
If None, will use the Space's default value: `self.has_batch_rank`.
Default: None.
add_time_rank (Optional[bool,int]): If True, will add a 1st (or 0th) rank (None) to
the created variable. If it is an int, will add that int (-1 means None).
If None, will use the Space's default value: `self.has_time_rank`.
Default: None.
time_major (bool): Only relevant if both `add_batch_rank` and `add_time_rank` are True.
Will make the time-rank the 0th rank and the batch-rank the 1st rank.
Otherwise, batch-rank will be 0th and time-rank will be 1st.
Default: False.
is_python (bool): Whether to create a python-based variable (list) or a backend-specific one.
local (bool): Whether the variable must not be shared across the network.
Default: False.
Keyword Args:
To be passed on to backend-specific methods (e.g. trainable, initializer, etc..).
Returns:
any: A Tensor Variable/Placeholder.
"""
raise NotImplementedError
[docs] def flatten(self, mapping=None, custom_scope_separator='/', scope_separator_at_start=True,
scope_=None, list_=None):
"""
A mapping function to flatten this Space into an OrderedDict whose only values are
primitive (non-container) Spaces. The keys are created automatically from Dict keys and
Tuple indexes.
Args:
mapping (Optional[callable]): A mapping function that takes a flattened auto-generated key and a primitive
Space and converts the primitive Space to something else. Default is pass through.
custom_scope_separator (str): The separator to use in the returned dict for scopes.
Default: '/'.
scope_separator_at_start (bool): Whether to add the scope-separator also at the beginning.
Default: True.
scope\_ (Optional[str]): For recursive calls only. Used for automatic key generation.
list\_ (Optional[list]): For recursive calls only. The list so far.
Returns:
OrderedDict: The OrderedDict using auto-generated keys and containing only primitive Spaces
(or whatever the mapping function maps the primitive Spaces to).
"""
# default: no mapping
if mapping is None:
def mapping(key, x):
return x
# Are we in the non-recursive (first) call?
ret = False
if list_ is None:
list_ = list()
ret = True
scope_ = ""
self._flatten(mapping, custom_scope_separator, scope_separator_at_start, scope_, list_)
# Non recursive (first) call -> Return the final FlattenedDataOp.
if ret:
return OrderedDict(list_)
def _flatten(self, mapping, custom_scope_separator, scope_separator_at_start, scope_, list_):
"""
Base implementation. May be overridden by ContainerSpace classes.
Simply sends `self` through the mapping function.
Args:
mapping (callable): The mapping function to use on a primitive (non-container) Space.
custom_scope_separator (str): The separator to use in the returned dict for scopes.
Default: '/'.
scope_separator_at_start (bool): Whether to add the scope-separator also at the beginning.
Default: True.
scope\_ (str): The flat-key to use to store the mapped result in list_.
list\_ (list): The list to append the mapped results to (under key=`scope_`).
"""
list_.append(tuple([scope_, mapping(scope_, self)]))
def __repr__(self):
return "Space(shape=" + str(self.shape) + ")"
def __eq__(self, other):
raise NotImplementedError
[docs] def sample(self, size=None, fill_value=None):
"""
Uniformly randomly samples an element from this space. This is for testing purposes, e.g. to simulate
a random environment.
Args:
size (Optional[int]): The number of samples or batch size to sample.
If size is > 1: Returns a batch of size samples with the 0th rank being the batch rank
(even if `self.has_batch_rank` is False).
If size is None or (1 and self.has_batch_rank is False): Returns a single sample w/o batch rank.
If size is 1 and self.has_batch_rank is True: Returns a single sample w/ the batch rank.
fill_value (Optional[any]): The number or initializer specifier to fill the sample. Can be used to create
a (non-random) sample with a certain fill value in all elements.
TODO: support initializer spec-strings like 'normal', 'truncated_normal', etc..
Returns:
any: The sampled element(s).
"""
raise NotImplementedError
[docs] def zeros(self, size=None):
"""
Args:
size (Optional): Same as `Space.sample()`.
Returns:
np.ndarray: `size` zero samples where all values are zero and have the correct type.
"""
raise NotImplementedError
def _get_np_shape(self, num_samples=None):
"""
Helper to determine, which shape one should pass to the numpy random funcs for sampling from a Space.
Depends on `num_samples`, the `shape` of this Space and the `self.has_batch_rank/has_time_rank` settings.
Args:
num_samples (Optional[int,Tuple[int,int]]): Number of samples to pull. If None or 0, pull 1 sample, but
without batch/time rank (no matter what the value of `self.has_batch_rank` is).
If tuple given, use the given values as time/batch ranks.
Returns:
Tuple[int]: Shape to use for numpy random sampling.
"""
# No extra batch/time rank.
if num_samples is None or (
num_samples == () or num_samples == 1 and not self.has_batch_rank and not self.has_time_rank
):
if len(self.shape) == 0:
return None
else:
return self.shape
# With one extra rank.
elif isinstance(num_samples, int):
return (num_samples,) + self.shape
# With two extra ranks (given as list or tuple).
else:
assert isinstance(num_samples, (tuple, list)) and len(num_samples) == 2,\
"ERROR: num_samples must be int or tuple/list of two ints, but is '{}'!".format(num_samples)
return tuple(num_samples) + self.shape
[docs] def contains(self, sample):
"""
Checks whether this space contains the given sample. This is more for testing purposes.
Args:
sample: The element to check.
Returns:
bool: Whether sample is a valid member of this space.
"""
raise NotImplementedError