File size: 895 Bytes
569299e
 
 
 
 
 
 
 
 
ee1c253
ec3a146
569299e
 
ee1c253
569299e
 
 
 
 
 
 
 
ee1c253
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from ddpg import Agent
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import torch
import argparse
from train import TrainingLoop
from captum.attr import (IntegratedGradients, LayerConductance, NeuronAttribution)

training_loop = TrainingLoop(env_spec="LunarLander-v2", continuous=True, gravity=-10)
training_loop.create_agent()

parser = argparse.ArgumentParser(description="Choose a function to run.")
parser.add_argument("function", choices=["train", "load-trained", "attribute", "video"], help="The function to run.")

args = parser.parse_args()

if args.function == "train":
    training_loop.train()
elif args.function == "load-trained":
    training_loop.load_trained()
elif args.function == "attribute":
    frames, attributions = training_loop.explain_trained(option="2", num_iterations=10)
elif args.function == "video":
    training_loop.render_video(20)