# Copyright 2018 The Rlgraph Authors, All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import inspect
import re
import time
from rlgraph.spaces.space_utils import get_space_from_op
from rlgraph.utils.op_records import GraphFnRecord, APIMethodRecord, DataOpRecord, DataOpRecordColumnIntoAPIMethod, \
DataOpRecordColumnFromAPIMethod, DataOpRecordColumnIntoGraphFn, DataOpRecordColumnFromGraphFn
from rlgraph.utils.rlgraph_errors import RLGraphError, RLGraphAPICallParamError
from rlgraph.utils import util
# Global registries for Component classes' API-methods and graph_fn.
component_api_registry = {}
component_graph_fn_registry = {}
def rlgraph_api(api_method=None, *, component=None, name=None, returns=None,
flatten_ops=False, split_ops=False, add_auto_key_as_first_param=False,
must_be_complete=True, ok_to_overwrite=False):
"""
API-method decorator used to tag any Component's methods as API-methods.
Args:
api_method (callable): The actual function/method to tag as an API method.
component (Optional[Component]): The Component that the method should belong to. None if `api_method` is
decorated inside a Component class.
name (Optional[str]): The name under which the API-method should be registered. This is only necessary if
the API-method is automatically generated as a thin-wrapper around a graph_fn.
returns (Optional[int]): If the function is a graph_fn, we may specify, how many return values
it returns. If None, will try to get this number from looking at the source code or from the Component's
`num_graph_fn_return_values` property.
flatten_ops (Union[bool,Set[str]]): Whether to flatten all or some DataOps by creating
a FlattenedDataOp (with automatic key names).
Can also be a set of in-Socket names to flatten explicitly (True for all).
(default: True).
split_ops (Union[bool,Set[str]]): Whether to split all or some of the already flattened DataOps
and send the SingleDataOps one by one through the graph_fn.
Example: Spaces=A=Dict (container), B=int (primitive)
The graph_fn should then expect for each primitive Space in A:
_graph_fn(primitive-in-A (Space), B (int))
NOTE that B will be the same in all calls for all primitive-in-A's.
(default: True).
add_auto_key_as_first_param (bool): If `split_ops` is not False, whether to send the
automatically generated flat key as the very first parameter into each call of the graph_fn.
Example: Spaces=A=float (primitive), B=Tuple (container)
The graph_fn should then expect for each primitive Space in B:
_graph_fn(key, A (float), primitive-in-B (Space))
NOTE that A will be the same in all calls for all primitive-in-B's.
The key can now be used to index into variables equally structured as B.
Has no effect if `split_ops` is False.
(default: False).
must_be_complete (bool): Whether the exposed API methods must be input-complete or not.
ok_to_overwrite (bool): Set to True to indicate that this API-decorator will overwrite an already existing
API-method in the Component. Default: False.
Returns:
callable: The decorator function.
"""
_sanity_check_decorator_options(flatten_ops, split_ops, add_auto_key_as_first_param)
def decorator_func(wrapped_func):
def api_method_wrapper(self, *args, **kwargs):
name_ = name or re.sub(r'^_graph_fn_', "", wrapped_func.__name__)
return_ops = kwargs.pop("return_ops", False)
# Direct evaluation of function.
if self.execution_mode == "define_by_run":
type(self).call_count += 1
start = time.perf_counter()
# Check with owner if extra args needed.
if name_ in self.api_methods and self.api_methods[name_].add_auto_key_as_first_param:
output = wrapped_func(self, "", *args, **kwargs)
else:
output = wrapped_func(self, *args, **kwargs)
# Store runtime for this method.
type(self).call_times.append( # Component.call_times
(self.name, wrapped_func.__name__, time.perf_counter() - start)
)
return output
api_method_rec = self.api_methods[name_]
# Create op-record column to call API method with. Ignore None input params. These should not be sent
# to the API-method.
in_op_column = DataOpRecordColumnIntoAPIMethod(
component=self, api_method_rec=api_method_rec, args=args, kwargs=kwargs
)
# Add the column to the API-method record.
api_method_rec.in_op_columns.append(in_op_column)
# Check minimum number of passed args.
minimum_num_call_params = len(in_op_column.api_method_rec.non_args_kwargs) - \
len(in_op_column.api_method_rec.default_args)
if len(in_op_column.op_records) < minimum_num_call_params:
raise RLGraphAPICallParamError(
"Number of call params ({}) for call to API-method '{}' is too low. Needs to be at least {} "
"params!".format(len(in_op_column.op_records), api_method_rec.name, minimum_num_call_params)
)
# Link from incoming op_recs into the new column or populate new column with ops/Spaces (this happens
# if this call was made from within a graph_fn such that ops and Spaces are already known).
all_args = [(i, a) for i, a in enumerate(args) if a is not None] + \
[(k, v) for k, v in sorted(kwargs.items()) if v is not None]
flex = None
for i, (key, value) in enumerate(all_args):
# Named arg/kwarg -> get input_name from that and peel op_rec.
if isinstance(key, str):
param_name = key
# Positional arg -> get input_name from input_names list.
else:
slot = key if flex is None else flex
if slot >= len(api_method_rec.input_names):
raise RLGraphAPICallParamError(
"Too many input args given in call to API-method '{}'!".format(api_method_rec.name)
)
param_name = api_method_rec.input_names[slot]
# Var-positional arg, attach the actual position to input_name string.
if self.api_method_inputs[param_name] == "*flex":
if flex is None:
flex = i
param_name += "[{}]".format(i - flex)
# We are already in building phase (params may be coming from inside graph_fn).
if self.graph_builder is not None and self.graph_builder.phase == "building":
self.api_method_inputs[param_name] = in_op_column.op_records[i].space
# Check input-completeness of Component (but not strict as we are only calling API, not a graph_fn).
if self.input_complete is False:
# Check Spaces and create variables.
self.graph_builder.build_component_when_input_complete(self)
# A DataOpRecord from the meta-graph.
elif isinstance(value, DataOpRecord):
if param_name not in self.api_method_inputs:
self.api_method_inputs[param_name] = None
# Fixed value (instead of op-record): Store the fixed value directly in the op.
else:
#in_op_column.op_records[i].space = get_space_from_op(value)
if param_name not in self.api_method_inputs or self.api_method_inputs[param_name] is None:
self.api_method_inputs[param_name] = in_op_column.op_records[i].space
# Regular API-method: Call it here.
args_, kwargs_ = in_op_column.get_args_and_kwargs()
if api_method_rec.is_graph_fn_wrapper is False:
return_values = wrapped_func(self, *args_, **kwargs_)
# Wrapped graph_fn: Call it through yet another wrapper.
else:
return_values = graph_fn_wrapper(
self, wrapped_func, returns, dict(
flatten_ops=flatten_ops, split_ops=split_ops,
add_auto_key_as_first_param=add_auto_key_as_first_param
), *args_, **kwargs_
)
# Process the results (push into a column).
out_op_column = DataOpRecordColumnFromAPIMethod(
component=self,
api_method_name=name_,
args=util.force_tuple(return_values) if type(return_values) != dict else None,
kwargs=return_values if type(return_values) == dict else None
)
# If we already have actual op(s) and Space(s), push them already into the
# DataOpRecordColumnFromAPIMethod's records.
if self.graph_builder is not None and self.graph_builder.phase == "building":
# Link the returned ops to that new out-column.
for i, rec in enumerate(out_op_column.op_records):
out_op_column.op_records[i].op = rec.op
out_op_column.op_records[i].space = rec.space
# And append the new out-column to the api-method-rec.
api_method_rec.out_op_columns.append(out_op_column)
# Do we need to return the raw ops or the op-recs?
# Direct parent caller is a `_graph_fn_...`: Return raw ops.
stack = inspect.stack()
if return_ops is True or re.match(r'^_graph_fn_.+$', stack[1][3]):
if type(return_values) == dict:
return {key: value.op for key, value in out_op_column.get_args_and_kwargs()[1].items()}
else:
tuple_ = tuple(map(lambda x: x.op, out_op_column.get_args_and_kwargs()[0]))
return tuple_[0] if len(tuple_) == 1 else tuple_
# Parent caller is non-graph_fn: Return op-recs.
else:
if type(return_values) == dict:
return return_values
else:
tuple_ = out_op_column.get_args_and_kwargs()[0]
return tuple_[0] if len(tuple_) == 1 else tuple_
func_type = util.get_method_type(wrapped_func)
is_graph_fn_wrapper = (func_type == "graph_fn")
name_ = name or (re.sub(r'^_graph_fn_', "", wrapped_func.__name__) if is_graph_fn_wrapper else
wrapped_func.__name__)
api_method_rec = APIMethodRecord(
func=wrapped_func, wrapper_func=api_method_wrapper,
name=name_,
must_be_complete=must_be_complete, ok_to_overwrite=ok_to_overwrite,
is_graph_fn_wrapper=is_graph_fn_wrapper, is_class_method=(component is None),
flatten_ops=flatten_ops, split_ops=split_ops, add_auto_key_as_first_param=add_auto_key_as_first_param
)
# Registers the given method with the Component (if not already done so).
if component is not None:
define_api_method(component, api_method_rec, copy_=False)
# Registers the given function with the Component sub-class so we can define it for each
# constructed instance of that sub-class.
else:
cls = wrapped_func.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
if cls not in component_api_registry:
component_api_registry[cls] = list()
component_api_registry[cls].append(api_method_rec)
return api_method_wrapper
if api_method is None:
return decorator_func
else:
return decorator_func(api_method)
def graph_fn(graph_fn=None, *, returns=None,
flatten_ops=False, split_ops=False, add_auto_key_as_first_param=False):
"""
Graph_fn decorator used to tag any Component's graph_fn (that is not directly wrapped by an API-method) as such.
Args:
graph_fn (callable): The actual graph_fn to tag.
returns (Optional[int]): How many return values it returns. If None, will try to get this number from looking at the source code or from the Component's
`num_graph_fn_return_values` property.
flatten_ops (Union[bool,Set[str]]): Whether to flatten all or some DataOps by creating
a FlattenedDataOp (with automatic key names).
Can also be a set of in-Socket names to flatten explicitly (True for all).
(default: True).
split_ops (Union[bool,Set[str]]): Whether to split all or some of the already flattened DataOps
and send the SingleDataOps one by one through the graph_fn.
Example: Spaces=A=Dict (container), B=int (primitive)
The graph_fn should then expect for each primitive Space in A:
_graph_fn(primitive-in-A (Space), B (int))
NOTE that B will be the same in all calls for all primitive-in-A's.
(default: True).
add_auto_key_as_first_param (bool): If `split_ops` is not False, whether to send the
automatically generated flat key as the very first parameter into each call of the graph_fn.
Example: Spaces=A=float (primitive), B=Tuple (container)
The graph_fn should then expect for each primitive Space in B:
_graph_fn(key, A (float), primitive-in-B (Space))
NOTE that A will be the same in all calls for all primitive-in-B's.
The key can now be used to index into variables equally structured as B.
Has no effect if `split_ops` is False.
(default: False).
Returns:
callable: The decorator function.
"""
_sanity_check_decorator_options(flatten_ops, split_ops, add_auto_key_as_first_param)
def decorator_func(wrapped_func):
def _graph_fn_wrapper(self, *args, **kwargs):
if self.execution_mode == "define_by_run":
return wrapped_func(self, *args, **kwargs)
else:
return graph_fn_wrapper(
self, wrapped_func, returns, dict(
flatten_ops=flatten_ops, split_ops=split_ops,
add_auto_key_as_first_param=add_auto_key_as_first_param
), *args, **kwargs
)
graph_fn_rec = GraphFnRecord(
func=wrapped_func, wrapper_func=_graph_fn_wrapper, is_class_method=True,
flatten_ops=flatten_ops, split_ops=split_ops, add_auto_key_as_first_param=add_auto_key_as_first_param
)
# Registers the given method with the Component (if not already done so).
# TODO: allow graph_fn to be defined outside a Component class as well.
#if component is not None:
# define_api_method(component, api_method_rec, copy_=False)
# Registers the given function with the Component sub-class so we can define it for each
# constructed instance of that sub-class.
#else:
cls = wrapped_func.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0]
if cls not in component_graph_fn_registry:
component_graph_fn_registry[cls] = list()
component_graph_fn_registry[cls].append(graph_fn_rec)
return _graph_fn_wrapper
if graph_fn is None:
return decorator_func
else:
return decorator_func(graph_fn)
def define_api_method(component, api_method_record, copy_=True):
"""
Registers an API-method with a Component instance.
Args:
component (Component): The Component object to register the API method with.
api_method_record (APIMethodRecord): The APIMethodRecord describing the to-be-registered API-method.
copy_ (bool): Whether to deepcopy the APIMethodRecord prior to handing it to the Component for storing.
"""
# Deep copy the record (in case this got registered the normal way with via decorating a class method).
if copy_:
api_method_record = copy.deepcopy(api_method_record)
api_method_record.component = component
# Raise errors if `name` already taken in this Component.
if not api_method_record.ok_to_overwrite:
# There already is an API-method with that name.
if api_method_record.name in component.api_methods:
raise RLGraphError("API-method with name '{}' already defined!".format(api_method_record.name))
# There already is another object property with that name (avoid accidental overriding).
elif not api_method_record.is_class_method and getattr(component, api_method_record.name, None) is not None:
raise RLGraphError(
"Component '{}' already has a property called '{}'. Cannot define an API-method with "
"the same name!".format(component.name, api_method_record.name)
)
# Do not build this API as per ctor instructions.
if api_method_record.name in component.switched_off_apis:
return
component.synthetic_methods.add(api_method_record.name)
setattr(component, api_method_record.name, api_method_record.wrapper_func.__get__(component, component.__class__))
setattr(api_method_record.wrapper_func, "__name__", api_method_record.name)
component.api_methods[api_method_record.name] = api_method_record
# Direct callable for eager/define by run.
component.api_fn_by_name[api_method_record.name] = api_method_record.wrapper_func
# Update the api_method_inputs dict (with empty Spaces if not defined yet).
skip_args = 1 # self
skip_args += (api_method_record.is_graph_fn_wrapper and api_method_record.add_auto_key_as_first_param)
param_list = list(inspect.signature(api_method_record.func).parameters.values())[skip_args:]
for param in param_list:
component.api_methods[api_method_record.name].input_names.append(param.name)
if param.name not in component.api_method_inputs:
# This param has a default value.
if param.default != inspect.Parameter.empty:
# Default is None. Set to "flex" (to signal that this Space is not needed for input-completeness)
# and wait for first call using this parameter (only then set it to that Space).
if param.default is None:
component.api_method_inputs[param.name] = "flex"
# Default is some python value (e.g. a bool). Use that are the assigned Space.
else:
space = get_space_from_op(param.default)
component.api_method_inputs[param.name] = space
# This param is an *args param. Store as "*flex". Then with upcoming API calls, we determine the Spaces
# for the single items in *args and set them under "param[0]", "param[1]", etc..
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
component.api_method_inputs[param.name] = "*flex"
# This param is a **kwargs param. Store as "**flex". Then with upcoming API calls, we determine the Spaces
# for the single items in **kwargs and set them under "param[some-key]", "param[some-other-key]", etc..
elif param.kind == inspect.Parameter.VAR_KEYWORD:
component.api_method_inputs[param.name] = "**flex"
# Normal POSITIONAL_ONLY parameter. Store as None (needed) for now.
else:
component.api_method_inputs[param.name] = None
def define_graph_fn(component, graph_fn_record):
"""
"""
# Deep copy the record (in case this got registered the normal way with via decorating a class method).
graph_fn_record = copy.deepcopy(graph_fn_record)
graph_fn_record.component = component
setattr(component, graph_fn_record.name, graph_fn_record.wrapper_func.__get__(component, component.__class__))
setattr(graph_fn_record.func, "__self__", component)
component.graph_fns[graph_fn_record.name] = graph_fn_record
def graph_fn_wrapper(component, wrapped_func, returns, options, *args, **kwargs):
"""
Executes a dry run through a graph_fn (without calling it) just generating the empty
op-record-columns around the graph_fn (incoming and outgoing). Except if the GraphBuilder
is already in the "building" phase, in which case the graph_fn is actually called.
Args:
component (Component): The Component that this graph_fn belongs to.
wrapped_func (callable): The graph_fn to be called during the build process.
returns (Optional[int]): The number of return values of the graph_fn.
options (Dict): Dict with the following keys (optionally) set:
- flatten_ops (Union[bool,Set[str]]): Whether to flatten all or some DataOps by creating
a FlattenedDataOp (with automatic key names).
Can also be a set of in-Socket names to flatten explicitly (True for all).
(default: True).
- split_ops (Union[bool,Set[str]]): Whether to split all or some of the already flattened DataOps
and send the SingleDataOps one by one through the graph_fn.
Example: Spaces=A=Dict (container), B=int (primitive)
The graph_fn should then expect for each primitive Space in A:
_graph_fn(primitive-in-A (Space), B (int))
NOTE that B will be the same in all calls for all primitive-in-A's.
(default: True).
- add_auto_key_as_first_param (bool): If `split_ops` is not False, whether to send the
automatically generated flat key as the very first parameter into each call of the graph_fn.
Example: Spaces=A=float (primitive), B=Tuple (container)
The graph_fn should then expect for each primitive Space in B:
_graph_fn(key, A (float), primitive-in-B (Space))
NOTE that A will be the same in all calls for all primitive-in-B's.
The key can now be used to index into variables equally structured as B.
Has no effect if `split_ops` is False.
(default: False).
\*args (Union[DataOpRecord,np.array,numeric]): The DataOpRecords to be used for calling the method.
"""
flatten_ops = options.pop("flatten_ops", False)
split_ops = options.pop("split_ops", False)
add_auto_key_as_first_param = options.pop("add_auto_key_as_first_param", False)
# Store a graph_fn record in this component for better in/out-op-record-column reference.
if wrapped_func.__name__ not in component.graph_fns:
component.graph_fns[wrapped_func.__name__] = GraphFnRecord(
func=wrapped_func, wrapper_func=graph_fn_wrapper, component=component
)
# Generate in-going op-rec-column.
in_graph_fn_column = DataOpRecordColumnIntoGraphFn(
component=component, graph_fn=wrapped_func,
flatten_ops=flatten_ops, split_ops=split_ops,
add_auto_key_as_first_param=add_auto_key_as_first_param,
args=args, kwargs=kwargs
)
# Add the column to the `graph_fns` record.
component.graph_fns[wrapped_func.__name__].in_op_columns.append(in_graph_fn_column)
# We are already building: Actually call the graph_fn after asserting that its Component is input-complete.
if component.graph_builder and component.graph_builder.phase == "building":
# Assert input-completeness of Component (if not already, then after this graph_fn/Space update).
# if self.input_complete is False:
# Check Spaces and create variables.
component.graph_builder.build_component_when_input_complete(component)
assert component.input_complete
# TODO: This check should go in, but fails for multi-GPU DQN runs.
# if in_graph_fn_column.graph_fn.__name__ == "_graph_fn__variables":
# assert self.variable_complete
# Call the graph_fn.
out_graph_fn_column = component.graph_builder.run_through_graph_fn_with_device_and_scope(
in_graph_fn_column, create_new_out_column=True
)
# Check again, in case we are now also variable-complete.
component.graph_builder.build_component_when_input_complete(component)
# We are still in the assembly phase: Don't actually call the graph_fn. Only generate op-rec-columns
# around it (in-coming and out-going).
else:
# Create 2 op-record columns, one going into the graph_fn and one getting out of there and link
# them together via the graph_fn (w/o calling it).
# TODO: remove when we have numpy-based Components (then we can do test calls to infer everything automatically)
if wrapped_func.__name__ in component.graph_fn_num_outputs:
num_graph_fn_return_values = component.graph_fn_num_outputs[wrapped_func.__name__]
elif returns is not None:
num_graph_fn_return_values = returns
else:
num_graph_fn_return_values = util.get_num_return_values(wrapped_func)
component.logger.debug("Graph_fn has {} return values (inferred).".format(
wrapped_func.__name__, num_graph_fn_return_values)
)
# If in-column is empty, add it to the "empty in-column" set.
if len(in_graph_fn_column.op_records) == 0:
component.no_input_graph_fn_columns.add(in_graph_fn_column)
# Generate the out-op-column from the number of return values (guessed during assembly phase or
# actually measured during build phase).
out_graph_fn_column = DataOpRecordColumnFromGraphFn(
num_op_records=num_graph_fn_return_values,
component=component, graph_fn_name=wrapped_func.__name__,
in_graph_fn_column=in_graph_fn_column
)
in_graph_fn_column.out_graph_fn_column = out_graph_fn_column
component.graph_fns[wrapped_func.__name__].out_op_columns.append(out_graph_fn_column)
if len(out_graph_fn_column.op_records) == 1:
return out_graph_fn_column.op_records[0]
else:
return tuple(out_graph_fn_column.op_records)
def _sanity_check_call_parameters(self, params, method, method_type, add_auto_key_as_first_param):
raw_signature_parameters = inspect.signature(method).parameters
actual_params = list(raw_signature_parameters.values())
if add_auto_key_as_first_param is True:
actual_params = actual_params[1:]
if len(params) != len(actual_params):
# Check whether the last arg is var_positional (e.g. *inputs; in that case it's ok if the number of params
# is larger than that of the actual graph_fn params or its one smaller).
if actual_params[-1].kind == inspect.Parameter.VAR_POSITIONAL and (len(params) > len(actual_params) > 0 or
len(params) == len(actual_params) - 1):
pass
# Some actual params have default values: Number of given params must be at least as large as the number
# of non-default actual params but maximally as large as the number of actual_parameters.
elif len(actual_params) >= len(params) >= sum(
[p.default is inspect.Parameter.empty for p in actual_params]):
pass
else:
raise RLGraphError(
"ERROR: {} '{}/{}' has {} input-parameters, but {} ({}) were being provided in the "
"`Component.call` method!".format(method_type, self.name, method.__name__,
len(actual_params), len(params), params)
)
def _sanity_check_decorator_options(flatten_ops, split_ops, add_auto_key_as_first_param):
if split_ops:
assert flatten_ops,\
"ERROR in decorator options: `split_ops` cannot be True if `flatten_ops` is False!"
if add_auto_key_as_first_param:
assert split_ops,\
"ERROR in decorator options: `add_auto_key_as_first_param` cannot be True if `split_ops` is False!"