Source code for rlgraph.components.layers.nn.activation_functions

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from functools import partial

from rlgraph import get_backend
from rlgraph.utils.rlgraph_errors import RLGraphError


if get_backend() == "tf":
    import tensorflow as tf
elif get_backend() == "pytorch":
    import torch.nn as nn


[docs]def get_activation_function(activation_function=None, *other_parameters): """ Returns an activation function (callable) to use in a NN layer. Args: activation_function (Optional[callable,str]): The activation function to lookup. Could be given as: - already a callable (return just that) - a lookup key (str) - None: Use linear activation. other_parameters (any): Possible extra parameter(s) used for some of the activation functions. Returns: callable: The backend-dependent activation function. """ if get_backend() == "tf": if activation_function is None or callable(activation_function): return activation_function elif activation_function == "linear": return tf.identity # Rectifier linear unit (ReLU) : 0 if x < 0 else x elif activation_function == "relu": return tf.nn.relu # Exponential linear: exp(x) - 1 if x < 0 else x elif activation_function == "elu": return tf.nn.elu # Sigmoid: 1 / (1 + exp(-x)) elif activation_function == "sigmoid": return tf.sigmoid # Scaled exponential linear unit: scale * [alpha * (exp(x) - 1) if < 0 else x] # https://arxiv.org/pdf/1706.02515.pdf elif activation_function == "selu": return tf.nn.selu # Swish function: x * sigmoid(x) # https://arxiv.org/abs/1710.05941 elif activation_function == "swish": return lambda x: x * tf.sigmoid(x=x) # Leaky ReLU: x * [alpha if x < 0 else 1.0] elif activation_function in ["lrelu", "leaky_relu"]: alpha = other_parameters[0] if len(other_parameters) > 0 else 0.2 return partial(tf.nn.leaky_relu, alpha=alpha) # Concatenated ReLU: elif activation_function == "crelu": return tf.nn.crelu # Softmax function: elif activation_function == "softmax": return tf.nn.softmax # Softplus function: elif activation_function == 'softplus': return tf.nn.softplus # Softsign function: elif activation_function == "softsign": return tf.nn.softsign # tanh activation function: elif activation_function == "tanh": return tf.nn.tanh else: raise RLGraphError("ERROR: Unknown activation_function '{}' for TensorFlow backend!". format(activation_function)) elif get_backend() == "pytorch": # Have to instantiate objects here. if activation_function is None or callable(activation_function): return activation_function elif activation_function == "linear": # Do nothing. return None # Rectifier linear unit (ReLU) : 0 if x < 0 else x elif activation_function == "relu": return nn.ReLU() # Exponential linear: exp(x) - 1 if x < 0 else x elif activation_function == "elu": return nn.ELU() # Sigmoid: 1 / (1 + exp(-x)) elif activation_function == "sigmoid": return nn.Sigmoid() # Scaled exponential linear unit: scale * [alpha * (exp(x) - 1) if < 0 else x] # https://arxiv.org/pdf/1706.02515.pdf elif activation_function == "selu": return nn.SELU() # Leaky ReLU: x * [alpha if x < 0 else 1.0] elif activation_function in ["lrelu", "leaky_relu"]: alpha = other_parameters[0] if len(other_parameters) > 0 else 0.2 return partial(nn.LeakyReLU(), alpha=alpha) # Softmax function: elif activation_function == "softmax": return nn.Softmax() # Softplus function: elif activation_function == 'softplus': return nn.Softplus() # Softsign function: elif activation_function == "softsign": return nn.Softsign() # tanh activation function: elif activation_function == "tanh": return nn.Tanh() else: raise RLGraphError("ERROR: Unknown activation_function '{}' for PyTorch backend!". format(activation_function))