Spaces:
Sleeping
Sleeping
import gradio as gr | |
from train import TrainingLoop | |
from scipy.special import softmax | |
import numpy as np | |
train = None | |
frames, attributions = None, None | |
lunar_lander_spec_conversion = { | |
0: "X-coordinate", | |
1: "Y-coordinate", | |
2: "Linear velocity in the X-axis", | |
3: "Linear velocity in the Y-axis", | |
4: "Angle", | |
5: "Angular velocity", | |
6: "Left leg touched the floor", | |
7: "Right leg touched the floor" | |
} | |
def create_training_loop(env_spec): | |
global train | |
train = TrainingLoop(env_spec=env_spec) | |
train.create_agent() | |
return train.env.spec | |
def display_softmax(inputs): | |
inputs = np.array(inputs) | |
probabilities = softmax(inputs) | |
softmax_dict = {name: float(prob) for name, prob in zip(lunar_lander_spec_conversion.values(), probabilities)} | |
return softmax_dict | |
def generate_output(num_iterations, option): | |
global frames, attributions | |
frames, attributions = train.explain_trained(num_iterations=num_iterations, option=option) | |
slider.maximum = len(frames) | |
def get_frame_and_attribution(slider_value): | |
global frames, attributions | |
slider_value = min(slider_value, len(frames) - 1) | |
frame = frames[slider_value] | |
print(f"{frame.shape=}") | |
attribution = display_softmax(attributions[slider_value]) | |
return frame, attribution | |
with gr.Blocks() as demo: | |
gr.Markdown("# Introspection in Deep Reinforcement Learning") | |
with gr.Tab(label="Attribute"): | |
env_spec = gr.Dropdown(choices=["LunarLander-v2"],type="value",multiselect=False, label="Environment Specification (e.g.: LunarLander-v2)") | |
env = gr.Interface(title="Create the Environment", allow_flagging="never", inputs=env_spec, fn=create_training_loop, outputs=gr.JSON()) | |
with gr.Row(): | |
option = gr.Dropdown(choices=["Torch Tensor of 0's", "Running Average"], type="index") | |
baselines = gr.Slider(label="Number of Baseline Iterations", interactive=True, minimum=0, maximum=100, value=10, step=5, info="Baseline inputs to collect for the average", render=True) | |
gr.Button("ATTRIBUTE").click(fn=generate_output, inputs=[baselines, option]) | |
slider = gr.Slider(label="Key Frame", minimum=0, maximum=1000, step=1, value=0) | |
gr.Interface(fn=get_frame_and_attribution, inputs=slider, live=True, outputs=[gr.Image(label="Timestep")]) | |
demo.launch() |