Source code for rlgraph.components.action_adapters.baseline_action_adapter
# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rlgraph import get_backend
from rlgraph.components.action_adapters.action_adapter import ActionAdapter
from rlgraph.utils.decorators import rlgraph_api, graph_fn
if get_backend() == "tf":
import tensorflow as tf
[docs]class BaselineActionAdapter(ActionAdapter):
"""
An ActionAdapter that adds 1 node to its action layer for an additional state-value output per batch item.
API:
get_state_values_and_logits(nn_output) (Tuple[SingleDataOp x 2]): The state-value and action logits (reshaped).
"""
def __init__(self, scope="baseline-action-adapter", **kwargs):
# Change the number of units in the action layer (+1 for the extra Value function node).
super(BaselineActionAdapter, self).__init__(add_units=1, scope=scope, **kwargs)
self.input_space = None
@rlgraph_api
def get_logits(self, nn_output):
"""
Args:
nn_output (DataOpRecord): The NN output of the preceding neural network.
Returns:
SingleDataOp: The logits (raw nn_output, BUT reshaped).
"""
aa_output = self.get_action_layer_output(nn_output)
_, logits = self._graph_fn_get_state_values_and_logits(aa_output["output"])
return logits
@rlgraph_api
def get_logits_probabilities_log_probs(self, nn_output):
"""
API-method. Returns separated V and logit values split from the action layer.
Args:
nn_output (DataOpRecord): The NN output of the preceding neural network.
Returns:
dict (4x DataOpRecord):
`state_values`: The state values (V).
`logits`: The (reshaped) logits coming from the action layer.
`probabilities`: softmaxed(logits)
`log_probs`: log(probabilities)
"""
# Slice away the first node for the state value and reshape the rest to yield the action logits.
state_values, logits = self.get_state_values_and_logits(nn_output)
probabilities, log_probs = self._graph_fn_get_probabilities_log_probs(logits)
return dict(state_values=state_values, logits=logits, probabilities=probabilities, log_probs=log_probs)
@rlgraph_api
def get_state_values_and_logits(self, nn_output):
# Run through the action layer.
aa_output = self.get_action_layer_output(nn_output)
state_values, logits = self._graph_fn_get_state_values_and_logits(aa_output["output"])
return state_values, logits
@graph_fn
def _graph_fn_get_state_values_and_logits(self, action_layer_output):
"""
Slices away the state-value node from the raw action_layer_output (dense) and returns the single state-value
and the remaining (reshaped) action-logits.
Args:
action_layer_output (SingleDataOp): The flat action layer output.
Returns:
tuple (2x SingleDataOp):
- The state value (as shape=(1,)).
- The reshaped action logits.
"""
if get_backend() == "tf":
# Separate the single state-value node from the flat logits.
state_value, flat_logits = tf.split(
value=action_layer_output, num_or_size_splits=(1, self.action_layer.units - 1), axis=-1
)
# TODO: automate this: batch in -> batch out; time in -> time out; batch+time in -> batch+time out, etc..
# TODO: if not default behavior: have to specify in decorator (see design_problems.txt).
# Now we have to reshape the flat logits to obtain the action-shaped logits.
# Adjust batch/time ranks.
flat_logits._batch_rank = 0 if self.input_space.time_major is False else 1
if self.input_space.has_time_rank:
flat_logits._time_rank = 0 if self.input_space.time_major is True else 1
logits = self.reshape.apply(flat_logits)
# TODO: automate this: batch in -> batch out; time in -> time out; batch+time in -> batch+time out, etc..
# TODO: if not default behavior: have to specify in decorator (see design_problems.txt).
# Adjust batch/time ranks.
state_value._batch_rank = 0 if self.input_space.time_major is False else 1
logits._batch_rank = 0 if self.input_space.time_major is False else 1
if self.input_space.has_time_rank:
state_value._time_rank = 0 if self.input_space.time_major is True else 1
logits._time_rank = 0 if self.input_space.time_major is True else 1
return state_value, logits