Source code for rlgraph.components.neural_networks.stack

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from rlgraph.utils.decorators import rlgraph_api
from rlgraph.components.component import Component
from rlgraph.utils.util import force_tuple, force_list


[docs]class Stack(Component): """ A component container stack that incorporates one or more sub-components some of whose API-methods (default: only `apply`) are automatically connected with each other (in the sequence the sub-Components are given in the c'tor), resulting in an API of the Stack. All sub-components' API-methods need to match in the number of input and output values. E.g. the third sub-component's api-metehod's number of return values has to match the forth sub-component's api-method's number of input parameters. """ def __init__(self, *sub_components, **kwargs): """ Args: sub_components (Union[Component,List[Component]]): The sub-components to add to the Stack and connect to each other. Keyword Args: api_methods (Set[Union[str,Tuple[str,str]]]): A set of names of API-methods to connect through the stack. Defaults to {"apply"}. All sub-Components must implement all API-methods in this set. Alternatively, a tuple can be used (instead of a string), in which case the first tuple-item is used as the Stack's API-method name and the second item is the sub-Components' API-method name. E.g. api_methods={("stack_run", "run")}. This will create "stack_run" for the Stack, which will call - one by one - all the "run" methods of the sub-Components. Connecting always works by first calling the first sub-Component's API-method, then - with the result - calling the second sub-Component's API-method, etc.. This is done for all API-methods in the given set. """ # Network object for fast-path execution where we do not repeatedely call `call` between layers. self.stack_obj = None api_methods = kwargs.pop("api_methods", {"apply"}) super(Stack, self).__init__(*sub_components, scope=kwargs.pop("scope", "stack"), **kwargs) self.num_allowed_inputs = None self.num_allowed_returns = None self._build_stack(api_methods) def _build_stack(self, api_methods): """ For each api-method in set `api_methods`, automatically create this Stack's own API-method by connecting through all sub-Component's API-methods. This is skipped if this Stack already has a custom API-method by that name. Args: api_methods (Set[Union[str,Tuple[str,str]]]): See ctor kwargs. """ # Loop through the API-method set. for api_method_spec in api_methods: function_to_use = None # API-method of sub-Components and this Stack should have different names. if isinstance(api_method_spec, tuple): # Custom method given, use that instead of creating one automatically. if callable(api_method_spec[1]): stack_api_method_name = components_api_method_name = api_method_spec[0] function_to_use = api_method_spec[1] else: stack_api_method_name, components_api_method_name = api_method_spec[0], api_method_spec[1] # API-method of sub-Components and this Stack should have the same name. else: stack_api_method_name = components_api_method_name = api_method_spec # API-method for this Stack does not exist yet -> Automatically create it. if not hasattr(self, stack_api_method_name): # Custom API-method is given (w/o decorator) -> Call the decorator directly here to register it. if function_to_use is not None: rlgraph_api(api_method=function_to_use, component=self, name=stack_api_method_name) # No API-method given -> Create auto-API-method and set it up through decorator. else: @rlgraph_api(name=stack_api_method_name, component=self) def method(self_, *inputs, **kwargs): args_ = inputs kwargs_ = kwargs for i, sub_component in enumerate(self_.sub_components.values()): # type: Component # TODO: python-Components: For now, we call each preprocessor's graph_fn # directly (assuming that inputs are not ContainerSpaces). if self_.backend == "python" or get_backend() == "python": graph_fn = getattr(sub_component, "_graph_fn_" + components_api_method_name) #if sub_component.api_methods[components_api_method_name].add_auto_key_as_first_param: # results = graph_fn("", *args_) # TODO: kwargs?? #else: results = graph_fn(*args_) elif get_backend() == "pytorch": # Do NOT convert to tuple, has to be in unpacked again immediately.n results = getattr(sub_component, components_api_method_name)(*force_list(args_)) else: #if get_backend() == "tf": results = getattr(sub_component, components_api_method_name)(*args_, **kwargs_) # Recycle args_, kwargs_ for reuse in next sub-Component's API-method call. if isinstance(results, dict): args_ = () kwargs_ = results else: args_ = force_tuple(results) kwargs_ = {} if args_ == (): return kwargs_ elif len(args_) == 1: return args_[0] else: return args_
[docs] @classmethod def from_spec(cls, spec=None, **kwargs): if isinstance(spec, dict): kwargs["_args"] = list(spec.pop("layers", [])) elif isinstance(spec, (tuple, list)): kwargs["_args"] = spec spec = None return super(Stack, cls).from_spec(spec, **kwargs)