Source code for rlgraph.components.loss_functions.loss_function

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from rlgraph import get_backend
from rlgraph.components import Component
from rlgraph.utils.decorators import rlgraph_api

if get_backend() == "tf":
    import tensorflow as tf
elif get_backend() == "pytorch":
    import torch

[docs]class LossFunction(Component): """ A loss function component offers a simple interface into some error/loss calculation function. API: loss_per_item(*inputs) -> The loss value vector holding single loss values (one per item in a batch). loss_average(loss_per_item) -> The average value of the input `loss_per_item`. """ def __init__(self, discount=0.98, **kwargs): """ Args: discount (float): The discount factor (gamma). """ super(LossFunction, self).__init__(scope=kwargs.pop("scope", "loss-function"), **kwargs) self.discount = discount @rlgraph_api def loss(self, *inputs): """ API-method that calculates the total loss (average over per-batch-item loss) from the original input to per-item-loss. Args: see `self._graph_fn_loss_per_item`. Returns: Tuple (2x SingleDataOp): - The tensor specifying the final loss (over the entire batch). - The loss values vector (one single value for each batch item). """ raise NotImplementedError @rlgraph_api def _graph_fn_loss_per_item(self, *inputs): """ Returns the single loss values (one for each item in a batch). Args: *inputs (DataOpTuple): The various data that this function needs to calculate the loss. Returns: SingleDataOp: The tensor specifying the loss per item. The batch dimension of this tensor corresponds to the number of items in the batch. """ raise NotImplementedError @rlgraph_api(must_be_complete=False) def _graph_fn_loss_average(self, loss_per_item): """ The actual loss function that an optimizer will try to minimize. This is usually the average over a batch. Args: loss_per_item (SingleDataOp): The output of our loss_per_item graph_fn. Returns: SingleDataOp: The final loss tensor holding the average loss over the entire batch. """ if get_backend() == "tf": return tf.reduce_mean(input_tensor=loss_per_item, axis=0) elif get_backend() == "pytorch": return torch.mean(loss_per_item, 0)