Source code for rlgraph.components.neural_networks.dict_preprocessor_stack

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from rlgraph.components.layers.preprocessing import PreprocessLayer
from rlgraph.components.neural_networks.preprocessor_stack import PreprocessorStack
from rlgraph.spaces import ContainerSpace, Dict
from rlgraph.utils.decorators import rlgraph_api
from rlgraph.utils.ops import flatten_op, unflatten_op
from rlgraph.utils.util import default_dict

if get_backend() == "tf":
    import tensorflow as tf


[docs]class DictPreprocessorStack(PreprocessorStack): """ A generic PreprocessorStack that can handle Dict/Tuple Spaces and parallely preprocess different Spaces within different (and separate) single PreprocessorStack components. The output is again a dict of preprocessed inputs. API: preprocess(input\_): Outputs the preprocessed input after sending it through all sub-Components of this Stack. reset(): An op to trigger all PreprocessorStacks of this Vector to be reset. """ def __init__(self, preprocessors, **kwargs): """ Args: preprocessors (dict): Raises: RLGraphError: If a sub-component is not a PreprocessLayer object. """ # Create one separate PreprocessorStack per given key. # All possibly other keys in an input will be pass through un-preprocessed. self.preprocessors = flatten_op(preprocessors) for i, (flat_key, spec) in enumerate(self.preprocessors.items()): self.preprocessors[flat_key] = PreprocessorStack.from_spec( spec, scope="preprocessor-stack-{}".format(i) ) # NOTE: No automatic API-methods. Define them all ourselves. kwargs["api_methods"] = {} default_dict(kwargs, dict(scope=kwargs.pop("scope", "dict-preprocessor-stack"))) super(DictPreprocessorStack, self).__init__(*list(self.preprocessors.values()), **kwargs) @rlgraph_api(flatten_ops=True, split_ops=True, add_auto_key_as_first_param=True) def _graph_fn_preprocess(self, key, inputs): # Is a PreprocessorStack defined for this key? if key in self.preprocessors: return self.preprocessors[key].preprocess(inputs) # Simple pass through, no preprocessing. else: return inputs @rlgraph_api def reset(self): # TODO: python-Components: For now, we call each preprocessor's graph_fn directly. if self.backend == "python" or get_backend() == "python": for preprocessor in self.preprocessors.values(): # type: PreprocessLayer preprocessor.reset() elif get_backend() == "tf": # Connect each pre-processor's "reset" output op via our graph_fn into one op. resets = list() for preprocessor in self.preprocessors.values(): # type: PreprocessorStack resets.append(preprocessor.reset()) reset_op = self._graph_fn_reset(*resets) return reset_op def _graph_fn_reset(self, *preprocessor_resets): if get_backend() == "tf": with tf.control_dependencies(preprocessor_resets): return tf.no_op()
[docs] def get_preprocessed_space(self, space): """ Returns the Space obtained after pushing the input through all layers of this Stack. Args: space (Dict): The incoming Space object. Returns: Space: The Space after preprocessing. """ assert isinstance(space, ContainerSpace) dict_spec = dict() for flat_key, sub_space in space.flatten().items(): if flat_key in self.preprocessors: dict_spec[flat_key] = self.preprocessors[flat_key].get_preprocessed_space(sub_space) else: dict_spec[flat_key] = sub_space dict_spec = unflatten_op(dict_spec) return Dict(dict_spec)