Source code for rlgraph.spaces.text_box

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
# import random
# import string

from rlgraph.spaces.box_space import BoxSpace


[docs]class TextBox(BoxSpace): """ A text box in TXT^n where the shape means the number of text chunks in each dimension. """ def __init__(self, shape=(), **kwargs): """ Args: shape (tuple): The shape of this space. """ # Set both low/high to 0 (make no sense for text). super(TextBox, self).__init__(low=0, high=0, **kwargs) # Set dtype to numpy's unicode type. self.dtype = np.unicode_ assert isinstance(shape, tuple), "ERROR: `shape` must be a tuple." self._shape = shape
[docs] def sample(self, size=None, fill_value=None): shape = self._get_np_shape(num_samples=size) # TODO: Make it such that it doesn't only produce number strings (using `petname` module?). sample_ = np.full(shape=shape, fill_value=fill_value, dtype=self.dtype) return sample_.astype(self.dtype)
[docs] def contains(self, sample): sample_shape = sample.shape if not isinstance(sample, str) else () return sample_shape == self.shape