Source code for rlgraph.spaces.float_box

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from rlgraph.utils.util import dtype as dtype_
from rlgraph.spaces.box_space import BoxSpace


[docs]class FloatBox(BoxSpace): def __init__(self, low=None, high=None, shape=None, dtype="float32", **kwargs): if low is None: assert high is None, "ERROR: If `low` is None, `high` must be None as well!" low = float("-inf") high = float("inf") self.unbounded = True else: self.unbounded = False # support calls like (FloatBox(1.0) -> low=0.0, high=1.0) if high is None: high = low low = 0.0 dtype = dtype_(dtype, "np") assert dtype in [np.float16, np.float32, np.float64], "ERROR: FloatBox does not allow dtype '{}'!".format(dtype) super(FloatBox, self).__init__(low=low, high=high, shape=shape, dtype=dtype, **kwargs)
[docs] def sample(self, size=None, fill_value=None): shape = self._get_np_shape(num_samples=size) if fill_value is not None: sample_ = np.full(shape=shape, fill_value=fill_value) else: if self.unbounded: sample_ = np.random.uniform(size=shape) else: sample_ = np.random.uniform(low=self.low, high=self.high, size=shape) # Make sure return values have the right dtype (float64 is np.random's default). return np.asarray(sample_, dtype=self.dtype)