4.7. Loss Functions

4.7.1. Loss Function Base Class

class rlgraph.components.loss_functions.loss_function.LossFunction(discount=0.98, **kwargs)[source]

Bases: rlgraph.components.component.Component

A loss function component offers a simple interface into some error/loss calculation function.

API:
loss_per_item(*inputs) -> The loss value vector holding single loss values (one per item in a batch). loss_average(loss_per_item) -> The average value of the input loss_per_item.
loss(*args, **kwargs)

4.7.2. DQN Loss Function

class rlgraph.components.loss_functions.dqn_loss_function.DQNLossFunction(double_q=False, huber_loss=False, importance_weights=False, n_step=1, scope='dqn-loss-function', **kwargs)[source]

Bases: rlgraph.components.loss_functions.loss_function.LossFunction

The classic 2015 DQN Loss Function: L = Expectation-over-uniform-batch(r + gamma x max_a’Qt(s’,a’) - Qn(s,a))^2 Where Qn is the “normal” Q-network and Qt is the “target” net (which is a little behind Qn for stability purposes).

API:
loss_per_item(q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp=None): The DQN loss per batch
item.
check_input_spaces(input_spaces, action_space=None)[source]

Do some sanity checking on the incoming Spaces:

loss(*args, **kwargs)

4.7.3. IMPALA Loss Function

class rlgraph.components.loss_functions.impala_loss_function.IMPALALossFunction(discount=0.99, reward_clipping='clamp_one', weight_pg=None, weight_baseline=None, weight_entropy=None, **kwargs)[source]

Bases: rlgraph.components.loss_functions.loss_function.LossFunction

The IMPALA loss function based on v-trace off-policy policy gradient corrections, described in detail in [1].

The three terms of the loss function are: 1) The policy gradient term:

L[pg] = (rho_pg * advantages) * nabla log(pi(a|s)), where (rho_pg * advantages)=pg_advantages in code below.
  1. The value-function baseline term:
    L[V] = 0.5 (vs - V(xs))^2, such that dL[V]/dtheta = (vs - V(xs)) nabla V(xs)
  2. The entropy regularizer term:
    L[E] = - SUM[all actions a] pi(a|s) * log pi(a|s)
[1] IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures - Espeholt, Soyer,
Munos et al. - 2018 (https://arxiv.org/abs/1802.01561)
check_input_spaces(input_spaces, action_space=None)[source]

Should check on the nature of all in-Sockets Spaces of this Component. This method is called automatically by the Model when all these Spaces are know during the Model’s build time.

Args:
input_spaces (Dict[str,Space]): A dict with Space/shape information.
keys=in-Socket name (str); values=the associated Space
action_space (Optional[Space]): The action Space of the Agent/GraphBuilder. Can be used to construct and
connect more Components (which rely on this information). This eliminates the need to pass the action Space information into many Components’ constructors.
loss(logits_actions_pi, action_probs_mu, values, actions, rewards, terminals)[source]

API-method that calculates the total loss (average over per-batch-item loss) from the original input to per-item-loss.

Args: see self._graph_fn_loss_per_item.

Returns:
SingleDataOp: The tensor specifying the final loss (over the entire batch).