Source code for rlgraph.components.helpers.mem_segment_tree
# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import operator
from rlgraph.utils.rlgraph_errors import RLGraphError
[docs]class MemSegmentTree(object):
"""
In-memory Segment tree for prioritized replay.
Note: The pure TensorFlow segment tree is much slower because variable updating is expensive,
and in scenarios like Ape-X, memory and update are separated processes, so there is little to be gained
from inserting into the graph.
"""
def __init__(
self,
values,
capacity,
operator=operator.add
):
"""
Helper to represent a segment tree.
Args:
values (list): Storage for the segment tree.
capacity (int): Capacity of segment tree.
operator (callable): Reduce operation of the segment tree.
"""
self.values = values
self.capacity = capacity
self.operator = operator
[docs] def insert(self, index, element):
"""
Inserts an element into the segment tree by determining
its position in the tree.
Args:
index (int): Insertion index.
element (any): Element to insert.
"""
index += self.capacity
self.values[index] = element
#void modify(int p, int value) { // set value at position p
# for (t[p += n] = value; p > 1; p >>= 1) t[p>>1] = t[p] + t[p^1];
# }
# Bit shift should be slightly faster here than division.
index = index >> 1
while index >= 1:
# No shift because small multiplications are optimized.
update_index = 2 * index
self.values[index] = self.operator(
self.values[update_index],
self.values[update_index + 1]
)
index = index >> 1
[docs] def get(self, index):
"""
Reads an item from the segment tree.
Args:
index (int):
Returns: The element.
"""
return self.values[self.capacity + index]
[docs] def index_of_prefixsum(self, prefix_sum):
"""
Identifies the highest index which satisfies the condition that the sum
over all elements from 0 till the index is <= prefix_sum.
Args:
prefix_sum .float): Upper bound on prefix we are allowed to select.
Returns:
int: Index/indices satisfying prefix sum condition.
"""
assert 0 <= prefix_sum <= self.get_sum() + 1e-5
index = 1
while index < self.capacity:
update_index = 2 * index
if self.values[update_index] > prefix_sum:
index = update_index
else:
prefix_sum -= self.values[update_index]
index = update_index + 1
return index - self.capacity
[docs] def reduce(self, start, limit, reduce_op=operator.add):
"""
Applies an operation to specified segment.
Args:
start (int): Start index to apply reduction to.
limit (end): End index to apply reduction to.
reduce_op (Union(operator.add, min, max)): Reduce op to apply.
Returns:
Number: Result of reduce operation
"""
if limit is None:
limit = self.capacity
if limit < 0:
limit += self.capacity
# Init result with neutral element of reduce op.
# Note that all of these are commutative reduce ops.
if reduce_op == operator.add:
result = 0.0
elif reduce_op == min:
result = float('inf')
elif reduce_op == max:
result = float('-inf')
else:
raise RLGraphError("Unsupported reduce OP. Support ops are [add, min, max].")
start += self.capacity
limit += self.capacity
while start < limit:
if start & 1:
result = reduce_op(result, self.values[start])
start += 1
if limit & 1:
limit -= 1
result = reduce_op(result, self.values[limit])
start = start >> 1
limit = limit >> 1
return result
[docs] def get_min_value(self, start=0, stop=None):
"""
Returns min value of storage variable.
"""
return self.reduce(start, stop, reduce_op=min)
[docs] def get_sum(self, start=0, stop=None):
"""
Returns sum value of storage variable.
"""
return self.reduce(start, stop, reduce_op=operator.add)
class MinSumSegmentTree(object):
"""
This class merges two segment trees' operations for performance reasons to avoid
unnecessary duplication of the insert loops.
"""
def __init__(
self,
sum_tree,
min_tree,
capacity,
):
self.sum_segment_tree = sum_tree
self.min_segment_tree = min_tree
self.capacity = capacity
def insert(self, index, element):
"""
Inserts an element into both segment trees by determining
its position in the trees.
Args:
index (int): Insertion index.
element (any): Element to insert.
"""
index += self.capacity
self.sum_segment_tree.values[index] = element
self.min_segment_tree.values[index] = element
index = index >> 1
while index >= 1:
update_index = 2 * index
self.sum_segment_tree.values[index] = self.sum_segment_tree.values[update_index] +\
self.sum_segment_tree.values[update_index + 1]
self.min_segment_tree.values[index] = min(self.min_segment_tree.values[update_index],
self.min_segment_tree.values[update_index + 1])
index = index >> 1