Source code for rlgraph.components.memories.replay_memory

# Copyright 2018 The RLgraph authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from rlgraph.components.memories.memory import Memory
from rlgraph.utils.ops import FlattenedDataOp
from rlgraph.utils.util import get_batch_size
from rlgraph.utils.decorators import rlgraph_api

[docs]class ReplayMemory(Memory): """ Implements a standard replay memory to sample randomized batches. """ def __init__( self, capacity=1000, scope="replay-memory", **kwargs ): """ Args: next_states (bool): If true include next states in the return values of the API-method "get_records". """ super(ReplayMemory, self).__init__(capacity, scope=scope, **kwargs) self.index = None self.size = None self.states = None
[docs] def create_variables(self, input_spaces, action_space=None): super(ReplayMemory, self).create_variables(input_spaces, action_space) # Record space must contain 'terminals' for a replay memory. assert 'terminals' in self.record_space # Main buffer index. self.index = self.get_variable(name="index", dtype=int, trainable=False, initializer=0) # Number of elements present. self.size = self.get_variable(name="size", dtype=int, trainable=False, initializer=0)
@rlgraph_api(flatten_ops=True) def _graph_fn_insert_records(self, records): num_records = get_batch_size(records["/terminals"]) # List of indices to update (insert from `index` forward and roll over at `self.capacity`). update_indices = tf.range(start=self.index, limit=self.index + num_records) % self.capacity # Updates all the necessary sub-variables in the record. # update_indices = tf.Print(update_indices, [update_indices, index, num_records], summarize=100, # message='Update indices / index / num records = ') record_updates = list() for key in self.record_registry: record_updates.append(self.scatter_update_variable( variable=self.record_registry[key], indices=update_indices, updates=records[key] )) # Update indices and size. with tf.control_dependencies(control_inputs=record_updates): index_updates = list() index_updates.append(self.assign_variable(ref=self.index, value=(self.index + num_records) % self.capacity)) update_size = tf.minimum(x=(self.read_variable(self.size) + num_records), y=self.capacity) index_updates.append(self.assign_variable(self.size, value=update_size)) # Nothing to return. with tf.control_dependencies(control_inputs=index_updates): return tf.no_op() @rlgraph_api def _graph_fn_get_records(self, num_records=1): size = self.read_variable(self.size) # Sample and retrieve a random range, including terminals. index = self.read_variable(self.index) indices = tf.random_uniform(shape=(num_records,), maxval=size, dtype=tf.int32) indices = (index - 1 - indices) % self.capacity # Return default importance weight one. return self._read_records(indices=indices), indices, tf.ones_like(tensor=indices, dtype=tf.float32) def _read_records(self, indices): """ Obtains record values for the provided indices. Args: indices (Union[ndarray,tf.Tensor]): Indices to read. Assumed to be not contiguous. Returns: FlattenedDataOp: Record value dict. """ records = FlattenedDataOp() for name, variable in self.record_registry.items(): records[name] = self.read_variable(variable, indices) return records