Source code for rlgraph.components.common.staging_area
# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rlgraph import get_backend
from rlgraph.components.component import Component
from rlgraph.utils.decorators import rlgraph_api
from rlgraph.utils.ops import flatten_op, unflatten_op, FlattenedDataOp
from rlgraph.utils.util import dtype as dtype_
if get_backend() == "tf":
import tensorflow as tf
[docs]class StagingArea(Component):
"""
Stages an arbitrary number of incoming ops for next-step processing.
This allows for more efficient handling of dequeued (or otherwise pipelined) data: The data can
be prepared and then staged while a training step is still taking place, the next training step can then
immediately take the staged data, aso.asf..
"""
def __init__(self, num_data=1, device="/device:GPU:0", scope="staging-area", **kwargs):
"""
Args:
num_data (int): The number of data items to stage. Each item can be a ContainerDataOp (which
will be flattened (stage) and unflattened (unstage) automatically).
"""
super(StagingArea, self).__init__(
graph_fn_num_outputs=dict(_graph_fn_unstage=num_data),
device=device,
scope=scope, **kwargs
)
# The actual backend-dependent StagingArea object.
self.area = None
# List of lists of flat keys of all input DataOps.
self.flat_keys = list()
[docs] def create_variables(self, input_spaces, action_space=None):
# Store the original structure for later recovery.
dtypes = list()
shapes = list()
idx = 0
while True:
key = "inputs[{}]".format(idx)
if key not in input_spaces:
break
flat_keys = list()
for flat_key, flat_space in input_spaces[key].flatten().items():
dtypes.append(dtype_(flat_space.dtype))
shapes.append(flat_space.get_shape(with_batch_rank=True, with_time_rank=True))
flat_keys.append(flat_key)
self.flat_keys.append(flat_keys)
idx += 1
if get_backend() == "tf":
self.area = tf.contrib.staging.StagingArea(dtypes, shapes)
@rlgraph_api
def _graph_fn_stage(self, *inputs):
"""
Stages all incoming ops (after flattening them).
Args:
inputs (DataOp): The incoming ops to be (flattened and) staged.
Returns:
DataOp: The staging op.
"""
# Flatten inputs and stage them.
# TODO: Build equivalent to nest.flatten ()
flattened_ops = list()
for input_ in inputs:
flat_list = list(flatten_op(input_).values())
flattened_ops.extend(flat_list)
stage_op = self.area.put(flattened_ops)
return stage_op
@rlgraph_api
def _graph_fn_unstage(self):
"""
Unstages (and unflattens) all staged data.
Returns:
Tuple[DataOp]: All previously staged ops.
"""
unstaged_data = self.area.get()
unflattened_data = list()
idx = 0
# Unflatten all data and return.
for flat_key_list in self.flat_keys:
flat_dict = FlattenedDataOp({flat_key: item for flat_key, item in zip(flat_key_list, unstaged_data[idx:idx + len(flat_key_list)])})
unflattened_data.append(unflatten_op(flat_dict))
idx += len(flat_key_list)
return tuple(unflattened_data)