Source code for rlgraph.components.layers.nn.residual_layer

# Copyright 2018 The RLgraph authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from rlgraph.components.layers.nn.activation_functions import get_activation_function
from rlgraph.components.layers.nn.nn_layer import NNLayer
from rlgraph.utils.decorators import rlgraph_api

[docs]class ResidualLayer(NNLayer): """ A residual layer that adds the input value to some calculation. Based on: [1] Identity Mappings in Deep Residual Networks - He, Zhang, Ren and Sun (Microsoft) 2016 ( API: apply(input\_) -> """ def __init__(self, residual_unit, repeats=2, scope="residual-layer", **kwargs): """ Args: residual_unit (NeuralNetwork): repeats (int): The number of times that the residual unit should be repeated before applying the addition with the original input and the activation function. """ super(ResidualLayer, self).__init__(scope=scope, **kwargs) self.residual_unit = residual_unit self.repeats = repeats # Copy the repeat_units n times and add them to this Component. self.residual_units = [self.residual_unit] + [ self.residual_unit.copy(scope=self.residual_unit.scope+"-rep"+str(i+1)) for i in range(repeats - 1) ] self.add_components(*self.residual_units) @rlgraph_api def _graph_fn_apply(self, inputs): """ Args: inputs (SingleDataOp): The flattened inputs to this layer. Returns: SingleDataOp: The output after passing the input through n times the residual function, then the activation function. """ if get_backend() == "tf": results = inputs # Apply the residual unit n times to the input. for i in range(self.repeats): results = self.residual_units[i].apply(results) # Then activate and add up. added_with_input = results + inputs activation_function = get_activation_function(self.activation, self.activation_params) if activation_function is not None: return activation_function(added_with_input) else: return added_with_input