skar0 commited on
Commit
4d7aced
·
1 Parent(s): 60177de

Removed share=True from launch

Browse files
Files changed (1) hide show
  1. app.py +54 -54
app.py CHANGED
@@ -1,54 +1,54 @@
1
- import glob
2
- import gradio as gr
3
- import gym
4
- import sys
5
- from torch.utils.tensorboard import SummaryWriter
6
- import yaml
7
- import torch
8
- from cartpole import (
9
- make_env, reset_env, Agent, rollout_phase, get_action_shape
10
- )
11
-
12
- MAIN = __name__ == "__main__"
13
- examples = [0, 1, 31415, 'Hello, World!', 'This is a seed...']
14
-
15
- def generate_video(
16
- string: str, wandb_path='wandb/run-20230303_211416-ox4d1p0u/files'
17
- ):
18
- with open(f'{wandb_path}/config.yaml') as f_cfg:
19
- config = yaml.safe_load(f_cfg)
20
- seed = hash(string) % ((sys.maxsize + 1) * 2)
21
- num_envs = config['num_envs']['value']
22
- num_steps = config['num_steps']['value']
23
- assert seed >= 0
24
- assert isinstance(seed, int)
25
- run_name = f'seed{seed}'
26
- log_dir = f'generate/{run_name}'
27
- writer = SummaryWriter(log_dir)
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- envs = gym.vector.SyncVectorEnv([
30
- make_env("CartPole-v1", seed, i, True, run_name)
31
- for i in range(num_envs)
32
- ])
33
- action_shape = get_action_shape(envs)
34
- next_obs, next_done = reset_env(envs, device)
35
- global_step = 0
36
- agent = Agent(envs).to(device)
37
- agent.load_state_dict(torch.load(f'{wandb_path}/model_state_dict.pt'))
38
- rollout_phase(
39
- next_obs, next_done, agent, envs, writer, device,
40
- global_step, action_shape, num_envs, num_steps,
41
- )
42
- video_path = glob.glob(f'videos/{run_name}/*.mp4')[0]
43
- return video_path
44
-
45
- if MAIN:
46
- demo = gr.Interface(
47
- fn=generate_video,
48
- inputs=[
49
- gr.components.Textbox(lines=1, label="Seed"),
50
- ],
51
- outputs=gr.components.Video(label="Generated Video"),
52
- examples=examples,
53
- )
54
- demo.launch(share=True)
 
1
+ import glob
2
+ import gradio as gr
3
+ import gym
4
+ import sys
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ import yaml
7
+ import torch
8
+ from cartpole import (
9
+ make_env, reset_env, Agent, rollout_phase, get_action_shape
10
+ )
11
+
12
+ MAIN = __name__ == "__main__"
13
+ examples = [0, 1, 31415, 'Hello, World!', 'This is a seed...']
14
+
15
+ def generate_video(
16
+ string: str, wandb_path='wandb/run-20230303_211416-ox4d1p0u/files'
17
+ ):
18
+ with open(f'{wandb_path}/config.yaml') as f_cfg:
19
+ config = yaml.safe_load(f_cfg)
20
+ seed = hash(string) % ((sys.maxsize + 1) * 2)
21
+ num_envs = config['num_envs']['value']
22
+ num_steps = config['num_steps']['value']
23
+ assert seed >= 0
24
+ assert isinstance(seed, int)
25
+ run_name = f'seed{seed}'
26
+ log_dir = f'generate/{run_name}'
27
+ writer = SummaryWriter(log_dir)
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ envs = gym.vector.SyncVectorEnv([
30
+ make_env("CartPole-v1", seed, i, True, run_name)
31
+ for i in range(num_envs)
32
+ ])
33
+ action_shape = get_action_shape(envs)
34
+ next_obs, next_done = reset_env(envs, device)
35
+ global_step = 0
36
+ agent = Agent(envs).to(device)
37
+ agent.load_state_dict(torch.load(f'{wandb_path}/model_state_dict.pt'))
38
+ rollout_phase(
39
+ next_obs, next_done, agent, envs, writer, device,
40
+ global_step, action_shape, num_envs, num_steps,
41
+ )
42
+ video_path = glob.glob(f'videos/{run_name}/*.mp4')[0]
43
+ return video_path
44
+
45
+ if MAIN:
46
+ demo = gr.Interface(
47
+ fn=generate_video,
48
+ inputs=[
49
+ gr.components.Textbox(lines=1, label="Seed"),
50
+ ],
51
+ outputs=gr.components.Video(label="Generated Video"),
52
+ examples=examples,
53
+ )
54
+ demo.launch()