gabehubner's picture
fix app.py
9bb89e8
raw
history blame
2.37 kB
import gradio as gr
from train import TrainingLoop
from scipy.special import softmax
import numpy as np
train = None
frames, attributions = None, None
lunar_lander_spec_conversion = {
0: "X-coordinate",
1: "Y-coordinate",
2: "Linear velocity in the X-axis",
3: "Linear velocity in the Y-axis",
4: "Angle",
5: "Angular velocity",
6: "Left leg touched the floor",
7: "Right leg touched the floor"
}
def create_training_loop(env_spec):
global train
train = TrainingLoop(env_spec=env_spec)
train.create_agent()
return train.env.spec
def display_softmax(inputs):
inputs = np.array(inputs)
probabilities = softmax(inputs)
softmax_dict = {name: float(prob) for name, prob in zip(lunar_lander_spec_conversion.values(), probabilities)}
return softmax_dict
def generate_output(num_iterations, option):
global frames, attributions
frames, attributions = train.explain_trained(num_iterations=num_iterations, option=option)
slider.maximum = len(frames)
def get_frame_and_attribution(slider_value):
global frames, attributions
slider_value = min(slider_value, len(frames) - 1)
frame = frames[slider_value]
print(f"{frame.shape=}")
attribution = display_softmax(attributions[slider_value])
return frame, attribution
with gr.Blocks() as demo:
gr.Markdown("# Introspection in Deep Reinforcement Learning")
with gr.Tab(label="Attribute"):
env_spec = gr.Dropdown(choices=["LunarLander-v2"],type="value",multiselect=False, label="Environment Specification (e.g.: LunarLander-v2)")
env = gr.Interface(title="Create the Environment", allow_flagging="never", inputs=env_spec, fn=create_training_loop, outputs=gr.JSON())
with gr.Row():
option = gr.Dropdown(choices=["Torch Tensor of 0's", "Running Average"], type="index")
baselines = gr.Slider(label="Number of Baseline Iterations", interactive=True, minimum=0, maximum=100, value=10, step=5, info="Baseline inputs to collect for the average", render=True)
gr.Button("ATTRIBUTE").click(fn=generate_output, inputs=[baselines, option])
slider = gr.Slider(label="Key Frame", minimum=0, maximum=1000, step=1, value=0)
gr.Interface(fn=get_frame_and_attribution, inputs=slider, live=True, outputs=[gr.Image(label="Timestep")])
demo.launch()