# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from six.moves import xrange as range_
import copy
import inspect
import numpy as np
import re
from rlgraph import get_backend
from rlgraph.utils.decorators import rlgraph_api, component_api_registry, component_graph_fn_registry, define_api_method, \
define_graph_fn
from rlgraph.utils.rlgraph_errors import RLGraphError
from rlgraph.utils.specifiable import Specifiable
from rlgraph.utils.ops import DataOpDict, FLAT_TUPLE_OPEN, FLAT_TUPLE_CLOSE
from rlgraph.utils import util
if get_backend() == "tf":
import tensorflow as tf
elif get_backend() == "tf-eager":
import tensorflow as tf
import tensorflow.contrib.eager as eager
[docs]class Component(Specifiable):
"""
Base class for a graph component (such as a layer, an entire function approximator, a memory, an optimizers, etc..).
A component can contain other components and/or its own graph-logic (e.g. tf ops).
A component's sub-components are connected to each other via in- and out-Sockets (similar to LEGO blocks
and deepmind's sonnet).
This base class implements the interface to add sub-components, create connections between
different sub-components and between a sub-component and this one and between this component
and an external component.
A component also has a variable registry, the ability to save the component's structure and variable-values to disk,
and supports adding its graph_fns to the overall computation graph.
"""
call_count = 0
# List of tuples (method_name, runtime)
call_times = []
def __init__(self, *sub_components, **kwargs):
"""
Args:
sub_components (Component): Specification dicts for sub-Components to be added to this one.
Keyword Args:
name (str): The name of this Component. Names of sub-components within a containing component
must be unique. Names are used to label exposed Sockets of the containing component.
If name is empty, use scope as name (as last resort).
scope (str): The scope of this Component for naming variables in the Graph.
device (str): Device this component will be assigned to. If None, defaults to CPU.
trainable (Optional[bool]): Whether to make the variables of this Component always trainable or not.
Use None for no specific preference.
#global_component (bool): In distributed mode, this flag indicates if the component is part of the
# shared global model or local to the worker. Defaults to False and will be ignored if set to
# True in non-distributed mode.
# TODO: remove when we have numpy-based Components (then we can do test calls to infer everything automatically)
graph_fn_num_outputs (dict): A dict specifying which graph_fns have how many return values.
This can be useful if graph_fns don't clearly have a fixed number of return values and the auto-inferral
utility function cannot determine the actual number of returned values.
switched_off_apis (Optional[Set[str]]): Set of API-method names that should NOT be build for this Component.
backend (str): The custom backend that this Component obliges to. None to use the RLGraph global backend.
Default: None.
"""
super(Component, self).__init__()
# Scope if used to create scope hierarchies inside the Graph.
# self.logger = logging.getLogger(__name__)
self.scope = kwargs.pop("scope", "")
assert re.match(r'^[\w\-]*$', self.scope), \
"ERROR: scope {} does not match scope-pattern! Needs to be \\w or '-'.".format(self.scope)
# The global scope string defining the exact nested position of this Component in the Graph.
# e.g. "/core/component1/sub-component-a"
self.global_scope = self.scope
# Shared variable scope.
self.reuse_variable_scope = kwargs.pop("reuse_variable_scope", None)
# Names of sub-components that exist (parallelly) inside a containing component must be unique.
self.name = kwargs.pop("name", self.scope) # if no name given, use scope
self.device = kwargs.pop("device", None)
self.trainable = kwargs.pop("trainable", None)
self.graph_fn_num_outputs = kwargs.pop("graph_fn_num_outputs", dict())
self.switched_off_apis = kwargs.pop("switched_off_apis", set())
self.backend = kwargs.pop("backend", None)
assert not kwargs, "ERROR: kwargs ({}) still contains items!".format(kwargs)
# Keep track of whether this Component has already been added to another Component and throw error
# if this is done twice. Each Component can only be added once to a parent Component.
self.parent_component = None # type: Component
# Dict of sub-components that live inside this one (key=sub-component's scope).
self.sub_components = OrderedDict()
# Link to the GraphBuilder object.
self.graph_builder = None
# `self.api_methods`: Dict holding information about which op-record-tuples go via which API
# methods into this Component and come out of it.
# keys=API method name; values=APIMethodRecord
self.api_methods = {}
# Maps names to callable API functions for eager calls.
self.api_fn_by_name = {}
# Maps names of methods to synthetically defined methods.
self.synthetic_methods = set()
# How this component executes its 'call' method.
self.execution_mode = "static_graph" if self.backend != "python" else "define_by_run"
# `self.api_method_inputs`: Registry for all unique API-method input parameter names and their Spaces.
# Two API-methods may share the same input if their input parameters have the same names.
# keys=input parameter name; values=Space that goes into that parameter
self.api_method_inputs = {}
# Registry for graph_fn records (only populated at build time when the graph_fns are actually called).
self.graph_fns = {}
self.register_api_methods_and_graph_fns()
# Set of op-rec-columns going into a graph_fn of this Component and not having 0 op-records.
# Helps during the build procedure to call these right away after the Component is input-complete.
self.no_input_graph_fn_columns = set()
# Set of op-records that are constant and thus can be processed right away at the beginning of the build
# procedure.
self.constant_op_records = set()
# Whether we know already all our in-Sockets' Spaces.
# Only then can we create our variables. Model will do this.
self.input_complete = False
# Whether all our sub-Components are input-complete. Only after that point, we can run our _variables graph_fn.
self.variable_complete = False
# All Variables that are held by this component (and its sub-components) by name.
# key=full-scope variable name (scope=component/sub-component scope)
# value=the actual variable
self.variables = {}
# All summary ops that are held by this component (and its sub-components) by name.
# key=full-scope summary name (scope=component/sub-component scope)
# value=the actual summary op
self.summaries = {}
# The regexp that a summary's full-scope name has to match in order for it to be generated and registered.
# This will be set by the GraphBuilder at build time.
self.summary_regexp = None
# Now add all sub-Components (also support all sub-Components being given in a single list).
sub_components = sub_components[0] if len(sub_components) == 1 and \
isinstance(sub_components[0], (list, tuple)) else sub_components
self.add_components(*sub_components)
[docs] def register_api_methods_and_graph_fns(self):
"""
Detects all methods of the Component that should be registered as API-methods for
this Component and complements `self.api_methods` and `self.api_method_inputs`.
Goes by the @api decorator before each API-method or graph_fn that should be
auto-thin-wrapped by an API-method.
"""
# Goes through the class hierarchy of `self` and tries to lookup all registered functions
# (by name) that should be turned into API-methods.
class_hierarchy = inspect.getmro(type(self))
for class_ in class_hierarchy[:-2]: # skip last two as its `Specifiable` and `object`
api_method_recs = component_api_registry.get(class_.__name__)
if api_method_recs is not None:
for api_method_rec in api_method_recs:
if api_method_rec.name not in self.api_methods:
define_api_method(self, api_method_rec)
graph_fn_recs = component_graph_fn_registry.get(class_.__name__)
if graph_fn_recs is not None:
for graph_fn_rec in graph_fn_recs:
if graph_fn_rec.name not in self.graph_fns:
define_graph_fn(self, graph_fn_rec)
[docs] def check_variable_completeness(self):
"""
Checks, whether this Component is input-complete AND all our sub-Components are input-complete.
At that point, all variables are defined and we can run the `_variables` graph_fn.
Returns:
bool: Whether this Component is "variables-complete".
"""
# We are already variable-complete -> shortcut return here.
if self.variable_complete:
return True
# We are not input-complete yet (our own variables have not been created) -> return False.
elif self.input_complete is False:
return False
# Simply check all sub-Components for input-completeness.
self.variable_complete = all(sc.input_complete for sc in self.sub_components.values())
return self.variable_complete
[docs] def create_variables(self, input_spaces, action_space=None):
"""
Should create all variables that are needed within this component,
unless a variable is only needed inside a single _graph_fn-method, in which case,
it should be created there.
Variables must be created via the backend-agnostic self.get_variable-method.
Note that for different scopes in which this component is being used, variables will not(!) be shared.
Args:
input_spaces (Dict[str,Space]): A dict with Space/shape information.
keys=in-Socket name (str); values=the associated Space
action_space (Optional[Space]): The action Space of the Agent/GraphBuilder. Can be used to construct and
connect more Components (which rely on this information). This eliminates the need to pass the action
Space information into many Components' constructors.
"""
pass
[docs] def register_variables(self, *variables):
"""
Adds already created Variables to our registry. This could be useful if the variables are not created
by our own `self.get_variable` method, but by some backend-specific object (e.g. tf.layers).
Also auto-creates summaries (regulated by `self.summary_regexp`) for the given variables.
Args:
# TODO check if we warp PytorchVariable
variables (Union[PyTorchVariable, SingleDataOp]): The Variable objects to register.
"""
for var in variables:
# Use our global_scope plus the var's name without anything in between.
# e.g. var.name = "dense-layer/dense/kernel:0" -> key = "dense-layer/kernel"
# key = re.sub(r'({}).*?([\w\-.]+):\d+$'.format(self.global_scope), r'\1/\2', var.name)
key = re.sub(r':\d+$', "", var.name)
# Already registered: Must be the same (shared) variable.
if key in self.variables:
assert self.variables[key] is var,\
"ERROR: Key '{}' in {}.variables already exists, but holds a different variable " \
"({} vs {})!".format(key, self.global_scope, self.variables[key], var)
# New variable: Register.
else:
self.variables[key] = var
# Auto-create the summary for the variable.
summary_name = var.name[len(self.global_scope) + (1 if self.global_scope else 0):]
summary_name = re.sub(r':\d+$', "", summary_name)
self.create_summary(summary_name, var)
[docs] def get_variable(self, name="", shape=None, dtype="float", initializer=None, trainable=True,
from_space=None, add_batch_rank=False, add_time_rank=False, time_major=False, flatten=False,
local=False, use_resource=False):
"""
Generates or returns a variable to use in the selected backend.
The generated variable is automatically registered in this component's (and all parent components')
variable-registry under its global-scoped name.
Args:
name (str): The name under which the variable is registered in this component.
shape (Optional[tuple]): The shape of the variable. Default: empty tuple.
dtype (Union[str,type]): The dtype (as string) of this variable.
initializer (Optional[any]): Initializer for this variable.
trainable (bool): Whether this variable should be trainable. This will be overwritten, if the Component
has its own `trainable` property set to either True or False.
from_space (Optional[Space,str]): Whether to create this variable from a Space object
(shape and dtype are not needed then). The Space object can be given directly or via the name
of the in-Socket holding the Space.
add_batch_rank (Optional[bool,int]): If True and `from_space` is given, will add a 0th (1st) rank (None) to
the created variable. If it is an int, will add that int instead of None.
Default: False.
add_time_rank (Optional[bool,int]): If True and `from_space` is given, will add a 1st (0th) rank (None) to
the created variable. If it is an int, will add that int instead of None.
Default: False.
time_major (bool): Only relevant if both `add_batch_rank` and `add_time_rank` are True.
Will make the time-rank the 0th rank and the batch-rank the 1st rank.
Otherwise, batch-rank will be 0th and time-rank will be 1st.
Default: False.
flatten (bool): Whether to produce a FlattenedDataOp with auto-keys.
local (bool): Whether the variable must not be shared across the network.
Default: False.
use_resource (bool): Whether to use the new tf resource-type variables.
Default: False.
Returns:
DataOp: The actual variable (dependent on the backend) or - if from
a ContainerSpace - a FlattenedDataOp or ContainerDataOp depending on the Space.
"""
# Overwrite the given trainable parameter, iff self.trainable is actually defined as a bool.
trainable = self.trainable if self.trainable is not None else (trainable if trainable is not None else True)
# Called as getter.
if shape is None and initializer is None and from_space is None:
if name not in self.variables:
raise KeyError(
"Variable with name '{}' not found in registry of Component '{}'!".format(name, self.name)
)
# TODO: Maybe try both the pure name AND the name with global-scope in front.
return self.variables[name]
# Called as setter.
var = None
# We are creating the variable using a Space as template.
if from_space is not None:
var = self._variable_from_space(
flatten, from_space, name, add_batch_rank, add_time_rank, time_major, trainable, initializer, local,
use_resource
)
# TODO: Figure out complete concept for python/numpy based Components (including their handling of variables).
# Assume that when using pytorch, we use Python/numpy collections to store data.
elif self.backend == "python" or get_backend() == "python" or get_backend() == "pytorch":
if add_batch_rank is not False and isinstance(add_batch_rank, int):
if isinstance(add_time_rank, int):
if time_major:
var = [[initializer for _ in range_(add_batch_rank)] for _ in range_(add_time_rank)]
else:
var = [[initializer for _ in range_(add_time_rank)] for _ in range_(add_batch_rank)]
else:
var = [initializer for _ in range_(add_batch_rank)]
elif add_time_rank is not False and isinstance(add_time_rank, int):
var = [initializer for _ in range_(add_time_rank)]
elif initializer is not None:
# Return
var = initializer
else:
var = []
return var
# Direct variable creation (using the backend).
elif get_backend() == "tf":
# Provide a shape, if initializer is not given or it is an actual Initializer object (rather than an array
# of fixed values, for which we then don't need a shape as it comes with one).
if initializer is None or isinstance(initializer, tf.keras.initializers.Initializer):
shape = tuple((() if add_batch_rank is False else
(None,) if add_batch_rank is True else (add_batch_rank,)) + (shape or ()))
# Numpyize initializer and give it correct dtype.
else:
shape = None
initializer = np.asarray(initializer, dtype=util.dtype(dtype, "np"))
var = tf.get_variable(
name=name, shape=shape, dtype=util.dtype(dtype), initializer=initializer, trainable=trainable,
collections=[tf.GraphKeys.GLOBAL_VARIABLES if local is False else tf.GraphKeys.LOCAL_VARIABLES],
use_resource=use_resource
)
elif get_backend() == "tf-eager":
shape = tuple(
(() if add_batch_rank is False else (None,) if add_batch_rank is True else (add_batch_rank,)) +
(shape or ())
)
var = eager.Variable(
name=name, shape=shape, dtype=util.dtype(dtype), initializer=initializer, trainable=trainable,
collections=[tf.GraphKeys.GLOBAL_VARIABLES if local is False else tf.GraphKeys.LOCAL_VARIABLES]
)
# TODO: what about python variables?
# Registers the new variable with this Component.
key = ((self.reuse_variable_scope + "/") if self.reuse_variable_scope else
(self.global_scope + "/") if self.global_scope else "") + name
# Container-var: Save individual Variables.
# TODO: What about a var from Tuple space?
if isinstance(var, OrderedDict):
for sub_key, v in var.items():
self.variables[key + sub_key] = v
else:
self.variables[key] = var
return var
def _variable_from_space(self, flatten, from_space, name, add_batch_rank, add_time_rank, time_major, trainable,
initializer, local=False, use_resource=False):
"""
Private variable from space helper, see 'get_variable' for API.
"""
# Variables should be returned in a flattened OrderedDict.
if get_backend() == "tf":
if self.reuse_variable_scope is not None:
with tf.variable_scope(name_or_scope=self.reuse_variable_scope, reuse=True):
if flatten:
return from_space.flatten(mapping=lambda key_, primitive: primitive.get_variable(
name=name + key_, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank,
time_major=time_major, trainable=trainable, initializer=initializer,
is_python=(self.backend == "python" or get_backend() == "python"),
local=local, use_resource=use_resource
))
# Normal, nested Variables from a Space (container or primitive).
else:
return from_space.get_variable(
name=name, add_batch_rank=add_batch_rank, trainable=trainable, initializer=initializer,
is_python=(self.backend == "python" or get_backend() == "python"),
local=local, use_resource=use_resource
)
else:
if flatten:
return from_space.flatten(mapping=lambda key_, primitive: primitive.get_variable(
name=name + key_, add_batch_rank=add_batch_rank, add_time_rank=add_time_rank,
time_major=time_major, trainable=trainable, initializer=initializer,
is_python=(self.backend == "python" or get_backend() == "python"),
local=local, use_resource=use_resource
))
# Normal, nested Variables from a Space (container or primitive).
else:
return from_space.get_variable(
name=name, add_batch_rank=add_batch_rank, trainable=trainable, initializer=initializer,
is_python=(self.backend == "python" or get_backend() == "python"),
local=local, use_resource=use_resource
)
[docs] def get_variables(self, *names, **kwargs):
"""
Utility method to get one or more component variable(s) by name(s).
Args:
names (List[str]): Lookup name strings for variables. None for all.
Keyword Args:
collections (set): A set of collections to which the variables have to belong in order to be returned here.
Default: tf.GraphKeys.TRAINABLE_VARIABLES
custom_scope_separator (str): The separator to use in the returned dict for scopes.
Default: '/'.
global_scope (bool): Whether to use keys in the returned dict that include the global-scopes of the
Variables. Default: False.
Returns:
dict: A dict mapping variable names to their get_backend variables.
"""
if get_backend() == "tf":
collections = kwargs.pop("collections", None) or tf.GraphKeys.GLOBAL_VARIABLES
custom_scope_separator = kwargs.pop("custom_scope_separator", "/")
global_scope = kwargs.pop("global_scope", True)
assert not kwargs, "{}".format(kwargs)
if len(names) == 1 and isinstance(names[0], list):
names = names[0]
names = util.force_list(names)
# Return all variables of this Component (for some collection).
if len(names) == 0:
collection_variables = tf.get_collection(collections)
ret = dict()
for v in collection_variables:
lookup = re.sub(r':\d+$', "", v.name)
if lookup in self.variables:
if global_scope:
# Replace the scope separator with a custom one.
ret[re.sub(r'(/|{}|{})'.format(FLAT_TUPLE_CLOSE, FLAT_TUPLE_OPEN),
custom_scope_separator, lookup)] = v
else:
ret[re.sub(r'^.+/', "", lookup)] = v
return ret
# Return only variables of this Component by name.
else:
return self.get_variables_by_name(
*names, custom_scope_separator=custom_scope_separator, global_scope=global_scope
)
elif get_backend() == "pytorch":
# Just return variables for this component.
custom_scope_separator = kwargs.pop("custom_scope_separator", "/")
global_scope = kwargs.pop("global_scope", True)
return self.get_variables_by_name(
*names, custom_scope_separator=custom_scope_separator, global_scope=global_scope
)
[docs] def get_variables_by_name(self, *names, **kwargs):
"""
Retrieves this components variables by name.
Args:
names (List[str]): List of names of Variable to return.
Keyword Args:
custom_scope_separator (str): The separator to use in the returned dict for scopes.
Default: '/'.
global_scope (bool): Whether to use keys in the returned dict that include the global-scopes of the
Variables. Default: False.
Returns:
dict: Dict containing the requested names as keys and variables as values.
"""
custom_scope_separator = kwargs.pop("custom_scope_separator", "/")
global_scope = kwargs.pop("global_scope", False)
assert not kwargs
variables = dict()
if get_backend() == "tf":
for name in names:
global_scope_name = ((self.global_scope + "/") if self.global_scope else "") + name
if name in self.variables:
variables[re.sub(r'/', custom_scope_separator, name)] = self.variables[name]
elif global_scope_name in self.variables:
if global_scope:
variables[re.sub(r'/', custom_scope_separator, global_scope_name)] = self.variables[
global_scope_name]
else:
variables[name] = self.variables[global_scope_name]
elif get_backend() == "pytorch":
for name in names:
global_scope_name = ((self.global_scope + "/") if self.global_scope else "") + name
if name in self.variables:
pytorch_var = self.variables[name]
variables[re.sub(r'/', custom_scope_separator, name)] = pytorch_var.get_value()
elif global_scope_name in self.variables:
if global_scope:
pytorch_var = self.variables[global_scope_name]
variables[re.sub(r'/', custom_scope_separator, global_scope_name)] = pytorch_var.get_value()
else:
variables[name] = self.variables[global_scope_name].get_value()
return variables
[docs] def create_summary(self, name, values, type_="histogram"):
"""
Creates a summary op (and adds it to the graph).
Skips those, whose full name does not match `self.summary_regexp`.
Args:
name (str): The name for the summary. This has to match `self.summary_regexp`.
The name should not contain a "summary"-prefix or any global scope information
(both will be added automatically by this method).
values (op): The op to summarize.
type\_ (str): The summary type to create. Currently supported are:
"histogram", "scalar" and "text".
"""
# Prepend the "summaries/"-prefix.
name = "summaries/" + name
# Get global name.
global_name = ((self.global_scope + "/") if self.global_scope else "") + name
# Skip non matching summaries (all if summary_regexp is None).
if self.summary_regexp is None or not re.search(self.summary_regexp, global_name):
return
summary = None
if get_backend() == "tf":
ctor = getattr(tf.summary, type_)
summary = ctor(name, values)
# Registers the new summary with this Component.
if global_name in self.summaries:
raise RLGraphError("ERROR: Summary with name '{}' already exists in {}'s summary "
"registry!".format(global_name, self.name))
self.summaries[global_name] = summary
self.propagate_summary(global_name)
[docs] def propagate_summary(self, key_):
"""
Propagates a single summary op of this Component to its parents' summaries registries.
Args:
key\_ (str): The lookup key for the summary to propagate.
"""
# Return if there is no parent.
if self.parent_component is None:
return
# If already there -> Error.
if key_ in self.parent_component.summaries:
raise RLGraphError("ERROR: Summary registry of '{}' already has a summary under key '{}'!".
format(self.parent_component.name, key_))
self.parent_component.summaries[key_] = self.summaries[key_]
# Recurse up the container hierarchy.
self.parent_component.propagate_summary(key_)
[docs] def add_components(self, *components, **kwargs):
"""
Adds sub-components to this one.
Args:
components (List[Component]): The list of Component objects to be added into this one.
Keyword Args:
expose_apis (Optional[Set[str]]): An optional set of strings with API-methods of the child component
that should be exposed as the parent's API via a simple wrapper API-method for the parent (that
calls the child's API-method).
#exposed_must_be_complete (bool): Whether the exposed API methods must be input-complete or not.
"""
expose_apis = kwargs.pop("expose_apis", set())
if isinstance(expose_apis, str):
expose_apis = {expose_apis}
for component in components:
# Try to create Component from spec.
if not isinstance(component, Component):
component = Component.from_spec(component)
# Make sure no two components with the same name are added to this one (own scope doesn't matter).
if component.name in self.sub_components:
raise RLGraphError("ERROR: Sub-Component with name '{}' already exists in this one!".
format(component.name))
# Make sure each Component can only be added once to a parent/container Component.
elif component.parent_component is not None:
raise RLGraphError("ERROR: Sub-Component with name '{}' has already been added once to a container "
"Component! Each Component can only be added once to a parent.".format(component.name))
# Make sure we don't add to ourselves.
elif component is self:
raise RLGraphError("ERROR: Cannot add a Component ({}) as a sub-Component to itself!".format(self.name))
component.parent_component = self
self.sub_components[component.name] = component
# Add own reusable scope to front of sub-components'.
if self.reuse_variable_scope is not None:
# Propagate reuse_variable_scope down to the added Component's sub-components.
component.propagate_sub_component_properties(
properties=dict(reuse_variable_scope=self.reuse_variable_scope)
)
# Fix the sub-component's (and sub-sub-component's etc..) scope(s).
self.propagate_scope(component)
# Execution modes must be coherent within one component subgraph.
self.propagate_sub_component_properties(
properties=dict(execution_mode=self.execution_mode), component=component
)
# Should we expose some API-methods of the child?
for api_method_name, api_method_rec in component.api_methods.items():
if api_method_name in expose_apis:
# Hold these here to avoid fixtures (when Components get copied).
name_ = api_method_name
component_name = component.name
must_be_complete = api_method_rec.must_be_complete
@rlgraph_api(component=self, name=api_method_name, must_be_complete=must_be_complete)
def exposed_api_method_wrapper(self, *inputs):
# Complicated way to lookup sub-component's method to avoid fixtures when original
# component gets copied.
return getattr(self.sub_components[component_name], name_)(*inputs)
[docs] def get_all_sub_components(self, list_=None, level_=0):
"""
Returns all sub-Components (including self) sorted by their nesting-level (... grand-children before children
before parents).
Args:
list\_ (Optional[List[Component]])): A list of already collected components to append to.
level\_ (int): The slot indicating the Component level depth in `list_` at which we are currently.
Returns:
List[Component]: A list with all the sub-components in `self` and `self` itself.
"""
return_ = False
if list_ is None:
list_ = dict()
return_ = True
if level_ not in list_:
list_[level_] = list()
list_[level_].append(self)
level_ += 1
for sub_component in self.sub_components.values():
sub_component.get_all_sub_components(list_, level_)
if return_:
ret = list()
for l in sorted(list_.keys(), reverse=True):
ret.extend(sorted(list_[l], key=lambda c: c.scope))
return ret
[docs] def get_sub_component_by_global_scope(self, scope):
"""
Returns a sub-Component (or None if not found) by scope. The sub-coponent's scope should be given
as global scope of the sub-component (not local scope with respect to this Component).
Args:
scope (str): The global scope of the sub-Component we are looking for.
Returns:
Component: The sub-Component with the given global scope if found, None if not found.
"""
# TODO: make method more efficient.
components = self.get_all_sub_components()
for component in components:
if component.global_scope == scope:
return component
return None
[docs] def get_sub_component_by_name(self, name):
"""
Returns a sub-Component (or None if not found) by its name (local scope). The sub-Component must be a direct
sub-Component of `self`.
Args:
name (str): The name (local scope) of the sub-Component we are looking for.
Returns:
Component: The sub-Component with the given name if found, None if not found.
Raises:
RLGraphError: If a sub-Component by that name could not be found.
"""
sub_component = self.sub_components.get(name, None)
if sub_component is None:
raise RLGraphError("ERROR: sub-Component with name '{}' not found in '{}'!".format(name, self.global_scope))
return sub_component
[docs] def remove_sub_component_by_name(self, name):
"""
Removes a sub-component from this one by its name. Thereby sets the `parent_component` property of the
removed Component to None.
Raises an error if the sub-component does not exist.
Args:
name (str): The name of the sub-component to be removed.
Returns:
Component: The removed component.
"""
assert name in self.sub_components, "ERROR: Component {} cannot be removed because it is not" \
"a sub-component. Sub-components by name are: {}.".format(name, list(self.sub_components.keys()))
removed_component = self.sub_components.pop(name)
# Set parent of the removed component to None.
removed_component.parent_component = None
return removed_component
[docs] def get_parents(self):
"""
Returns a list of parent and grand-parents of this component.
Returns:
List[Component]: A list (may be empty if this component has no parents) of all parent and grand-parents.
"""
ret = list()
component = self
while component.parent_component is not None:
ret.append(component.parent_component)
component = component.parent_component
return ret
[docs] def propagate_scope(self, sub_component):
"""
Fixes all the sub-Component's (and its sub-Component's) global_scopes.
Args:
sub_component (Optional[Component]): The sub-Component object whose global_scope needs to be updated.
Use None for this Component itself.
"""
# TODO this should be moved to use generic method below, but checking if global scope if set
# TODO does not work well within that.
if sub_component is None:
sub_component = self
elif self.global_scope:
sub_component.global_scope = self.global_scope + (
("/" + sub_component.scope) if sub_component.scope else "")
# Recurse.
for sc in sub_component.sub_components.values():
sub_component.propagate_scope(sc)
[docs] def propagate_sub_component_properties(self, properties, component=None):
"""
Recursively updates properties of component and its sub-components.
Args:
properties (dict): Dict with names of properties and their values to recursively update
sub-components with.
component (Optional([Component])): Component to recursively update. Uses self if None.
"""
if component is None:
component = self
properties_scoped = copy.deepcopy(properties)
for name, value in properties.items():
# Property is some scope (value is then that scope of the parent component).
# Have to align it with sub-component's local scope.
if value and (name == "global_scope" or name == "reuse_variable_scope"):
value += (("/" + component.scope) if component.scope else "")
properties_scoped[name] = value
setattr(component, name, value)
# Normal property: Set to static given value.
else:
setattr(component, name, value)
for sc in component.sub_components.values():
component.propagate_sub_component_properties(properties_scoped, sc)
[docs] def propagate_variables(self, keys=None):
"""
Propagates all variable from this Component to its parents' variable registries.
Args:
keys (Optional[List[str]]): An optional list of variable names to propagate. Should only be used in
internal, recursive calls to this same method.
"""
# Return if there is no parent.
if self.parent_component is None:
return
# Add all our variables to parent's variable registry.
keys = keys or self.variables.keys()
for key in keys:
# If already there (bubbled up from some child component that was input complete before us)
# -> Make sure the variable object is identical.
if key in self.parent_component.variables:
if self.variables[key] is not self.parent_component.variables[key]:
raise RLGraphError("ERROR: Variable registry of '{}' already has a variable under key '{}'!". \
format(self.parent_component.name, key))
self.parent_component.variables[key] = self.variables[key]
# Recurse up the container hierarchy.
self.parent_component.propagate_variables(keys)
[docs] def copy(self, name=None, scope=None, device=None, trainable=None,
reuse_variable_scope=None, reuse_variable_scope_for_sub_components=None):
"""
Copies this component and returns a new component with possibly another name and another scope.
The new component has its own variables (they are not shared with the variables of this component as they
will be created after this copy anyway, during the build phase).
and is initially not connected to any other component.
Args:
name (str): The name of the new Component. If None, use the value of scope.
scope (str): The scope of the new Component. If None, use the same scope as this component.
device (str): The device of the new Component. If None, use the same device as this one.
trainable (Optional[bool]): Whether to make all variables in this component trainable or not. Use None
for no specific preference.
reuse_variable_scope (Optional[str]): If not None, variables of the copy will be shared under this scope.
reuse_variable_scope_for_sub_components (Optional[str]): If not None, variables only of the sub-components
of the copy will be shared under this scope.
Returns:
Component: The copied component object.
"""
# Make sure we are still in the assembly phase (should not copy actual ops).
assert self.input_complete is False, "ERROR: Cannot copy a Component ('{}') that has already been built!". \
format(self.name)
if scope is None:
scope = self.scope
if name is None:
name = scope
if device is None:
device = self.device
if trainable is None:
trainable = self.trainable
# Simply deepcopy self and change name and scope.
new_component = copy.deepcopy(self)
new_component.name = name
new_component.scope = scope
# Change global_scope for the copy and all its sub-components.
new_component.global_scope = scope
new_component.propagate_scope(sub_component=None)
# Propagate reusable scope, device and other trainable.
new_component.propagate_sub_component_properties(
properties=dict(device=device, trainable=trainable)
)
if reuse_variable_scope:
new_component.propagate_sub_component_properties(dict(reuse_variable_scope=reuse_variable_scope))
# Gives us the chance to skip new_component's scope.
elif reuse_variable_scope_for_sub_components:
for sc in new_component.sub_components.values():
sc.propagate_sub_component_properties(dict(reuse_variable_scope=reuse_variable_scope_for_sub_components))
# Erase the parent pointer.
new_component.parent_component = None
return new_component
[docs] @staticmethod
def scatter_update_variable(variable, indices, updates):
"""
Updates a variable. Optionally returns the operation depending on the backend.
Args:
variable (any): Variable to update.
indices (array): Indices to update.
updates (any): Update values.
Returns:
Optional[op]: The graph operation representing the update (or None).
"""
if get_backend() == "tf":
return tf.scatter_update(ref=variable, indices=indices, updates=updates)
[docs] @staticmethod
def assign_variable(ref, value):
"""
Assigns a variable to a value.
Args:
ref (any): The variable to assign to.
value (any): The value to use for the assignment.
Returns:
Optional[op]: None or the graph operation representing the assginment.
"""
if get_backend() == "tf":
if type(value).__name__ == "Variable":
return tf.assign(ref=ref, value=value.read_value())
else:
return tf.assign(ref=ref, value=value)
elif get_backend() == "pytorch":
ref.set_value(value)
[docs] @staticmethod
def read_variable(variable, indices=None):
"""
Reads a variable.
Args:
variable (DataOp): The variable whose value to read.
indices (Optional[np.ndarray,tf.Tensor]): Indices (if any) to fetch from the variable.
Returns:
any: Variable values.
"""
if get_backend() == "tf":
if indices is not None:
# Could be redundant, question is if there may be special read operations
# in other backends, or read from remote variable requiring extra args.
return tf.gather(params=variable, indices=indices)
else:
return variable
elif get_backend() == "pytorch":
# This is not a useful call but for interop in some components
return variable.get_value()
[docs] def sub_component_by_name(self, scope_name):
"""
Returns a sub-component of this component by its name.
Args:
scope_name (str): Name of the component. This is typically its scope.
Returns:
Component: Sub-component if it exists.
Raises:
ValueError: Error if no sub-component with this name exists.
"""
if scope_name not in self.sub_components:
raise RLGraphError(
"Name {} is not a valid sub-component name for component {}. Sub-Components are: {}".
format(scope_name, self.__str__(), self.sub_components.keys())
)
return self.sub_components[scope_name]
def _post_build(self, component):
component._post_define_by_run_build()
for sub_component in component.sub_components.values():
self._post_build(sub_component)
def _post_define_by_run_build(self):
"""
Optionally execute post-build calls.
"""
pass
@rlgraph_api(returns=1)
def _graph_fn__variables(self):
"""
Outputs all of this Component's variables in a DataOpDict (API-method "_variables").
This can be used e.g. to sync this Component's variables into another Component, which owns
a Synchronizable() as a sub-component. The returns values of this graph_fn are then sent into
the other Component's API-method `sync` (parameter: "values") for syncing.
Returns:
DataOpDict: Dict with keys=variable names and values=variable (SingleDataOp).
"""
# Must use custom_scope_separator here b/c RLGraph doesn't allow Dict with '/'-chars in the keys.
# '/' could collide with a FlattenedDataOp's keys and mess up the un-flatten process.
variables_dict = self.get_variables(custom_scope_separator="-")
return DataOpDict(variables_dict)
[docs] @staticmethod
def reset_profile():
"""
Sets profiling values to 0.
"""
Component.call_count = 0
Component.call_times = []
def __str__(self):
return "{}('{}' api={})".format(type(self).__name__, self.name, str(list(self.api_methods.keys())))