Source code for rlgraph.components.common.batch_splitter
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rlgraph import get_backend
from rlgraph.components import Component
from rlgraph.utils.decorators import rlgraph_api
from rlgraph.utils.ops import FlattenedDataOp, unflatten_op, DataOpTuple
if get_backend() == "tf":
import tensorflow as tf
[docs]class BatchSplitter(Component):
"""
Splits a number of incoming DataOps along their batch dimension.
"""
def __init__(self, num_shards, shard_size, **kwargs):
"""
Args:
num_shards (int): Number of shards to split the batch dimension into.
shard_size (int): The number of samples in a per-GPU shard.
"""
super(BatchSplitter, self).__init__(
scope=kwargs.pop("scope", "batch-splitter"),
graph_fn_num_outputs=dict(_graph_fn_split_batch=num_shards),
**kwargs
)
assert num_shards > 1, "ERROR: num shards must be greater than 1 but is {}.".format(
num_shards
)
self.num_shards = num_shards
self.shard_size = shard_size
@rlgraph_api(flatten_ops=True)
def _graph_fn_split_batch(self, *inputs):
"""
Splits all DataOps in *inputs along their batch dimension into n equally sized shards. The number of shards
is determined by `self.num_shards` (int) and the size of each shard depends on the incoming batch size with
possibly a few superfluous items in the batch being discarded
(effective batch size = num_shards x shard_size).
Args:
*input (FlattenedDataOp): Input tensors which must all have the same batch dimension.
Returns:
tuple:
# Each shard consisting of: A DataOpTuple with len = number of input args.
# - Each item in the DataOpTuple is a FlattenedDataOp with (flat) key (describing the input-piece
# (e.g. "/states1")) and values being the (now sharded) batch data for that input piece.
# e.g. return (for 2 shards):
# tuple(DataOpTuple(input1_flatdict, input2_flatdict, input3_flatdict, input4_flatdict), DataOpTuple([same]))
List of FlattenedDataOps () containing DataOpTuples containing the input shards.
"""
if get_backend() == "tf":
#batch_size = tf.shape(next(iter(inputs[0].values())))[0]
#shard_size = tf.cast(batch_size / self.num_shards, dtype=tf.int32)
# Must be evenly divisible so we slice out an evenly divisible tensor.
# E.g. 203 items in batch with 4 shards -> Only 4 x 50 = 200 are usable.
usable_size = self.shard_size * self.num_shards
# List (one item for each input arg). Each item in the list looks like:
# A FlattenedDataOp with (flat) keys (describing the input-piece (e.g. "/states1")) and values being
# lists of len n for the n shards' data.
inputs_flattened_and_split = list()
for input_arg_data in inputs:
shard_dict = FlattenedDataOp()
for flat_key, data in input_arg_data.items():
usable_input_tensor = data[:usable_size]
shard_dict[flat_key] = tf.split(value=usable_input_tensor, num_or_size_splits=self.num_shards)
inputs_flattened_and_split.append(shard_dict)
# Flip the list to generate a new list where each item represents one shard.
shard_list = list()
for shard_idx in range(self.num_shards):
# To be converted into FlattenedDataOps over the input-arg-pieces once complete.
input_arg_list = list()
for input_elem in range(len(inputs)):
sharded_data_dict = FlattenedDataOp()
for flat_key, shards in inputs_flattened_and_split[input_elem].items():
sharded_data_dict[flat_key] = shards[shard_idx]
input_arg_list.append(unflatten_op(sharded_data_dict))
# Must store everything as FlattenedDataOp otherwise the re-nesting will not work.
shard_list.append(DataOpTuple(input_arg_list))
# Return n values (n = number of batch shards).
return tuple(shard_list)