Source code for rlgraph.agents.dqn_agent

# Copyright 2018 The RLgraph authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from rlgraph.agents import Agent
from rlgraph.components import Synchronizable, Memory, PrioritizedReplay, DQNLossFunction, DictMerger, \
    ContainerSplitter
from rlgraph.spaces import FloatBox, BoolBox
from rlgraph.utils.decorators import rlgraph_api
from rlgraph.utils.util import strip_list


[docs]class DQNAgent(Agent): """ A collection of DQN algorithms published in the following papers: [1] Human-level control through deep reinforcement learning. Mnih, Kavukcuoglu, Silver et al. - 2015 [2] Deep Reinforcement Learning with Double Q-learning. v. Hasselt, Guez, Silver - 2015 [3] Dueling Network Architectures for Deep Reinforcement Learning, Wang et al. - 2016 [4] https://en.wikipedia.org/wiki/Huber_loss """ def __init__(self, double_q=True, dueling_q=True, huber_loss=False, n_step=1, memory_spec=None, store_last_memory_batch=False, store_last_q_table=False, **kwargs): """ Args: double_q (bool): Whether to use the double DQN loss function (see [2]). dueling_q (bool): Whether to use a dueling layer in the ActionAdapter (see [3]). huber_loss (bool) : Whether to apply a Huber loss. (see [4]). n_step (Optional[int]): n-step adjustment to discounting. memory_spec (Optional[dict,Memory]): The spec for the Memory to use for the DQN algorithm. store_last_memory_batch (bool): Whether to store the last pulled batch from the memory in `self.last_memory_batch` for debugging purposes. Default: False. store_last_q_table (bool): Whether to store the Q(s,a) values for the last received batch (memory or external) in `self.last_q_table` for debugging purposes. Default: False. """ # Fix action-adapter before passing it to the super constructor. action_adapter_spec = kwargs.pop("action_adapter_spec", dict()) # Use a DuelingActionAdapter (instead of a basic ActionAdapter) if option is set. if dueling_q is True: action_adapter_spec["type"] = "dueling-action-adapter" assert "units_state_value_stream" in action_adapter_spec assert "units_advantage_stream" in action_adapter_spec super(DQNAgent, self).__init__( action_adapter_spec=action_adapter_spec, name=kwargs.pop("name", "dqn-agent"), **kwargs ) self.double_q = double_q self.dueling_q = dueling_q self.huber_loss = huber_loss # Debugging tools. self.store_last_memory_batch = store_last_memory_batch self.last_memory_batch = None self.store_last_q_table = store_last_q_table self.last_q_table = None # Extend input Space definitions to this Agent's specific API-methods. preprocessed_state_space = self.preprocessed_state_space.with_batch_rank() reward_space = FloatBox(add_batch_rank=True) terminal_space = BoolBox(add_batch_rank=True) weight_space = FloatBox(add_batch_rank=True) self.input_spaces.update(dict( actions=self.action_space.with_batch_rank(), weights="variables:policy", time_step=int, use_exploration=bool, preprocessed_states=preprocessed_state_space, rewards=reward_space, terminals=terminal_space, next_states=preprocessed_state_space, preprocessed_next_states=preprocessed_state_space, importance_weights=weight_space, # TODO: This is currently necessary for multi-GPU handling (as the update_from_external_batch # TODO: gets overridden by a generic function with args=*inputs) #inputs=[preprocessed_state_space, self.action_space.with_batch_rank(), reward_space, terminal_space, # preprocessed_state_space, weight_space] )) # The merger to merge inputs into one record Dict going into the memory. self.merger = DictMerger("states", "actions", "rewards", "next_states", "terminals") # The replay memory. self.memory = Memory.from_spec(memory_spec) # The splitter for splitting up the records coming from the memory. self.splitter = ContainerSplitter("states", "actions", "rewards", "terminals", "next_states") # Copy our Policy (target-net), make target-net synchronizable. self.target_policy = self.policy.copy(scope="target-policy", trainable=False) self.target_policy.add_components(Synchronizable(), expose_apis="sync") # Number of steps since the last target-net synching from the main policy. self.steps_since_target_net_sync = 0 use_importance_weights = isinstance(self.memory, PrioritizedReplay) self.loss_function = DQNLossFunction( discount=self.discount, double_q=self.double_q, huber_loss=self.huber_loss, importance_weights=use_importance_weights, n_step=n_step ) # Add all our sub-components to the core. sub_components = [self.preprocessor, self.merger, self.memory, self.splitter, self.policy, self.target_policy, self.exploration, self.loss_function, self.optimizer] self.root_component.add_components(*sub_components) # Define the Agent's (root-Component's) API. self.define_api_methods("policy", "preprocessor-stack", self.optimizer.scope, *sub_components) # markup = get_graph_markup(self.graph_builder.root_component) # print(markup) if self.auto_build: self._build_graph([self.root_component], self.input_spaces, optimizer=self.optimizer, batch_size=self.update_spec["batch_size"]) self.graph_built = True # TODO: What should the external batch be? 0s. #if "multi-gpu-sync-optimizer" in self.root_component.sub_components: # # Get 1st return op of API-method `calculate_update_from_external_batch` # # (which is the group of stage-ops). # stage_op = self.root_component.sub_components["multi-gpu-sync-optimizer"].\ # api_methods["calculate_update_from_external_batch"].\ # out_op_columns[0].op_records[0].op # # Initialize the stage. # self.graph_executor.monitored_session.run_step_fn( # lambda step_context: step_context.session.run(stage_op) # )
[docs] def get_action(self, states, internals=None, use_exploration=True, apply_preprocessing=True, extra_returns=None): """ Args: extra_returns (Optional[Set[str],str]): Optional string or set of strings for additional return values (besides the actions). Possible values are: - 'preprocessed_states': The preprocessed states after passing the given states through the preprocessor stack. - 'internal_states': The internal states returned by the RNNs in the NN pipeline. - 'used_exploration': Whether epsilon- or noise-based exploration was used or not. Returns: tuple or single value depending on `extra_returns`: - action - the preprocessed states """ extra_returns = {extra_returns} if isinstance(extra_returns, str) else (extra_returns or set()) # States come in without preprocessing -> use state space. if apply_preprocessing: call_method = "get_preprocessed_state_and_action" batched_states = self.state_space.force_batch(states) else: call_method = "action_from_preprocessed_state" batched_states = states remove_batch_rank = batched_states.ndim == np.asarray(states).ndim + 1 # Increase timesteps by the batch size (number of states in batch). batch_size = len(batched_states) self.timesteps += batch_size # Control, which return value to "pull" (depending on `additional_returns`). return_ops = [1, 0] if "preprocessed_states" in extra_returns else [1] ret = self.graph_executor.execute(( call_method, [batched_states, self.timesteps, use_exploration], # 0=preprocessed_states, 1=action return_ops )) if remove_batch_rank: return strip_list(ret) else: return ret
[docs] def define_api_methods(self, policy_scope, pre_processor_scope, optimizer_scope, *sub_components): super(DQNAgent, self).define_api_methods(policy_scope, pre_processor_scope, optimizer_scope) preprocessor, merger, memory, splitter, policy, target_policy, exploration, loss_function, optimizer = \ sub_components # Reset operation (resets preprocessor). if self.preprocessing_required: @rlgraph_api(component=self.root_component) def reset_preprocessor(self): reset_op = preprocessor.reset() return reset_op # Act from preprocessed states. @rlgraph_api(component=self.root_component) def action_from_preprocessed_state(self, preprocessed_states, time_step=0, use_exploration=True): sample_deterministic = policy.get_max_likelihood_action(preprocessed_states) actions = exploration.get_action(sample_deterministic["action"], time_step, use_exploration) return preprocessed_states, actions # State (from environment) to action with preprocessing. @rlgraph_api(component=self.root_component) def get_preprocessed_state_and_action(self, states, time_step=0, use_exploration=True): preprocessed_states = preprocessor.preprocess(states) return self.action_from_preprocessed_state(preprocessed_states, time_step, use_exploration) # Insert into memory. @rlgraph_api(component=self.root_component) def insert_records(self, preprocessed_states, actions, rewards, next_states, terminals): records = merger.merge(preprocessed_states, actions, rewards, next_states, terminals) return memory.insert_records(records) # Syncing target-net. @rlgraph_api(component=self.root_component) def sync_target_qnet(self): # If we are a multi-GPU root: # Simply feeds everything into the multi-GPU sync optimizer's method and return. if "multi-gpu-sync-optimizer" in self.sub_components: multi_gpu_syncer = self.sub_components["multi-gpu-sync-optimizer"] return multi_gpu_syncer.sync_target_qnets() else: policy_vars = self.get_sub_component_by_name(policy_scope)._variables() return self.get_sub_component_by_name("target-policy").sync(policy_vars) # Learn from memory. @rlgraph_api(component=self.root_component) def update_from_memory(self_): # Non prioritized memory will just return weight 1.0 for all samples. records, sample_indices, importance_weights = memory.get_records(self.update_spec["batch_size"]) preprocessed_s, actions, rewards, terminals, preprocessed_s_prime = splitter.split(records) step_op, loss, loss_per_item, q_values_s = self_.update_from_external_batch( preprocessed_s, actions, rewards, terminals, preprocessed_s_prime, importance_weights ) # TODO this is really annoying.. will be solved once we have dict returns. if isinstance(memory, PrioritizedReplay): update_pr_step_op = memory.update_records(sample_indices, loss_per_item) return step_op, loss, loss_per_item, records, q_values_s, update_pr_step_op else: return step_op, loss, loss_per_item, records, q_values_s # Learn from an external batch. @rlgraph_api(component=self.root_component) def update_from_external_batch( self_, preprocessed_states, actions, rewards, terminals, preprocessed_next_states, importance_weights ): # If we are a multi-GPU root: # Simply feeds everything into the multi-GPU sync optimizer's method and return. if "multi-gpu-sync-optimizer" in self_.sub_components: main_policy_vars = self_.get_sub_component_by_name(policy_scope)._variables() # TODO: This may be called differently in other agents (replace by root-policy). grads_and_vars, loss, loss_per_item, q_values_s = \ self_.sub_components["multi-gpu-sync-optimizer"].calculate_update_from_external_batch( main_policy_vars, preprocessed_states, actions, rewards, terminals, preprocessed_next_states, importance_weights ) step_op = self_.get_sub_component_by_name(optimizer_scope).apply_gradients(grads_and_vars) step_and_sync_op = self_.sub_components["multi-gpu-sync-optimizer"].sync_policy_weights_to_towers( step_op, main_policy_vars ) return step_and_sync_op, loss, loss_per_item, q_values_s # Get the different Q-values. q_values_s = self_.get_sub_component_by_name(policy_scope).get_logits_probabilities_log_probs( preprocessed_states )["logits"] qt_values_sp = self_.get_sub_component_by_name(target_policy.scope).get_logits_probabilities_log_probs( preprocessed_next_states )["logits"] q_values_sp = None if self.double_q: q_values_sp = self_.get_sub_component_by_name(policy_scope).get_logits_probabilities_log_probs( preprocessed_next_states )["logits"] loss, loss_per_item = self_.get_sub_component_by_name(loss_function.scope).loss( q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp, importance_weights ) # Args are passed in again because some device strategies may want to split them to different devices. policy_vars = self_.get_sub_component_by_name(policy_scope)._variables() # TODO: for a fully automated multi-GPU strategy, we would have to make sure that: # TODO: - every agent (root_component) has an update_from_external_batch method # TODO: - this if check is somehow automated and not necessary anymore (local optimizer must be called with different API-method, not step) if hasattr(self_, "is_multi_gpu_tower") and self_.is_multi_gpu_tower is True: grads_and_vars = self_.get_sub_component_by_name(optimizer_scope).calculate_gradients(policy_vars, loss) return grads_and_vars, loss, loss_per_item, q_values_s else: step_op, loss, loss_per_item = optimizer.step(policy_vars, loss, loss_per_item) return step_op, loss, loss_per_item, q_values_s # TODO for testing @rlgraph_api(component=self.root_component) def get_td_loss(self_, preprocessed_states, actions, rewards, terminals, preprocessed_next_states, importance_weights): # Get the different Q-values. q_values_s = policy.get_logits_probabilities_log_probs(preprocessed_states)["logits"] qt_values_sp = target_policy.get_logits_probabilities_log_probs(preprocessed_next_states)["logits"] q_values_sp = None if self.double_q: q_values_sp = policy.get_logits_probabilities_log_probs(preprocessed_next_states)["logits"] loss, loss_per_item = loss_function.loss( q_values_s, actions, rewards, terminals, qt_values_sp, q_values_sp, importance_weights ) return loss, loss_per_item
def _observe_graph(self, preprocessed_states, actions, internals, rewards, next_states, terminals): self.graph_executor.execute(("insert_records", [preprocessed_states, actions, rewards, next_states, terminals]))
[docs] def update(self, batch=None): # Should we sync the target net? self.steps_since_target_net_sync += self.update_spec["update_interval"] if self.steps_since_target_net_sync >= self.update_spec["sync_interval"]: sync_call = "sync_target_qnet" self.steps_since_target_net_sync = 0 else: sync_call = None # [0]=no-op step; [1]=the loss; [2]=loss-per-item, [3]=memory-batch (if pulled); [4]=q-values return_ops = [0, 1, 2] q_table = None if batch is None: # Add some additional return-ops to pull (left out normally for performance reasons). if self.store_last_q_table is True: return_ops += [3, 4] # 3=batch, 4=q-values elif self.store_last_memory_batch is True: return_ops += [3] # 3=batch ret = self.graph_executor.execute(("update_from_memory", None, return_ops), sync_call) # print("Loss: {}".format(ret["update_from_memory"][1])) # Remove unnecessary return dicts (e.g. sync-op). if isinstance(ret, dict): ret = ret["update_from_memory"] # Store the last Q-table? if self.store_last_q_table is True: q_table = dict( states=ret[3]["states"], q_values=ret[4] ) else: # Add some additional return-ops to pull (left out normally for performance reasons). if self.store_last_q_table is True: return_ops += [3] # 3=q-values batch_input = [batch["states"], batch["actions"], batch["rewards"], batch["terminals"], batch["next_states"], batch["importance_weights"]] ret = self.graph_executor.execute(("update_from_external_batch", batch_input, return_ops), sync_call) # Remove unnecessary return dicts (e.g. sync-op). if isinstance(ret, dict): ret = ret["update_from_external_batch"] # Store the last Q-table? if self.store_last_q_table is True: q_table = dict( states=batch["states"], q_values=ret[3] ) # Store the latest pulled memory batch? if self.store_last_memory_batch is True and batch is None: self.last_memory_batch = ret[2] if self.store_last_q_table is True: self.last_q_table = q_table # [1]=the loss (0=update noop) # [2]=loss per item for external update, records for update from memory return ret[1], ret[2]
[docs] def reset(self): """ Resets our preprocessor, but only if it contains stateful PreprocessLayer Components (meaning the PreprocessorStack has at least one variable defined). """ if self.preprocessing_required and len(self.preprocessor.variables) > 0: self.graph_executor.execute("reset_preprocessor")
def __repr__(self): return "DQNAgent(doubleQ={} duelingQ={})".format(self.double_q, self.dueling_q)