Source code for rlgraph
# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from rlgraph.version import __version__
import json
import os
import logging
# Libraries should add NullHandler() by default, as its the application code's
# responsibility to configure log handlers.
# https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library
try:
from logging import NullHandler
except ImportError:
class NullHandler(logging.Handler):
def emit(self, record):
pass
logging.getLogger(__name__).addHandler(NullHandler())
if "RLGRAPH_HOME" in os.environ:
rl_graph_dir = os.environ.get("RLGRAPH_HOME")
else:
rl_graph_dir = os.path.expanduser('~')
rl_graph_dir = os.path.join(rl_graph_dir, ".rlgraph")
# TODO "tensorflow" for tensorflow?
# Default backend ('tf' for tensorflow or 'pytorch' for PyTorch)
BACKEND = "tf"
# Default distributed backend is distributed ray.
DISTRIBUTED_BACKEND = "distributed_tf"
distributed_compatible_backends = dict(
tf=["distributed_tf", "ray", "horovod"],
pytorch=["ray", "horovod"]
)
config_path = os.path.expanduser(os.path.join(rl_graph_dir, 'rlgraph.json'))
if os.path.exists(config_path):
try:
with open(config_path) as f:
config = json.load(f)
except ValueError:
config = dict()
# Read from config or leave defaults.
backend = config.get("BACKEND", None)
if backend is not None:
BACKEND = backend
distributed_backend = config.get("DISTRIBUTED_BACKEND", None)
if distributed_backend is not None:
DISTRIBUTED_BACKEND = distributed_backend
# Create dir if necessary:
if not os.path.exists(rl_graph_dir):
try:
os.makedirs(rl_graph_dir)
except OSError:
pass
# Write to file if there was none:
if not os.path.exists(config_path):
_config = {
"BACKEND": BACKEND,
"DISTRIBUTED_BACKEND": DISTRIBUTED_BACKEND,
}
try:
with open(config_path, 'w') as f:
f.write(json.dumps(_config, indent=4))
except IOError:
# Except permission denied.
pass
# Overwrite backend if set in ENV.
if 'RLGRAPH_BACKEND' in os.environ:
backend = os.environ.get('RLGRAPH_BACKEND', None)
if backend is not None:
logging.info("Setting BACKEND to '{}' per environment variable 'RLGRAPH_BACKEND'.".format(backend))
BACKEND = backend
# Overwrite distributed-backend if set in ENV.
if 'RLGRAPH_DISTRIBUTED_BACKEND' in os.environ:
distributed_backend = os.environ.get('RLGRAPH_DISTRIBUTED_BACKEND', None)
if distributed_backend is not None:
logging.info(
"Setting DISTRIBUTED_BACKEND to '{}' per environment variable "
"'RLGRAPH_DISTRIBUTED_BACKEND'.".format(distributed_backend)
)
DISTRIBUTED_BACKEND = distributed_backend
# Test compatible backend.
if DISTRIBUTED_BACKEND not in distributed_compatible_backends[BACKEND]:
raise ValueError("Distributed backend {} not compatible with backend {}. Compatible backends"
"are: {}".format(DISTRIBUTED_BACKEND, BACKEND, distributed_compatible_backends[BACKEND]))
# Test imports.
if DISTRIBUTED_BACKEND == 'distributed_tf':
assert BACKEND == "tf"
try:
import tensorflow
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"INIT ERROR: Cannot run distributed_tf without backend (tensorflow)! Please install tensorflow first "
"via `pip install tensorflow` or `pip install tensorflow-gpu`."
)
elif DISTRIBUTED_BACKEND == "horovod":
try:
import horovod
except ModuleNotFoundError as e:
raise ValueError("INIT ERROR: Cannot run RLGraph with distributed backend Horovod.")
elif DISTRIBUTED_BACKEND == "ray":
try:
import ray
except ModuleNotFoundError as e:
raise ValueError("INIT ERROR: Cannot run RLGraph with distributed backend Ray.")
else:
raise ValueError("Distributed backend {} not supported".format(DISTRIBUTED_BACKEND))
[docs]def get_backend():
return BACKEND
[docs]def get_distributed_backend():
return DISTRIBUTED_BACKEND
__all__ = [
"__version__", "get_backend", "get_distributed_backend"
]