4.6.3. Neural Network Layers¶
4.6.3.1. Activation Functions¶
-
rlgraph.components.layers.nn.activation_functions.
get_activation_function
(activation_function=None, *other_parameters)[source]¶ Returns an activation function (callable) to use in a NN layer.
- Args:
- activation_function (Optional[callable,str]): The activation function to lookup. Could be given as:
- already a callable (return just that)
- a lookup key (str)
- None: Use linear activation.
other_parameters (any): Possible extra parameter(s) used for some of the activation functions.
- Returns:
- callable: The backend-dependent activation function.
4.6.3.2. NNLayer Base Class¶
-
class
rlgraph.components.layers.nn.nn_layer.
NNLayer
(**kwargs)[source]¶ Bases:
rlgraph.components.layers.layer.Layer
A generic NN-layer object implementing the apply graph_fn and offering additional activation function support. Can be used in the following ways:
- Thin wrapper around a backend-specific layer object (normal use case):
- Create the backend layer in the create_variables method and store it under self.layer. Then register the backend layer’s variables with the RLgraph Component.
- Custom layer (with custom computation):
- Create necessary variables in create_variables (e.g. matrices), then override _graph_fn_apply, leaving self.layer as None.
- Single Activation Function:
- Leave self.layer as None and do not override _graph_fn_apply. It will then only apply the activation function.
4.6.3.3. Concat Layer¶
-
class
rlgraph.components.layers.nn.concat_layer.
ConcatLayer
(axis=-1, scope='concat-layer', **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
A simple concatenation layer wrapper. The ConcatLayer is a Layer without sub-components but with n api_methods and 1 output, where input data is concatenated into one output by its GraphFunction.
4.6.3.4. Conv2D Layer¶
-
class
rlgraph.components.layers.nn.conv2d_layer.
Conv2DLayer
(filters, kernel_size, strides, padding='valid', data_format='channels_last', kernel_spec=None, biases_spec=None, **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
A Conv2D NN-layer.
-
create_variables
(input_spaces, action_space=None)[source]¶ Should create all variables that are needed within this component, unless a variable is only needed inside a single _graph_fn-method, in which case, it should be created there. Variables must be created via the backend-agnostic self.get_variable-method.
Note that for different scopes in which this component is being used, variables will not(!) be shared.
- Args:
- input_spaces (Dict[str,Space]): A dict with Space/shape information.
- keys=in-Socket name (str); values=the associated Space
- action_space (Optional[Space]): The action Space of the Agent/GraphBuilder. Can be used to construct and
- connect more Components (which rely on this information). This eliminates the need to pass the action Space information into many Components’ constructors.
-
4.6.3.5. Dense Layer¶
-
class
rlgraph.components.layers.nn.dense_layer.
DenseLayer
(units, weights_spec=None, biases_spec=None, **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
A dense (or “fully connected”) NN-layer.
-
create_variables
(input_spaces, action_space=None)[source]¶ Should create all variables that are needed within this component, unless a variable is only needed inside a single _graph_fn-method, in which case, it should be created there. Variables must be created via the backend-agnostic self.get_variable-method.
Note that for different scopes in which this component is being used, variables will not(!) be shared.
- Args:
- input_spaces (Dict[str,Space]): A dict with Space/shape information.
- keys=in-Socket name (str); values=the associated Space
- action_space (Optional[Space]): The action Space of the Agent/GraphBuilder. Can be used to construct and
- connect more Components (which rely on this information). This eliminates the need to pass the action Space information into many Components’ constructors.
-
4.6.3.6. LSTM Layer¶
-
class
rlgraph.components.layers.nn.lstm_layer.
LSTMLayer
(units, use_peepholes=False, cell_clip=None, static_loop=False, forget_bias=1.0, parallel_iterations=32, swap_memory=False, time_major=False, **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
An LSTM layer processing an initial internal state vector and a batch of sequences to produce a final internal state and a batch of output sequences.
-
apply
(*args, **kwargs)¶
-
check_input_spaces
(input_spaces, action_space=None)[source]¶ Do some sanity checking on the incoming Space: Must not be Container (for now) and must have a batch rank.
-
create_variables
(input_spaces, action_space=None)[source]¶ Should create all variables that are needed within this component, unless a variable is only needed inside a single _graph_fn-method, in which case, it should be created there. Variables must be created via the backend-agnostic self.get_variable-method.
Note that for different scopes in which this component is being used, variables will not(!) be shared.
- Args:
- input_spaces (Dict[str,Space]): A dict with Space/shape information.
- keys=in-Socket name (str); values=the associated Space
- action_space (Optional[Space]): The action Space of the Agent/GraphBuilder. Can be used to construct and
- connect more Components (which rely on this information). This eliminates the need to pass the action Space information into many Components’ constructors.
-
4.6.3.7. MaxPool2D Layer¶
-
class
rlgraph.components.layers.nn.maxpool2d_layer.
MaxPool2DLayer
(pool_size, strides, padding='valid', data_format='channels_last', **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
A max-pooling 2D layer.
4.6.3.8. Residual Layer¶
-
class
rlgraph.components.layers.nn.residual_layer.
ResidualLayer
(residual_unit, repeats=2, scope='residual-layer', **kwargs)[source]¶ Bases:
rlgraph.components.layers.nn.nn_layer.NNLayer
A residual layer that adds the input value to some calculation. Based on:
[1] Identity Mappings in Deep Residual Networks - He, Zhang, Ren and Sun (Microsoft) 2016 (https://arxiv.org/pdf/1603.05027.pdf)
- API:
- apply(input_) ->