FreeWay/Assignment5/mini_go/algorimths/dqn.py

409 lines
16 KiB
Python

# Copyright 2019 DeepMind Technologies Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DQN agent implemented in TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random, os
import numpy as np
import sonnet as snt
import tensorflow as tf
Transition = collections.namedtuple(
"Transition",
"info_state action reward next_info_state is_final_step legal_actions_mask")
StepOutput = collections.namedtuple("step_output", ["action", "probs"])
ILLEGAL_ACTION_LOGITS_PENALTY = -1e9
class ReplayBuffer(object):
"""ReplayBuffer of fixed size with a FIFO replacement policy.
Stored transitions can be sampled uniformly.
The underlying datastructure is a ring buffer, allowing 0(1) adding and
sampling.
"""
def __init__(self, replay_buffer_capacity):
self._replay_buffer_capacity = replay_buffer_capacity
self._data = []
self._next_entry_index = 0
def add(self, element):
"""Adds `element` to the buffer.
If the buffer is full, the oldest element will be replaced.
Args:
element: data to be added to the buffer.
"""
if len(self._data) < self._replay_buffer_capacity:
self._data.append(element)
else:
self._data[self._next_entry_index] = element
self._next_entry_index += 1
self._next_entry_index %= self._replay_buffer_capacity
def sample(self, num_samples):
"""Returns `num_samples` uniformly sampled from the buffer.
Args:
num_samples: `int`, number of samples to draw.
Returns:
An iterable over `num_samples` random elements of the buffer.
Raises:
ValueError: If there are less than `num_samples` elements in the buffer
"""
if len(self._data) < num_samples:
raise ValueError("{} elements could not be sampled from size {}".format(
num_samples, len(self._data)))
return random.sample(self._data, num_samples)
def __len__(self):
return len(self._data)
def __iter__(self):
return iter(self._data)
class DQN:
"""DQN Agent implementation in TensorFlow.
"""
def __init__(self,
session,
player_id,
state_representation_size,
num_actions,
hidden_layers_sizes,
replay_buffer_capacity=10000,
batch_size=128,
replay_buffer_class=ReplayBuffer,
learning_rate=0.01,
update_target_network_every=200,
learn_every=10,
discount_factor=1.0,
min_buffer_size_to_learn=1000,
epsilon_start=1.0,
epsilon_end=0.1,
epsilon_decay_duration=int(1e6),
optimizer_str="sgd",
loss_str="mse",
max_global_gradient_norm=None):
"""Initialize the DQN agent."""
self.player_id = player_id
self._session = session
self._num_actions = num_actions
self._layer_sizes = hidden_layers_sizes + [num_actions]
self._batch_size = batch_size
self._update_target_network_every = update_target_network_every
self._learn_every = learn_every
self._min_buffer_size_to_learn = min_buffer_size_to_learn
self._discount_factor = discount_factor
self._epsilon_start = epsilon_start
self._epsilon_end = epsilon_end
self._epsilon_decay_duration = epsilon_decay_duration
# TODO Allow for optional replay buffer config.
self._replay_buffer = replay_buffer_class(replay_buffer_capacity)
self._prev_timestep = None
self._prev_action = None
# Step counter to keep track of learning, eps decay and target network.
self._step_counter = 0
# Keep track of the last training loss achieved in an update step.
self._last_loss_value = None
# Create required TensorFlow placeholders to perform the Q-network updates.
self._info_state_ph = tf.placeholder(
shape=[None, state_representation_size],
dtype=tf.float32,
name="info_state_ph")
self._action_ph = tf.placeholder(
shape=[None], dtype=tf.int32, name="action_ph")
self._reward_ph = tf.placeholder(
shape=[None], dtype=tf.float32, name="reward_ph")
self._is_final_step_ph = tf.placeholder(
shape=[None], dtype=tf.float32, name="is_final_step_ph")
self._next_info_state_ph = tf.placeholder(
shape=[None, state_representation_size],
dtype=tf.float32,
name="next_info_state_ph")
self._legal_actions_mask_ph = tf.placeholder(
shape=[None, num_actions],
dtype=tf.float32,
name="legal_actions_mask_ph")
self._q_network = snt.nets.MLP(output_sizes=self._layer_sizes)
self._q_values = self._q_network(self._info_state_ph)
self._target_q_network = snt.nets.MLP(output_sizes=self._layer_sizes)
self._target_q_values = self._target_q_network(self._next_info_state_ph)
# Stop gradient to prevent updates to the target network while learning
self._target_q_values = tf.stop_gradient(self._target_q_values)
self._update_target_network = self._create_target_network_update_op(
self._q_network, self._target_q_network)
# Create the loss operations.
# Sum a large negative constant to illegal action logits before taking the
# max. This prevents illegal action values from being considered as target.
illegal_actions = 1 - self._legal_actions_mask_ph
illegal_logits = illegal_actions * ILLEGAL_ACTION_LOGITS_PENALTY
max_next_q = tf.reduce_max(
tf.math.add(tf.stop_gradient(self._target_q_values), illegal_logits),
axis=-1)
target = (
self._reward_ph +
(1 - self._is_final_step_ph) * self._discount_factor * max_next_q)
action_indices = tf.stack(
[tf.range(tf.shape(self._q_values)[0]), self._action_ph], axis=-1)
predictions = tf.gather_nd(self._q_values, action_indices)
if loss_str == "mse":
loss_class = tf.losses.mean_squared_error
elif loss_str == "huber":
loss_class = tf.losses.huber_loss
else:
raise ValueError("Not implemented, choose from 'mse', 'huber'.")
self._loss = tf.reduce_mean(
loss_class(labels=target, predictions=predictions))
if optimizer_str == "adam":
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
elif optimizer_str == "sgd":
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
else:
raise ValueError("Not implemented, choose from 'adam' and 'sgd'.")
def minimize_with_clipping(optimizer, loss):
grads_and_vars = optimizer.compute_gradients(loss)
if max_global_gradient_norm is not None:
grads, variables = zip(*grads_and_vars)
grads, _ = tf.clip_by_global_norm(grads, max_global_gradient_norm)
grads_and_vars = list(zip(grads, variables))
return optimizer.apply_gradients(grads_and_vars)
self._learn_step = minimize_with_clipping(optimizer, self._loss)
# self._ckp = tf.train.Checkpoint(module=self._q_network)
self._saver = tf.train.Saver(var_list=self._q_network.variables)
def step(self, time_step, is_evaluation=False, add_transition_record=True):
"""Returns the action to be taken and updates the Q-network if needed.
Args:
time_step: an instance of TimeStep
is_evaluation: bool, whether this is a training or evaluation call.
add_transition_record: Whether to add to the replay buffer on this step.
Returns:
A `StepOutput` containing the action probs and chosen action.
"""
# Act step: don't act at terminal info states or if its not our turn.
if (not time_step.last()) and (self.player_id == time_step.current_player()):
info_state = time_step.observations["info_state"][self.player_id]
legal_actions = time_step.observations["legal_actions"][self.player_id]
epsilon = self._get_epsilon(is_evaluation)
action, probs = self._epsilon_greedy(info_state, legal_actions, epsilon)
else:
action = None
probs = []
# Don't mess up with the state during evaluation.
if not is_evaluation:
self._step_counter += 1
if self._step_counter % self._learn_every == 0:
self._last_loss_value = self.learn()
if self._step_counter % self._update_target_network_every == 0:
self._session.run(self._update_target_network)
if self._prev_timestep and add_transition_record:
# We may omit record adding here if it's done elsewhere.
if self._prev_action is not None:
self.add_transition(self._prev_timestep, self._prev_action, time_step)
if time_step.last(): # prepare for the next episode.
self._prev_timestep = None
self._prev_action = None
return
else:
self._prev_timestep = time_step
self._prev_action = action
return StepOutput(action=action, probs=probs)
def add_transition(self, prev_time_step, prev_action, time_step):
"""Adds the new transition using `time_step` to the replay buffer.
Adds the transition from `self._prev_timestep` to `time_step` by
`self._prev_action`.
Args:
prev_time_step: prev ts, an instance of rl_environment.TimeStep.
prev_action: int, action taken at `prev_time_step`.
time_step: current ts, an instance of rl_environment.TimeStep.
"""
assert prev_time_step is not None
legal_actions = (
prev_time_step.observations["legal_actions"][self.player_id])
legal_actions_mask = np.zeros(self._num_actions)
legal_actions_mask[legal_actions] = 1.0
transition = Transition(
info_state=(
prev_time_step.observations["info_state"][self.player_id][:]),
action=prev_action,
reward=time_step.rewards[self.player_id],
next_info_state=time_step.observations["info_state"][self.player_id][:],
is_final_step=float(time_step.last()),
legal_actions_mask=legal_actions_mask)
self._replay_buffer.add(transition)
def _create_target_network_update_op(self, q_network, target_q_network):
"""Create TF ops copying the params of the Q-network to the target network.
Args:
q_network: `snt.AbstractModule`. Values are copied from this network.
target_q_network: `snt.AbstractModule`. Values are copied to this network.
Returns:
A `tf.Operation` that updates the variables of the target.
"""
variables = q_network.get_variables()
target_variables = target_q_network.get_variables()
return tf.group([
tf.assign(target_v, v)
for (target_v, v) in zip(target_variables, variables)
])
def _epsilon_greedy(self, info_state, legal_actions, epsilon):
"""Returns a valid epsilon-greedy action and valid action probs.
Action probabilities are given by a softmax over legal q-values.
Args:
info_state: hashable representation of the information state.
legal_actions: list of legal actions at `info_state`.
epsilon: float, probability of taking an exploratory action.
Returns:
A valid epsilon-greedy action and valid action probabilities.
"""
probs = np.zeros(self._num_actions)
if np.random.rand() < epsilon:
action = np.random.choice(legal_actions)
probs[legal_actions] = 1.0 / len(legal_actions)
else:
info_state = np.reshape(info_state, [1, -1])
q_values = self._session.run(
self._q_values, feed_dict={self._info_state_ph: info_state})[0]
legal_q_values = q_values[legal_actions]
action = legal_actions[np.argmax(legal_q_values)]
probs[action] = 1.0
return action, probs
def _get_epsilon(self, is_evaluation, power=1.0):
"""Returns the evaluation or decayed epsilon value."""
if is_evaluation:
return 0.0
decay_steps = min(self._step_counter, self._epsilon_decay_duration)
decayed_epsilon = (
self._epsilon_end + (self._epsilon_start - self._epsilon_end) *
(1 - decay_steps / self._epsilon_decay_duration) ** power)
return decayed_epsilon
def learn(self):
"""Compute the loss on sampled transitions and perform a Q-network update.
If there are not enough elements in the buffer, no loss is computed and
`None` is returned instead.
Returns:
The average loss obtained on this batch of transitions or `None`.
"""
if (len(self._replay_buffer) < self._batch_size or
len(self._replay_buffer) < self._min_buffer_size_to_learn):
return None
transitions = self._replay_buffer.sample(self._batch_size)
info_states = [t.info_state for t in transitions]
actions = [t.action for t in transitions]
rewards = [t.reward for t in transitions]
next_info_states = [t.next_info_state for t in transitions]
are_final_steps = [t.is_final_step for t in transitions]
legal_actions_mask = [t.legal_actions_mask for t in transitions]
loss, _ = self._session.run(
[self._loss, self._learn_step],
feed_dict={
self._info_state_ph: info_states,
self._action_ph: actions,
self._reward_ph: rewards,
self._is_final_step_ph: are_final_steps,
self._next_info_state_ph: next_info_states,
self._legal_actions_mask_ph: legal_actions_mask,
})
return loss
def save(self, checkpoint_root, checkpoint_name):
save_prefix = os.path.join(checkpoint_root, checkpoint_name)
self._saver.save(sess=self._session, save_path=save_prefix)
def restore(self, save_path):
self._saver.restore(self._session, save_path)
@property
def q_values(self):
return self._q_values
@property
def replay_buffer(self):
return self._replay_buffer
@property
def info_state_ph(self):
return self._info_state_ph
@property
def loss(self):
return self._last_loss_value
@property
def prev_timestep(self):
return self._prev_timestep
@property
def prev_action(self):
return self._prev_action
@property
def step_counter(self):
return self._step_counter