Spaces:
Sleeping
Sleeping
import os | |
import torch as T | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import numpy as np | |
from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution) | |
class OUActionNoise(object): # Ornstein-Uhlenbeck process -> Temporary correlated noise | |
def __init__(self, mu, sigma=0.15, theta=0.2, dt=1e-2, x0=None): | |
self.theta = theta | |
self.mu = mu | |
self.sigma = sigma | |
self.dt = dt | |
self.x0 = x0 | |
self.reset() | |
def __call__(self): | |
x = self.x_prev + self.theta * (self.mu - self.x_prev) * self.dt + self.sigma*np.sqrt(self.dt)*np.random.normal(size=self.mu.shape) | |
self.x_prev = x | |
return x | |
def reset(self): | |
self.x_prev = self.x0 if self.x0 is not None else np.zeros_like(self.mu) | |
class ReplayBuffer(object): | |
def __init__(self, max_size, input_shape, n_actions): | |
self.mem_size = max_size | |
self.mem_cntr = 0 | |
self.state_memory = np.zeros((self.mem_size, *input_shape)) | |
self.new_state_memory = np.zeros((self.mem_size, *input_shape)) | |
self.action_memory = np.zeros((self.mem_size, n_actions)) | |
self.reward_memory = np.zeros(self.mem_size) | |
self.terminal_memory = np.zeros(self.mem_size, dtype=np.float32) | |
def store_transition(self, state, action, reward, state_, done): | |
index = self.mem_cntr % self.mem_size # index of the memory | |
self.state_memory[index] = state | |
self.action_memory[index] = action | |
self.reward_memory[index] = reward | |
self.new_state_memory[index] = state_ | |
self.terminal_memory[index] = 1 - done | |
self.mem_cntr += 1 | |
def sample_buffer(self, batch_size): | |
max_mem = min(self.mem_cntr, self.mem_size) # if memory is not full, use mem_cntr | |
batch = np.random.choice(max_mem, batch_size) | |
states = self.state_memory[batch] | |
actions = self.action_memory[batch] | |
rewards = self.reward_memory[batch] | |
new_states = self.new_state_memory[batch] | |
terminal = self.terminal_memory[batch] | |
return states, actions, rewards, new_states, terminal | |
class CriticNetwork(nn.Module): | |
def __init__(self, beta, input_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir="tmp/ddpg"): | |
super(CriticNetwork, self).__init__() | |
self.input_dims = input_dims | |
self.fc1_dims = fc1_dims | |
self.fc2_dims = fc2_dims | |
self.n_actions = n_actions | |
self.checkpoint_dir = chkpt_dir | |
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_ddpg') | |
self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims) | |
f1 = 1./np.sqrt(self.fc1.weight.data.size()[0]) | |
T.nn.init.uniform_(self.fc1.weight.data, -f1, f1) | |
T.nn.init.uniform_(self.fc1.bias.data, -f1, f1) | |
self.bn1 = nn.LayerNorm(self.fc1_dims) | |
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) | |
f2 = 1./np.sqrt(self.fc2.weight.data.size()[0]) | |
T.nn.init.uniform_(self.fc2.weight.data, -f2, f2) | |
T.nn.init.uniform_(self.fc2.bias.data, -f2, f2) | |
self.bn2 = nn.LayerNorm(self.fc2_dims) | |
self.action_value = nn.Linear(self.n_actions, self.fc2_dims) | |
f3 = 0.003 # From paper | |
self.q = nn.Linear(self.fc2_dims, 1) | |
T.nn.init.uniform_(self.q.weight.data, -f3, f3) | |
T.nn.init.uniform_(self.q.bias.data, -f3, f3) | |
self.optimizer = optim.Adam(self.parameters(), lr=beta, weight_decay=0.01) | |
self.device = T.device("cpu") | |
self.to(self.device) | |
def forward(self, state, action): | |
state_value = self.fc1(state) | |
state_value = self.bn1(state_value) | |
state_value = F.relu(state_value) | |
state_value = self.fc2(state_value) | |
state_value = self.bn2(state_value) | |
action_value = F.relu(self.action_value(action)) | |
state_action_value = F.relu(T.add(state_value, action_value)) | |
state_action_value = self.q(state_action_value) | |
return state_action_value | |
def save_checkpoint(self): | |
print('... saving checkpoint ...') | |
T.save(self.state_dict(), self.checkpoint_file) | |
def load_checkpoint(self): | |
print('... loading checkpoint ...') | |
self.load_state_dict(T.load(self.checkpoint_file)) | |
class ActorNetwork(nn.Module): | |
def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, n_actions, name, chkpt_dir="tmp/ddpg"): | |
super(ActorNetwork, self).__init__() | |
self.input_dims = input_dims | |
self.fc1_dims = fc1_dims | |
self.fc2_dims = fc2_dims | |
self.n_actions = n_actions | |
self.checkpoint_dir = chkpt_dir | |
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_ddpg') | |
self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims) | |
f1 = 1./np.sqrt(self.fc1.weight.data.size()[0]) | |
T.nn.init.uniform_(self.fc1.weight.data, -f1, f1) | |
T.nn.init.uniform_(self.fc1.bias.data, -f1, f1) | |
self.bn1 = nn.LayerNorm(self.fc1_dims) | |
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims) | |
f2 = 1./np.sqrt(self.fc2.weight.data.size()[0]) | |
T.nn.init.uniform_(self.fc2.weight.data, -f2, f2) | |
T.nn.init.uniform_(self.fc2.bias.data, -f2, f2) | |
self.bn2 = nn.LayerNorm(self.fc2_dims) | |
f3 = 0.003 # From paper | |
self.mu = nn.Linear(self.fc2_dims, self.n_actions) | |
T.nn.init.uniform_(self.mu.weight.data, -f3, f3) | |
T.nn.init.uniform_(self.mu.bias.data, -f3, f3) | |
T.nn.init.uniform_(self.mu.bias.data, -f3, f3) | |
self.optimizer = optim.Adam(self.parameters(), lr=alpha) | |
self.device = T.device("cpu") | |
self.to(self.device) | |
def forward(self, state): | |
x = self.fc1(state) | |
x = self.bn1(x) | |
x = F.relu(x) | |
x = self.fc2(x) | |
x = self.bn2(x) | |
x = F.relu(x) | |
x = T.tanh(self.mu(x)) | |
return x | |
def save_checkpoint(self): | |
print('... saving checkpoint ...') | |
T.save(self.state_dict(), self.checkpoint_file) | |
def load_checkpoint(self): | |
print('... loading checkpoint ...') | |
self.load_state_dict(T.load(self.checkpoint_file)) | |
class Agent(object): | |
def __init__(self, alpha, beta, input_dims, tau, env, gamma=0.99, n_actions=2, max_size=1000000, layer1_size=400, layer2_size=300, batch_size=64): | |
self.gamma = gamma | |
self.tau = tau | |
self.batch_size = batch_size | |
self.memory = ReplayBuffer(max_size, input_dims, n_actions) | |
self.actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="actor") | |
self.critic = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="critic") | |
self.target_actor = ActorNetwork(alpha, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="target_actor") | |
self.target_critic = CriticNetwork(beta, input_dims, layer1_size, layer2_size, n_actions=n_actions, name="target_critic") | |
self.noise = OUActionNoise(mu=np.zeros(n_actions)) | |
self.attributions = [] | |
self.ig : IntegratedGradients = None | |
self.update_network_parameters(tau=1) | |
def choose_action(self, observation, baseline: T.Tensor = None): | |
self.actor.eval() | |
observation = T.tensor(observation, dtype=T.float).to(self.actor.device) | |
# print(f"Observation: {observation.shape=}") | |
mu = self.actor(observation).to(self.actor.device) | |
if self.ig is not None: | |
attribution = self.ig.attribute(observation, baselines=baseline, n_steps=1) | |
self.attributions.append(attribution) | |
mu_prime = mu + T.tensor(self.noise(), dtype=T.float).to(self.actor.device) | |
self.actor.train() | |
return mu_prime.cpu().detach().numpy() | |
def remember(self, state, action, reward, new_state, done): | |
self.memory.store_transition(state, action, reward, new_state, done) | |
def learn(self): | |
if self.memory.mem_cntr < self.batch_size: | |
return | |
state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size) | |
reward = T.tensor(reward, dtype=T.float).to(self.critic.device) | |
done = T.tensor(done).to(self.critic.device) | |
new_state = T.tensor(new_state, dtype=T.float).to(self.critic.device) | |
action = T.tensor(action, dtype=T.float).to(self.critic.device) | |
state = T.tensor(state, dtype=T.float).to(self.critic.device) | |
self.target_actor.eval() | |
self.target_critic.eval() | |
self.critic.eval() | |
target_actions = self.target_actor.forward(new_state) | |
critic_value_ = self.target_critic.forward(new_state, target_actions) | |
critic_value = self.critic.forward(state, action) | |
target = [] | |
for j in range(self.batch_size): | |
target.append(reward[j] + self.gamma*critic_value_[j]*done[j]) | |
target = T.tensor(target).to(self.critic.device) | |
target = target.view(self.batch_size, 1) | |
self.critic.train() | |
self.critic.optimizer.zero_grad() | |
critic_loss = F.mse_loss(target, critic_value) | |
critic_loss.backward() | |
self.critic.optimizer.step() | |
self.critic.eval() | |
self.actor.optimizer.zero_grad() | |
mu = self.actor.forward(state) | |
self.actor.train() | |
actor_loss = -self.critic.forward(state, mu) | |
actor_loss = T.mean(actor_loss) | |
actor_loss.backward() | |
self.actor.optimizer.step() | |
self.update_network_parameters() | |
def update_network_parameters(self, tau=None): | |
if tau is None: | |
tau = self.tau | |
actor_params = self.actor.named_parameters() | |
critic_params = self.critic.named_parameters() | |
target_actor_params = self.target_actor.named_parameters() | |
target_critic_params = self.target_critic.named_parameters() | |
critic_state_dict = dict(critic_params) | |
actor_state_dict = dict(actor_params) | |
target_critic_state_dict = dict(target_critic_params) | |
target_actor_state_dict = dict(target_actor_params) | |
for name in critic_state_dict: | |
critic_state_dict[name] = tau*critic_state_dict[name].clone() + (1-tau)*target_critic_state_dict[name].clone() | |
self.target_critic.load_state_dict(critic_state_dict) | |
for name in actor_state_dict: | |
actor_state_dict[name] = tau*actor_state_dict[name].clone() + (1-tau)*target_actor_state_dict[name].clone() | |
self.target_actor.load_state_dict(actor_state_dict) | |
def save_models(self): | |
self.actor.save_checkpoint() | |
self.target_actor.save_checkpoint() | |
self.critic.save_checkpoint() | |
self.target_critic.save_checkpoint() | |
def load_models(self): | |
self.actor.load_checkpoint() | |
self.target_actor.load_checkpoint() | |
self.critic.load_checkpoint() | |
self.target_critic.load_checkpoint() | |