Source code for rlgraph.components.helpers.segment_tree

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


[docs]class SegmentTree(object): """ TensorFlow Segment tree for prioritized replay. """ def __init__( self, storage_variable, capacity=1048 ): """ Helper to represent a segment tree in pure TensorFlow. Args: storage_variable (tf.Variable): TensorFlow variable to use for storage. capacity (int): Capacity of the segment tree. """ self.values = storage_variable self.capacity = capacity
[docs] def insert(self, index, element, insert_op=tf.add): """ Inserts an element into the segment tree by determining its position in the tree. Args: index (int): Insertion index. element (any): Element to insert. insert_op (Union(tf.add, tf.minimum, tf, maximum)): Insert operation on the tree. """ index += self.capacity # Use a TensorArray to collect updates to the segment tree, then perform them all at once. index_updates = tf.TensorArray( dtype=tf.int32, infer_shape=False, size=1, dynamic_size=True, clear_after_read=False ) element_updates = tf.TensorArray( dtype=tf.float32, infer_shape=False, size=1, dynamic_size=True, clear_after_read=False ) index_updates = index_updates.write(index=0, value=index) element_updates = element_updates.write(index=0, value=element) # Search and update values while index >=1 loop_update_index = tf.div(x=index, y=2) def insert_body(loop_update_index, index_updates, element_updates, call_index): # This is the index we just updated. prev_index = index_updates.read(call_index - 1) prev_val = element_updates.read(call_index - 1) update_val = tf.where( condition=tf.greater(x=prev_index % 2, y=0), # Previous index was odd because of loop init -> 2 * index + 1 is in element_updates, # 2 * index is in variable values x=insert_op(x=self.values[2 * loop_update_index], y=prev_val), # Previous index was even -> 2 * index is in element updates, 2 * index + 1 in variable values. y=insert_op(x=prev_val, y=self.values[2 * loop_update_index + 1]) ) index_updates = index_updates.write(call_index, loop_update_index) element_updates = element_updates.write(call_index, update_val) return tf.div(x=loop_update_index, y=2), index_updates, element_updates, call_index + 1 def cond(loop_update_index, index_updates, element_updates, call_index): return loop_update_index >= 1 # Return the TensorArrays containing the updates. loop_update_index, index_updates, element_updates, _ = tf.while_loop( cond=cond, body=insert_body, loop_vars=[loop_update_index, index_updates, element_updates, 1], parallel_iterations=1, back_prop=False ) indices = index_updates.stack() updates = element_updates.stack() assignment = tf.scatter_update(ref=self.values, indices=indices, updates=updates) with tf.control_dependencies(control_inputs=[assignment]): return tf.no_op()
# def insert(self, index, element, insert_op=tf.add): # """ # Inserts an element into the segment tree by determining # its position in the tree. # # Args: # index (int): Insertion index. # element (any): Element to insert. # insert_op (Union(tf.add, tf.minimum, tf, maximum)): Insert operation on the tree. # """ # # index = tf.Print(index, [index], summarize=100, message='index before add = ') # index += self.capacity # # index = tf.Print(index, [index], summarize=100, message='index after add = ') # # assignment = tf.assign(ref=self.values[index], value=element) # # # Search and update values while index >=1 # loop_update_index = tf.div(x=index, y=2) # # def insert_body(loop_update_index): # # update_val = insert_op( # x=self.values[2 * loop_update_index], # y=self.values[2 * loop_update_index + 1] # ) # loop_update_index = tf.Print(loop_update_index, # [loop_update_index,update_val], # summarize=100, message='index, update val= ') # assignment = tf.assign(ref=self.values[loop_update_index], value=update_val) # # with tf.control_dependencies(control_inputs=[assignment]): # return tf.div(x=loop_update_index, y=2) # # def cond(loop_update_index): # return loop_update_index >= 1 # # with tf.control_dependencies(control_inputs=[assignment]): # return tf.while_loop( # cond=cond, # body=insert_body, # loop_vars=[loop_update_index], # parallel_iterations=1, # back_prop=False # )
[docs] def get(self, index): """ Reads an item from the segment tree. Args: index (int): Returns: The element. """ return self.values[self.capacity + index]
[docs] def index_of_prefixsum(self, prefix_sum): """ Identifies the highest index which satisfies the condition that the sum over all elements from 0 till the index is <= prefix_sum. Args: prefix_sum .float): Upper bound on prefix we are allowed to select. Returns: int: Index/indices satisfying prefix sum condition. """ assert_ops = list() # 0 <= prefix_sum <= sum(priorities) priority_sum = tf.reduce_sum(input_tensor=self.values, axis=0) # priority_sum_tensor = tf.fill(dims=tf.shape(prefix_sum), value=priority_sum) assert_ops.append(tf.Assert( condition=tf.less_equal(x=prefix_sum, y=priority_sum), data=[prefix_sum] )) # Vectorized loop -> initialize all indices matching elements in prefix-sum, index = 1 def search_body(index, prefix_sum): # Is the value at position 2 * index > prefix sum? compare_value = self.values[2 * index] def update_prefix_sum_fn(index, prefix_sum): # 'Use up' values in this segment, then jump to next. prefix_sum -= self.values[2 * index] return 2 * index + 1, prefix_sum index, prefix_sum = tf.cond( pred=compare_value > prefix_sum, # If over prefix sum, jump index. true_fn=lambda: (2 * index, prefix_sum), # Else adjust prefix sum until done. false_fn=lambda: update_prefix_sum_fn(index, prefix_sum) ) return index, prefix_sum def cond(index, prefix_sum): return index < self.capacity with tf.control_dependencies(control_inputs=assert_ops): index, _ = tf.while_loop(cond=cond, body=search_body, loop_vars=[index, prefix_sum]) return index - self.capacity
[docs] def reduce(self, start, limit, reduce_op=tf.add): """ Applies an operation to specified segment. Args: start (int): Start index to apply reduction to. limit (end): End index to apply reduction to. reduce_op (Union(tf.add, tf.minimum, tf.maximum)): Reduce op to apply. Returns: Number: Result of reduce operation """ # Init result with neutral element of reduce op. # Note that all of these are commutative reduce ops. if reduce_op == tf.add: result = 0.0 elif reduce_op == tf.minimum: result = float('inf') elif reduce_op == tf.maximum: result = float('-inf') else: raise ValueError("Unsupported reduce OP. Support ops are [tf.add, tf.minimum, tf.maximum]") start += self.capacity limit += self.capacity def reduce_body(start, limit, result): start_mod = tf.mod(x=start, y=2) def update_start_fn(start, result): result = reduce_op(x=result, y=self.values[start]) start += 1 return start, result start, result = tf.cond( pred=tf.equal(x=start_mod, y=0), true_fn=lambda: (start, result), false_fn=lambda: update_start_fn(start, result) ) end_mod = tf.mod(x=limit, y=2) def update_limit_fn(limit, result): limit -= 1 result = reduce_op(x=result, y=self.values[limit]) return limit, result limit, result = tf.cond( pred=tf.equal(x=end_mod, y=0), true_fn=lambda: (limit, result), false_fn=lambda: update_limit_fn(limit, result) ) return tf.div(x=start, y=2), tf.div(x=limit, y=2), result def cond(start, limit, result): return start < limit _, _, result = tf.while_loop(cond=cond, body=reduce_body, loop_vars=(start, limit, result)) return result
[docs] def get_min_value(self): """ Returns min value of storage variable. """ return self.reduce(0, self.capacity - 1, reduce_op=tf.minimum)
[docs] def get_sum(self): """ Returns sum value of storage variable. """ return self.reduce(0, self.capacity - 1, reduce_op=tf.add)