gabehubner's picture
add requirements
ee1c253
raw
history blame
11 kB
import os
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution)
class OUActionNoise(object): # Ornstein-Uhlenbeck process -> Temporary correlated noise
def __init__(self, mu, sigma=0.15, theta=0.2, dt=1e-2, x0=None):
self.theta = theta
self.mu = mu
self.sigma = sigma
self.dt = dt
self.x0 = x0
self.reset()
def __call__(self):
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma*np.sqrt(self.dt)*np.random.normal(size=self.mu.shape)
self.x_prev = x
return x
def reset(self):
self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu)
class ReplayBuffer(object):
def __init__(self, max_size, input_shape, n_actions):
self.mem_size = max_size
self.mem_cntr = 0
self.state_memory = np.zeros((self.mem_size, *input_shape))
self.new_state_memory = np.zeros((self.mem_size, *input_shape))
self.action_memory = np.zeros((self.mem_size, n_actions))
self.reward_memory = np.zeros(self.mem_size)
self.terminal_memory = np.zeros(self.mem_size, dtype=np.float32)
def store_transition(self, state, action, reward, state_, done):
index = self.mem_cntr % self.mem_size # index of the memory
self.state_memory[index] = state
self.action_memory[index] = action
self.reward_memory[index] = reward
self.new_state_memory[index] = state_
self.terminal_memory[index] = 1 - done
self.mem_cntr += 1
def sample_buffer(self, batch_size):
max_mem = min(self.mem_cntr, self.mem_size) # if memory is not full, use mem_cntr
batch = np.random.choice(max_mem, batch_size)
states = self.state_memory[batch]
actions = self.action_memory[batch]
rewards = self.reward_memory[batch]
new_states = self.new_state_memory[batch]
terminal = self.terminal_memory[batch]
return states, actions, rewards, new_states, terminal
class CriticNetwork(nn.Module):
def __init__(self, beta, input_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir="tmp/ddpg"):
super(CriticNetwork, self).__init__()
self.input_dims = input_dims
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.n_actions = n_actions
self.checkpoint_dir = chkpt_dir
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_ddpg')
self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
f1 = 1./np.sqrt(self.fc1.weight.data.size()[0])
T.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
T.nn.init.uniform_(self.fc1.bias.data, -f1, f1)
self.bn1 = nn.LayerNorm(self.fc1_dims)
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
f2 = 1./np.sqrt(self.fc2.weight.data.size()[0])
T.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
T.nn.init.uniform_(self.fc2.bias.data, -f2, f2)
self.bn2 = nn.LayerNorm(self.fc2_dims)
self.action_value = nn.Linear(self.n_actions, self.fc2_dims)
f3 = 0.003 # From paper
self.q = nn.Linear(self.fc2_dims, 1)
T.nn.init.uniform_(self.q.weight.data, -f3, f3)
T.nn.init.uniform_(self.q.bias.data, -f3, f3)
self.optimizer = optim.Adam(self.parameters(), lr=beta, weight_decay=0.01)
self.device = T.device("cpu")
self.to(self.device)
def forward(self, state, action):
state_value = self.fc1(state)
state_value = self.bn1(state_value)
state_value = F.relu(state_value)
state_value = self.fc2(state_value)
state_value = self.bn2(state_value)
action_value = F.relu(self.action_value(action))
state_action_value = F.relu(T.add(state_value, action_value))
state_action_value = self.q(state_action_value)
return state_action_value
def save_checkpoint(self):
print('... saving checkpoint ...')
T.save(self.state_dict(), self.checkpoint_file)
def load_checkpoint(self):
print('... loading checkpoint ...')
self.load_state_dict(T.load(self.checkpoint_file))
class ActorNetwork(nn.Module):
def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir="tmp/ddpg"):
super(ActorNetwork, self).__init__()
self.input_dims = input_dims
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.n_actions = n_actions
self.checkpoint_dir = chkpt_dir
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_ddpg')
self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
f1 = 1./np.sqrt(self.fc1.weight.data.size()[0])
T.nn.init.uniform_(self.fc1.weight.data, -f1, f1)
T.nn.init.uniform_(self.fc1.bias.data, -f1, f1)
self.bn1 = nn.LayerNorm(self.fc1_dims)
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
f2 = 1./np.sqrt(self.fc2.weight.data.size()[0])
T.nn.init.uniform_(self.fc2.weight.data, -f2, f2)
T.nn.init.uniform_(self.fc2.bias.data, -f2, f2)
self.bn2 = nn.LayerNorm(self.fc2_dims)
f3 = 0.003 # From paper
self.mu = nn.Linear(self.fc2_dims, self.n_actions)
T.nn.init.uniform_(self.mu.weight.data, -f3, f3)
T.nn.init.uniform_(self.mu.bias.data, -f3, f3)
T.nn.init.uniform_(self.mu.bias.data, -f3, f3)
self.optimizer = optim.Adam(self.parameters(), lr=alpha)
self.device = T.device("cpu")
self.to(self.device)
def forward(self, state):
x = self.fc1(state)
x = self.bn1(x)
x = F.relu(x)
x = self.fc2(x)
x = self.bn2(x)
x = F.relu(x)
x = T.tanh(self.mu(x))
return x
def save_checkpoint(self):
print('... saving checkpoint ...')
T.save(self.state_dict(), self.checkpoint_file)
def load_checkpoint(self):
print('... loading checkpoint ...')
self.load_state_dict(T.load(self.checkpoint_file))
class Agent(object):
def __init__(self, alpha, beta, input_dims, tau, env, gamma=0.99, n_actions=2, max_size=1000000, layer1_size=400, layer2_size=300, batch_size=64):
self.gamma = gamma
self.tau = tau
self.batch_size = batch_size
self.memory = ReplayBuffer(max_size, input_dims, n_actions)
self.actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="actor")
self.critic = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="critic")
self.target_actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="target_actor")
self.target_critic = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="target_critic")
self.noise = OUActionNoise(mu=np.zeros(n_actions))
self.attributions = []
self.ig : IntegratedGradients = None
self.update_network_parameters(tau=1)
def choose_action(self, observation, baseline: T.Tensor = None):
self.actor.eval()
observation = T.tensor(observation, dtype=T.float).to(self.actor.device)
# print(f"Observation: {observation.shape=}")
mu = self.actor(observation).to(self.actor.device)
if self.ig is not None:
attribution = self.ig.attribute(observation, baselines=baseline, n_steps=1)
self.attributions.append(attribution)
mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(self.actor.device)
self.actor.train()
return mu_prime.cpu().detach().numpy()
def remember(self, state, action, reward, new_state, done):
self.memory.store_transition(state, action, reward, new_state, done)
def learn(self):
if self.memory.mem_cntr < self.batch_size:
return
state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size)
reward = T.tensor(reward, dtype=T.float).to(self.critic.device)
done = T.tensor(done).to(self.critic.device)
new_state = T.tensor(new_state, dtype=T.float).to(self.critic.device)
action = T.tensor(action, dtype=T.float).to(self.critic.device)
state = T.tensor(state, dtype=T.float).to(self.critic.device)
self.target_actor.eval()
self.target_critic.eval()
self.critic.eval()
target_actions = self.target_actor.forward(new_state)
critic_value_ = self.target_critic.forward(new_state, target_actions)
critic_value = self.critic.forward(state, action)
target = []
for j in range(self.batch_size):
target.append(reward[j] + self.gamma*critic_value_[j]*done[j])
target = T.tensor(target).to(self.critic.device)
target = target.view(self.batch_size, 1)
self.critic.train()
self.critic.optimizer.zero_grad()
critic_loss = F.mse_loss(target, critic_value)
critic_loss.backward()
self.critic.optimizer.step()
self.critic.eval()
self.actor.optimizer.zero_grad()
mu = self.actor.forward(state)
self.actor.train()
actor_loss = -self.critic.forward(state, mu)
actor_loss = T.mean(actor_loss)
actor_loss.backward()
self.actor.optimizer.step()
self.update_network_parameters()
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
actor_params = self.actor.named_parameters()
critic_params = self.critic.named_parameters()
target_actor_params = self.target_actor.named_parameters()
target_critic_params = self.target_critic.named_parameters()
critic_state_dict = dict(critic_params)
actor_state_dict = dict(actor_params)
target_critic_state_dict = dict(target_critic_params)
target_actor_state_dict = dict(target_actor_params)
for name in critic_state_dict:
critic_state_dict[name] = tau*critic_state_dict[name].clone() + (1-tau)*target_critic_state_dict[name].clone()
self.target_critic.load_state_dict(critic_state_dict)
for name in actor_state_dict:
actor_state_dict[name] = tau*actor_state_dict[name].clone() + (1-tau)*target_actor_state_dict[name].clone()
self.target_actor.load_state_dict(actor_state_dict)
def save_models(self):
self.actor.save_checkpoint()
self.target_actor.save_checkpoint()
self.critic.save_checkpoint()
self.target_critic.save_checkpoint()
def load_models(self):
self.actor.load_checkpoint()
self.target_actor.load_checkpoint()
self.critic.load_checkpoint()
self.target_critic.load_checkpoint()