Source code for rlgraph.spaces.containers

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import re

from rlgraph.utils.rlgraph_errors import RLGraphError
from rlgraph.utils.ops import DataOpDict, DataOpTuple, FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE
from rlgraph.spaces.space import Space


[docs]class ContainerSpace(Space): """ A simple placeholder class for Spaces that contain other Spaces. """
[docs] def sample(self, size=None, horizontal=False): """ Child classes must overwrite this one again with support for the `horizontal` parameter. Args: horizontal (bool): False: Within this container, sample each child-space `size` times. True: Produce `size` single containers in an np.array of len `size`. """ raise NotImplementedError
[docs]class Dict(ContainerSpace, dict): """ A Dict space (an ordered and keyed combination of n other spaces). Supports nesting of other Dict/Tuple spaces (or any other Space types) inside itself. """ def __init__(self, spec=None, **kwargs): add_batch_rank = kwargs.pop("add_batch_rank", False) add_time_rank = kwargs.pop("add_time_rank", False) time_major = kwargs.pop("time_major", False) ContainerSpace.__init__(self, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) # Allow for any spec or already constructed Space to be passed in as values in the python-dict. # Spec may be part of kwargs. if spec is None: spec = kwargs dict_ = {} for key in sorted(spec.keys()): # Keys must be strings. if not isinstance(key, str): raise RLGraphError("ERROR: No non-str keys allowed in a Dict-Space!") # Prohibit reserved characters (for flattened syntax). if re.search(r'/|{}\d+{}'.format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE), key): raise RLGraphError("ERROR: Key to Dict must not contain '/' or '{}\d+{}'! Is {}.". format(FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE, key)) value = spec[key] # Value is already a Space: Copy it (to not affect original Space) and maybe add/remove batch/time-ranks. if isinstance(value, Space): w_batch_w_time = value.with_extra_ranks(add_batch_rank, add_time_rank, time_major) dict_[key] = w_batch_w_time # Value is a list/tuple -> treat as Tuple space. elif isinstance(value, (list, tuple)): dict_[key] = Tuple( *value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major ) # Value is a spec (or a spec-dict with "type" field) -> produce via `from_spec`. elif (isinstance(value, dict) and "type" in value) or not isinstance(value, dict): dict_[key] = Space.from_spec( value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major ) # Value is a simple dict -> recursively construct another Dict Space as a sub-space of this one. else: dict_[key] = Dict( value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major ) dict.__init__(self, dict_) def _add_batch_rank(self, add_batch_rank=False): super(Dict, self)._add_batch_rank(add_batch_rank) for v in self.values(): v._add_batch_rank(add_batch_rank) def _add_time_rank(self, add_time_rank=False, time_major=False): super(Dict, self)._add_time_rank(add_time_rank, time_major) for v in self.values(): v._add_time_rank(add_time_rank, time_major)
[docs] def force_batch(self, samples): return dict([(key, self[key].force_batch(samples[key])) for key in sorted(self.keys())])
@property def shape(self): return tuple([self[key].shape for key in sorted(self.keys())])
[docs] def get_shape(self, with_batch_rank=False, with_time_rank=False, time_major=None, with_category_rank=False): return tuple([self[key].get_shape( with_batch_rank=with_batch_rank, with_time_rank=with_time_rank, time_major=time_major, with_category_rank=with_category_rank ) for key in sorted(self.keys())])
@property def rank(self): return tuple([self[key].rank for key in sorted(self.keys())]) @property def flat_dim(self): return int(np.sum([c.flat_dim for c in self.values()])) @property def dtype(self): return DataOpDict([(key, subspace.dtype) for key, subspace in self.items()])
[docs] def get_variable(self, name, is_input_feed=False, add_batch_rank=None, add_time_rank=None, time_major=None, **kwargs): return DataOpDict( [(key, subspace.get_variable( name + "/" + key, is_input_feed=is_input_feed, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, **kwargs )) for key, subspace in self.items()] )
def _flatten(self, mapping, custom_scope_separator, scope_separator_at_start, scope_, list_): # Iterate through this Dict. scope_ += custom_scope_separator if len(scope_) > 0 or scope_separator_at_start else "" for key in sorted(self.keys()): self[key].flatten(mapping, custom_scope_separator, scope_separator_at_start, scope_ + key, list_)
[docs] def sample(self, size=None, horizontal=False): if horizontal: return np.array([{key: self[key].sample() for key in sorted(self.keys())}] * (size or 1)) else: return {key: self[key].sample(size=size) for key in sorted(self.keys())}
[docs] def zeros(self, size=None): return DataOpDict([(key, subspace.zeros(size=size)) for key, subspace in self.items()])
[docs] def contains(self, sample): return isinstance(sample, dict) and all(self[key].contains(sample[key]) for key in self.keys())
def __repr__(self): return "Dict({})".format([(key, self[key].__repr__()) for key in self.keys()]) def __eq__(self, other): if not isinstance(other, Dict): return False return dict(self) == dict(other)
[docs]class Tuple(ContainerSpace, tuple): """ A Tuple space (an ordered sequence of n other spaces). Supports nesting of other container (Dict/Tuple) spaces inside itself. """ def __new__(cls, *components, **kwargs): if isinstance(components[0], (list, tuple)): assert len(components) == 1 components = components[0] add_batch_rank = kwargs.get("add_batch_rank", False) add_time_rank = kwargs.get("add_time_rank", False) time_major = kwargs.get("time_major", False) # Allow for any spec or already constructed Space to be passed in as values in the python-list/tuple. list_ = list() for value in components: # Value is already a Space: Copy it (to not affect original Space) and maybe add/remove batch-rank. if isinstance(value, Space): list_.append(value.with_extra_ranks(add_batch_rank, add_time_rank, time_major)) # Value is a list/tuple -> treat as Tuple space. elif isinstance(value, (list, tuple)): list_.append( Tuple(*value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) ) # Value is a spec (or a spec-dict with "type" field) -> produce via `from_spec`. elif (isinstance(value, dict) and "type" in value) or not isinstance(value, dict): list_.append(Space.from_spec( value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major )) # Value is a simple dict -> recursively construct another Dict Space as a sub-space of this one. else: list_.append(Dict( value, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major )) return tuple.__new__(cls, list_) def __init__(self, *components, **kwargs): add_batch_rank = kwargs.get("add_batch_rank", False) add_time_rank = kwargs.get("add_time_rank", False) time_major = kwargs.get("time_major", False) super(Tuple, self).__init__(add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major) def _add_batch_rank(self, add_batch_rank=False): super(Tuple, self)._add_batch_rank(add_batch_rank) for v in self: v._add_batch_rank(add_batch_rank) def _add_time_rank(self, add_time_rank=False, time_major=False): super(Tuple, self)._add_time_rank(add_time_rank, time_major) for v in self: v._add_time_rank(add_time_rank, time_major)
[docs] def force_batch(self, samples): return tuple([c.force_batch(samples[i]) for i, c in enumerate(self)])
@property def shape(self): return tuple([c.shape for c in self])
[docs] def get_shape(self, with_batch_rank=False, with_time_rank=False, time_major=None, with_category_rank=False): return tuple([c.get_shape( with_batch_rank=with_batch_rank, with_time_rank=with_time_rank, time_major=time_major, with_category_rank=with_category_rank ) for c in self])
@property def rank(self): return tuple([c.rank for c in self]) @property def flat_dim(self): return np.sum([c.flat_dim for c in self]) @property def dtype(self): return DataOpTuple([c.dtype for c in self])
[docs] def get_variable(self, name, is_input_feed=False, add_batch_rank=None, add_time_rank=None, time_major=None, **kwargs): return DataOpTuple( [subspace.get_variable( name+"/"+str(i), is_input_feed=is_input_feed, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank, time_major=time_major, **kwargs ) for i, subspace in enumerate(self)] )
def _flatten(self, mapping, custom_scope_separator, scope_separator_at_start, scope_, list_): # Iterate through this Tuple. scope_ += (custom_scope_separator if len(scope_) > 0 or scope_separator_at_start else "") + FLAT_TUPLE_OPEN for i, component in enumerate(self): component.flatten( mapping, custom_scope_separator, scope_separator_at_start, scope_ + str(i) + FLAT_TUPLE_CLOSE, list_ )
[docs] def sample(self, size=None, horizontal=False): if horizontal: return np.array([tuple(subspace.sample() for subspace in self)] * (size or 1)) else: return tuple(x.sample(size=size) for x in self)
[docs] def zeros(self, size=None): return tuple([c.zeros(size=size) for i, c in enumerate(self)])
[docs] def contains(self, sample): return isinstance(sample, (tuple, list, np.ndarray)) and len(self) == len(sample) and \ all(c.contains(xi) for c, xi in zip(self, sample))
def __repr__(self): return "Tuple({})".format(tuple([cmp.__repr__() for cmp in self])) def __eq__(self, other): return tuple.__eq__(self, other)