Spaces:
Runtime error
Runtime error
Changed git commands over to Hugging Face
Browse files- cartpole.py +37 -25
cartpole.py
CHANGED
@@ -11,7 +11,7 @@
|
|
11 |
# name: python3
|
12 |
# ---
|
13 |
|
14 |
-
# + id="QAY_RQOLcRtA" executionInfo={"status": "ok", "timestamp":
|
15 |
MAIN = __name__ == "__main__"
|
16 |
if MAIN:
|
17 |
print('Mounting drive...')
|
@@ -19,23 +19,32 @@ if MAIN:
|
|
19 |
drive.mount('/content/drive')
|
20 |
# %cd /content/drive/MyDrive/Colab Notebooks/cartpole-demo
|
21 |
|
22 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="GgSNZRJh4EjV" executionInfo={"status": "ok", "timestamp":
|
23 |
# !pip install einops
|
24 |
# !pip install wandb
|
25 |
# !pip install jupytext
|
26 |
# !pip install pygame
|
27 |
# !pip install torchtyping
|
28 |
# !pip install gradio
|
|
|
29 |
|
30 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="1g58HZUb8Ltl" executionInfo={"status": "ok", "timestamp":
|
|
|
|
|
31 |
# !git config --global user.email "[email protected]"
|
32 |
-
# !
|
33 |
-
# !cat pat.txt | xargs git remote set-url origin
|
34 |
# !jupytext --to py cartpole.ipynb
|
35 |
# !git fetch
|
|
|
36 |
# !git status
|
37 |
|
38 |
-
# + id="
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
import os
|
40 |
import glob
|
41 |
import sys
|
@@ -66,7 +75,7 @@ from typeguard import typechecked
|
|
66 |
# + id="K7T8bs1Y76ZK" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="f59ffef0-7156-4f27-d992-a392d59a1c73"
|
67 |
# %env "WANDB_NOTEBOOK_NAME" "cartpole.py"
|
68 |
|
69 |
-
# + id="Q5E93-BGRjuy"
|
70 |
def make_env(
|
71 |
env_id: str, seed: int, idx: int, capture_video: bool, run_name: str
|
72 |
):
|
@@ -93,7 +102,7 @@ def make_env(
|
|
93 |
return thunk
|
94 |
|
95 |
|
96 |
-
# + id="Kf152ROwHjM_"
|
97 |
def test_minibatch_indexes(minibatch_indexes):
|
98 |
for n in range(5):
|
99 |
frac, minibatch_size = np.random.randint(1, 8, size=(2,))
|
@@ -105,7 +114,7 @@ def test_minibatch_indexes(minibatch_indexes):
|
|
105 |
np.testing.assert_equal(np.sort(np.stack(indices).flatten()), np.arange(batch_size))
|
106 |
|
107 |
|
108 |
-
# + id="mhvduVeOHkln"
|
109 |
def test_calc_entropy_bonus(calc_entropy_bonus):
|
110 |
probs = Categorical(logits=t.randn((3, 4)))
|
111 |
ent_coef = 0.5
|
@@ -114,7 +123,7 @@ def test_calc_entropy_bonus(calc_entropy_bonus):
|
|
114 |
t.testing.assert_close(expected, actual)
|
115 |
|
116 |
|
117 |
-
# + id="Aya60GeCGA5X"
|
118 |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
119 |
t.nn.init.orthogonal_(layer.weight, std)
|
120 |
t.nn.init.constant_(layer.bias, bias_const)
|
@@ -146,7 +155,7 @@ class Agent(nn.Module):
|
|
146 |
|
147 |
|
148 |
|
149 |
-
# + id="6PwPZHlLGDYu"
|
150 |
# %%
|
151 |
@t.inference_mode()
|
152 |
def compute_advantages(
|
@@ -190,7 +199,7 @@ def compute_advantages(
|
|
190 |
|
191 |
|
192 |
|
193 |
-
# + id="uYSSMnF-GPvm"
|
194 |
# %%
|
195 |
@dataclass
|
196 |
class Minibatch:
|
@@ -252,7 +261,7 @@ def make_minibatches(
|
|
252 |
|
253 |
|
254 |
|
255 |
-
# + id="K7wXDJ9MGOWu"
|
256 |
# %%
|
257 |
def calc_policy_loss(
|
258 |
probs: Categorical, mb_action: t.Tensor, mb_advantages: t.Tensor,
|
@@ -277,7 +286,7 @@ def calc_policy_loss(
|
|
277 |
|
278 |
|
279 |
|
280 |
-
# + id="CmyxU6JWGMsG"
|
281 |
# %%
|
282 |
def calc_value_function_loss(
|
283 |
critic: nn.Sequential, mb_obs: t.Tensor, mb_returns: t.Tensor, v_coef: float
|
@@ -294,7 +303,7 @@ def calc_value_function_loss(
|
|
294 |
|
295 |
|
296 |
|
297 |
-
# + id="npyWs6xjGLkP"
|
298 |
# %%
|
299 |
def calc_entropy_loss(probs: Categorical, ent_coef: float):
|
300 |
'''Return the entropy loss term.
|
@@ -310,7 +319,7 @@ if MAIN:
|
|
310 |
test_calc_entropy_bonus(calc_entropy_loss)
|
311 |
|
312 |
|
313 |
-
# + id="nqJeg1kZGKSG"
|
314 |
# %%
|
315 |
class PPOScheduler:
|
316 |
def __init__(self, optimizer: optim.Adam, initial_lr: float, end_lr: float, num_updates: int):
|
@@ -345,7 +354,7 @@ def make_optimizer(
|
|
345 |
|
346 |
|
347 |
|
348 |
-
# + id="mgZ7-wsRCxJW"
|
349 |
@dataclass
|
350 |
class PPOArgs:
|
351 |
exp_name: str = 'cartpole.py'
|
@@ -373,7 +382,7 @@ class PPOArgs:
|
|
373 |
minibatch_size: int = 128
|
374 |
|
375 |
|
376 |
-
# + id="xeIu-J3ZwGyq"
|
377 |
def wandb_init(name: str, args: PPOArgs):
|
378 |
wandb.init(
|
379 |
project=args.wandb_project_name,
|
@@ -387,14 +396,14 @@ def wandb_init(name: str, args: PPOArgs):
|
|
387 |
)
|
388 |
|
389 |
|
390 |
-
# + id="gMYWqhsryYHy"
|
391 |
def set_seed(seed: int):
|
392 |
random.seed(seed)
|
393 |
np.random.seed(seed)
|
394 |
torch.manual_seed(seed)
|
395 |
|
396 |
|
397 |
-
# + id="T9j_L0Wpyrgz"
|
398 |
@typechecked
|
399 |
def rollout_phase(
|
400 |
next_obs: t.Tensor, next_done: t.Tensor,
|
@@ -472,14 +481,14 @@ def rollout_phase(
|
|
472 |
)
|
473 |
|
474 |
|
475 |
-
# + id="xdDhABIk5jyb"
|
476 |
def reset_env(envs, device):
|
477 |
next_obs = torch.Tensor(envs.reset()).to(device)
|
478 |
next_done = torch.zeros(envs.num_envs).to(device)
|
479 |
return next_obs, next_done
|
480 |
|
481 |
|
482 |
-
# + id="5CoMpUVU7rFT"
|
483 |
def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
484 |
action_shape = envs.single_action_space.shape
|
485 |
assert action_shape is not None
|
@@ -489,7 +498,7 @@ def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
|
489 |
return action_shape
|
490 |
|
491 |
|
492 |
-
# + id="FHmn5kSUGFFu"
|
493 |
# %%
|
494 |
def train_ppo(args: PPOArgs):
|
495 |
t0 = int(time.time())
|
@@ -628,8 +637,11 @@ if MAIN:
|
|
628 |
args = PPOArgs()
|
629 |
train_ppo(args)
|
630 |
|
631 |
-
# + colab={"base_uri": "https://localhost:8080/"} id="xJW6KL7QIj4s" outputId="7c529849-6d46-4a6a-def5-e1c0ef652c64"
|
632 |
# !python demo.py
|
633 |
|
634 |
-
# + id="P7ZfUlAqImIr"
|
|
|
|
|
|
|
635 |
|
|
|
11 |
# name: python3
|
12 |
# ---
|
13 |
|
14 |
+
# + id="QAY_RQOLcRtA" executionInfo={"status": "ok", "timestamp": 1677945244865, "user_tz": 0, "elapsed": 19712, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="be179435-1667-40af-8a80-7bc63a472715"
|
15 |
MAIN = __name__ == "__main__"
|
16 |
if MAIN:
|
17 |
print('Mounting drive...')
|
|
|
19 |
drive.mount('/content/drive')
|
20 |
# %cd /content/drive/MyDrive/Colab Notebooks/cartpole-demo
|
21 |
|
22 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="GgSNZRJh4EjV" executionInfo={"status": "ok", "timestamp": 1677945316689, "user_tz": 0, "elapsed": 57846, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="6aeb7bf3-e186-449d-cdc4-c66f778244b2"
|
23 |
# !pip install einops
|
24 |
# !pip install wandb
|
25 |
# !pip install jupytext
|
26 |
# !pip install pygame
|
27 |
# !pip install torchtyping
|
28 |
# !pip install gradio
|
29 |
+
# !pip install huggingface_hub
|
30 |
|
31 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="1g58HZUb8Ltl" executionInfo={"status": "ok", "timestamp": 1677945458077, "user_tz": 0, "elapsed": 16862, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="62ffc9cd-ff0b-4473-c17a-4593a14526cf"
|
32 |
+
# !git config --global credential.helper store
|
33 |
+
# !git config --global user.name "skar0"
|
34 |
# !git config --global user.email "[email protected]"
|
35 |
+
# !huggingface-cli login
|
|
|
36 |
# !jupytext --to py cartpole.ipynb
|
37 |
# !git fetch
|
38 |
+
# # !chmod +x .git/hooks/pre-push
|
39 |
# !git status
|
40 |
|
41 |
+
# + id="dYeFdxVIWOqc" executionInfo={"status": "ok", "timestamp": 1677945546175, "user_tz": 0, "elapsed": 318, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}}
|
42 |
+
|
43 |
+
|
44 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="5xFqBnKzVN60" executionInfo={"status": "ok", "timestamp": 1677945556589, "user_tz": 0, "elapsed": 7558, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="535e6c5e-17f6-4342-8a9d-ff54f4c82187"
|
45 |
+
# !git push
|
46 |
+
|
47 |
+
# + id="vEczQ48wC40O"
|
48 |
import os
|
49 |
import glob
|
50 |
import sys
|
|
|
75 |
# + id="K7T8bs1Y76ZK" executionInfo={"status": "ok", "timestamp": 1677942330521, "user_tz": 0, "elapsed": 8, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} colab={"base_uri": "https://localhost:8080/"} outputId="f59ffef0-7156-4f27-d992-a392d59a1c73"
|
76 |
# %env "WANDB_NOTEBOOK_NAME" "cartpole.py"
|
77 |
|
78 |
+
# + id="Q5E93-BGRjuy"
|
79 |
def make_env(
|
80 |
env_id: str, seed: int, idx: int, capture_video: bool, run_name: str
|
81 |
):
|
|
|
102 |
return thunk
|
103 |
|
104 |
|
105 |
+
# + id="Kf152ROwHjM_"
|
106 |
def test_minibatch_indexes(minibatch_indexes):
|
107 |
for n in range(5):
|
108 |
frac, minibatch_size = np.random.randint(1, 8, size=(2,))
|
|
|
114 |
np.testing.assert_equal(np.sort(np.stack(indices).flatten()), np.arange(batch_size))
|
115 |
|
116 |
|
117 |
+
# + id="mhvduVeOHkln"
|
118 |
def test_calc_entropy_bonus(calc_entropy_bonus):
|
119 |
probs = Categorical(logits=t.randn((3, 4)))
|
120 |
ent_coef = 0.5
|
|
|
123 |
t.testing.assert_close(expected, actual)
|
124 |
|
125 |
|
126 |
+
# + id="Aya60GeCGA5X"
|
127 |
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
|
128 |
t.nn.init.orthogonal_(layer.weight, std)
|
129 |
t.nn.init.constant_(layer.bias, bias_const)
|
|
|
155 |
|
156 |
|
157 |
|
158 |
+
# + id="6PwPZHlLGDYu"
|
159 |
# %%
|
160 |
@t.inference_mode()
|
161 |
def compute_advantages(
|
|
|
199 |
|
200 |
|
201 |
|
202 |
+
# + id="uYSSMnF-GPvm"
|
203 |
# %%
|
204 |
@dataclass
|
205 |
class Minibatch:
|
|
|
261 |
|
262 |
|
263 |
|
264 |
+
# + id="K7wXDJ9MGOWu"
|
265 |
# %%
|
266 |
def calc_policy_loss(
|
267 |
probs: Categorical, mb_action: t.Tensor, mb_advantages: t.Tensor,
|
|
|
286 |
|
287 |
|
288 |
|
289 |
+
# + id="CmyxU6JWGMsG"
|
290 |
# %%
|
291 |
def calc_value_function_loss(
|
292 |
critic: nn.Sequential, mb_obs: t.Tensor, mb_returns: t.Tensor, v_coef: float
|
|
|
303 |
|
304 |
|
305 |
|
306 |
+
# + id="npyWs6xjGLkP"
|
307 |
# %%
|
308 |
def calc_entropy_loss(probs: Categorical, ent_coef: float):
|
309 |
'''Return the entropy loss term.
|
|
|
319 |
test_calc_entropy_bonus(calc_entropy_loss)
|
320 |
|
321 |
|
322 |
+
# + id="nqJeg1kZGKSG"
|
323 |
# %%
|
324 |
class PPOScheduler:
|
325 |
def __init__(self, optimizer: optim.Adam, initial_lr: float, end_lr: float, num_updates: int):
|
|
|
354 |
|
355 |
|
356 |
|
357 |
+
# + id="mgZ7-wsRCxJW"
|
358 |
@dataclass
|
359 |
class PPOArgs:
|
360 |
exp_name: str = 'cartpole.py'
|
|
|
382 |
minibatch_size: int = 128
|
383 |
|
384 |
|
385 |
+
# + id="xeIu-J3ZwGyq"
|
386 |
def wandb_init(name: str, args: PPOArgs):
|
387 |
wandb.init(
|
388 |
project=args.wandb_project_name,
|
|
|
396 |
)
|
397 |
|
398 |
|
399 |
+
# + id="gMYWqhsryYHy"
|
400 |
def set_seed(seed: int):
|
401 |
random.seed(seed)
|
402 |
np.random.seed(seed)
|
403 |
torch.manual_seed(seed)
|
404 |
|
405 |
|
406 |
+
# + id="T9j_L0Wpyrgz"
|
407 |
@typechecked
|
408 |
def rollout_phase(
|
409 |
next_obs: t.Tensor, next_done: t.Tensor,
|
|
|
481 |
)
|
482 |
|
483 |
|
484 |
+
# + id="xdDhABIk5jyb"
|
485 |
def reset_env(envs, device):
|
486 |
next_obs = torch.Tensor(envs.reset()).to(device)
|
487 |
next_done = torch.zeros(envs.num_envs).to(device)
|
488 |
return next_obs, next_done
|
489 |
|
490 |
|
491 |
+
# + id="5CoMpUVU7rFT"
|
492 |
def get_action_shape(envs: gym.vector.SyncVectorEnv):
|
493 |
action_shape = envs.single_action_space.shape
|
494 |
assert action_shape is not None
|
|
|
498 |
return action_shape
|
499 |
|
500 |
|
501 |
+
# + id="FHmn5kSUGFFu"
|
502 |
# %%
|
503 |
def train_ppo(args: PPOArgs):
|
504 |
t0 = int(time.time())
|
|
|
637 |
args = PPOArgs()
|
638 |
train_ppo(args)
|
639 |
|
640 |
+
# + colab={"base_uri": "https://localhost:8080/"} id="xJW6KL7QIj4s" executionInfo={"status": "ok", "timestamp": 1677942639015, "user_tz": 0, "elapsed": 105286, "user": {"displayName": "Oskar Hollinsworth", "userId": "00307706571197304608"}} outputId="7c529849-6d46-4a6a-def5-e1c0ef652c64"
|
641 |
# !python demo.py
|
642 |
|
643 |
+
# + id="P7ZfUlAqImIr"
|
644 |
+
# !pip freeze > requirements.txt
|
645 |
+
|
646 |
+
# + id="x_bhyL3GLnhr"
|
647 |
|