Source code for rlgraph.components.neural_networks.preprocessor_stack

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from rlgraph.components.layers.preprocessing import PreprocessLayer
from rlgraph.utils.util import default_dict
from rlgraph.components.neural_networks.stack import Stack
from rlgraph.utils.decorators import rlgraph_api, graph_fn

if get_backend() == "tf":
    import tensorflow as tf


[docs]class PreprocessorStack(Stack): """ A special Stack that only carries PreprocessLayer Components and bundles all their `reset` output ops into one exposed `reset` output op. Otherwise, behaves like a Stack in feeding the outputs of one sub-Component to the inputs of the next sub-Component, etc.. API: preprocess(input\_): Outputs the preprocessed input after sending it through all sub-Components of this Stack. reset(): An op to trigger all PreprocessorLayers of this Stack to be reset. """ def __init__(self, *preprocessors, **kwargs): """ Args: preprocessors (PreprocessorLayer): The PreprocessorLayers to add to the Stack and connect to each other. Raises: RLGraphError: If a sub-component is not a PreprocessLayer object. """ # Link sub-Components' `apply` methods together to yield PreprocessorStack's `preprocess` method. # NOTE: Do not include `reset` here as it is defined explicitly below. kwargs["api_methods"] = {("preprocess", "apply")} default_dict(kwargs, dict(scope=kwargs.pop("scope", "preprocessor-stack"))) super(PreprocessorStack, self).__init__(*preprocessors, **kwargs) @rlgraph_api def reset(self): # TODO: python-Components: For now, we call each preprocessor's graph_fn directly. if self.backend == "python" or get_backend() == "python": for preprocessor in self.sub_components.values(): # type: PreprocessLayer preprocessor._graph_fn_reset() elif get_backend() == "tf": # Connect each pre-processor's "reset" output op via our graph_fn into one op. resets = list() for preprocess_layer in self.sub_components.values(): # type: PreprocessLayer resets.append(preprocess_layer.reset()) reset_op = self._graph_fn_reset(*resets) return reset_op @graph_fn def _graph_fn_reset(self, *preprocessor_resets): if get_backend() == "tf": with tf.control_dependencies(preprocessor_resets): return tf.no_op()
[docs] def get_preprocessed_space(self, space): """ Returns the Space obtained after pushing the input through all layers of this Stack. Args: space (Space): The incoming Space object. Returns: Space: The Space after preprocessing. """ for pp in self.sub_components.values(): space = pp.get_preprocessed_space(space) return space