Spaces:
Runtime error
Runtime error
vitaliykinakh
commited on
Commit
Β·
d872920
1
Parent(s):
d6b8e04
Add model weights as submodule
Browse files- .gitmodules +3 -0
- galaxy-zoo-generation +1 -0
- src/app/compare_models.py +1 -16
- src/app/explore_biggan.py +1 -10
- src/app/explore_cvae.py +1 -9
- src/app/explore_infoscc_gan.py +1 -8
- src/app/interpolate_labels.py +1 -18
- src/app/params.py +4 -11
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "galaxy-zoo-generation"]
|
2 |
+
path = galaxy-zoo-generation
|
3 |
+
url = https://huggingface.co/vitaliykinakh/galaxy-zoo-generation
|
galaxy-zoo-generation
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit af017a826b2be3dec5f364ac4e232ede6cc0e04f
|
src/app/compare_models.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from pathlib import Path
|
2 |
import math
|
3 |
|
4 |
import streamlit as st
|
@@ -14,7 +13,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
|
|
14 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
15 |
from src.models import ConditionalDecoder as cVAE
|
16 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
17 |
-
from src.utils import
|
18 |
|
19 |
|
20 |
device = params.device
|
@@ -76,25 +75,14 @@ def load_model(model_type: str):
|
|
76 |
y_size=params.shape_label,
|
77 |
z_size=params.noise_dim)
|
78 |
|
79 |
-
if not Path(params.path_infoscc_gan).exists():
|
80 |
-
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
81 |
-
|
82 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
83 |
g.load_state_dict(ckpt['g_ema'])
|
84 |
elif model_type == 'BigGAN':
|
85 |
g = BigGAN2Generator()
|
86 |
-
|
87 |
-
if not Path(params.path_biggan).exists():
|
88 |
-
download_file(params.drive_id_biggan, params.path_biggan)
|
89 |
-
|
90 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
91 |
g.load_state_dict(ckpt)
|
92 |
elif model_type == 'cVAE':
|
93 |
g = cVAE()
|
94 |
-
|
95 |
-
if not Path(params.path_cvae).exists():
|
96 |
-
download_file(params.drive_id_cvae, params.path_cvae)
|
97 |
-
|
98 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
99 |
g.load_state_dict(ckpt)
|
100 |
else:
|
@@ -107,9 +95,6 @@ def load_model(model_type: str):
|
|
107 |
def get_labels() -> torch.Tensor:
|
108 |
path_labels = params.path_labels
|
109 |
|
110 |
-
if not Path(path_labels).exists():
|
111 |
-
download_file(params.drive_id_labels, path_labels)
|
112 |
-
|
113 |
labels_train = get_labels_train(path_labels)
|
114 |
return labels_train
|
115 |
|
|
|
|
|
1 |
import math
|
2 |
|
3 |
import streamlit as st
|
|
|
13 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
14 |
from src.models import ConditionalDecoder as cVAE
|
15 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
16 |
+
from src.utils import sample_labels
|
17 |
|
18 |
|
19 |
device = params.device
|
|
|
75 |
y_size=params.shape_label,
|
76 |
z_size=params.noise_dim)
|
77 |
|
|
|
|
|
|
|
78 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
79 |
g.load_state_dict(ckpt['g_ema'])
|
80 |
elif model_type == 'BigGAN':
|
81 |
g = BigGAN2Generator()
|
|
|
|
|
|
|
|
|
82 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
83 |
g.load_state_dict(ckpt)
|
84 |
elif model_type == 'cVAE':
|
85 |
g = cVAE()
|
|
|
|
|
|
|
|
|
86 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
87 |
g.load_state_dict(ckpt)
|
88 |
else:
|
|
|
95 |
def get_labels() -> torch.Tensor:
|
96 |
path_labels = params.path_labels
|
97 |
|
|
|
|
|
|
|
98 |
labels_train = get_labels_train(path_labels)
|
99 |
return labels_train
|
100 |
|
src/app/explore_biggan.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import math
|
2 |
-
from pathlib import Path
|
3 |
|
4 |
import streamlit as st
|
5 |
import numpy as np
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
-
from src.utils import
|
16 |
|
17 |
|
18 |
# global parameters
|
@@ -25,7 +24,6 @@ dim_z = params.dim_z
|
|
25 |
bs = 16 # number of samples to generate
|
26 |
n_cols = int(math.sqrt(bs))
|
27 |
model_path = params.path_biggan
|
28 |
-
drive_id = params.drive_id_biggan
|
29 |
path_labels = params.path_labels
|
30 |
|
31 |
# manual labels
|
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
|
|
90 |
|
91 |
@st.cache
|
92 |
def get_labels() -> torch.Tensor:
|
93 |
-
if not Path(path_labels).exists():
|
94 |
-
download_file(params.drive_id_labels, path_labels)
|
95 |
-
|
96 |
labels_train = get_labels_train(path_labels)
|
97 |
return labels_train
|
98 |
|
@@ -102,10 +97,6 @@ def app():
|
|
102 |
|
103 |
st.title('Explore BigGAN')
|
104 |
st.markdown('This demo shows BigGAN for conditional galaxy generation')
|
105 |
-
|
106 |
-
if not Path(model_path).exists():
|
107 |
-
download_file(drive_id, model_path)
|
108 |
-
|
109 |
model = load_model(model_path)
|
110 |
eps = get_eps(bs)
|
111 |
labels_train = get_labels()
|
|
|
1 |
import math
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
import numpy as np
|
|
|
11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
12 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
14 |
+
from src.utils import sample_labels
|
15 |
|
16 |
|
17 |
# global parameters
|
|
|
24 |
bs = 16 # number of samples to generate
|
25 |
n_cols = int(math.sqrt(bs))
|
26 |
model_path = params.path_biggan
|
|
|
27 |
path_labels = params.path_labels
|
28 |
|
29 |
# manual labels
|
|
|
88 |
|
89 |
@st.cache
|
90 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
|
|
91 |
labels_train = get_labels_train(path_labels)
|
92 |
return labels_train
|
93 |
|
|
|
97 |
|
98 |
st.title('Explore BigGAN')
|
99 |
st.markdown('This demo shows BigGAN for conditional galaxy generation')
|
|
|
|
|
|
|
|
|
100 |
model = load_model(model_path)
|
101 |
eps = get_eps(bs)
|
102 |
labels_train = get_labels()
|
src/app/explore_cvae.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import math
|
2 |
-
from pathlib import Path
|
3 |
|
4 |
import streamlit as st
|
5 |
import numpy as np
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
from src.models import ConditionalDecoder
|
14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
-
from src.utils import
|
16 |
|
17 |
|
18 |
# global parameters
|
@@ -25,7 +24,6 @@ dim_z = params.dim_z
|
|
25 |
bs = 16 # number of samples to generate
|
26 |
n_cols = int(math.sqrt(bs))
|
27 |
model_path = params.path_cvae
|
28 |
-
drive_id = params.drive_id_cvae
|
29 |
path_labels = params.path_labels
|
30 |
|
31 |
# manual labels
|
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
|
|
90 |
|
91 |
@st.cache
|
92 |
def get_labels() -> torch.Tensor:
|
93 |
-
if not Path(path_labels).exists():
|
94 |
-
download_file(params.drive_id_labels, path_labels)
|
95 |
-
|
96 |
labels_train = get_labels_train(path_labels)
|
97 |
return labels_train
|
98 |
|
@@ -103,9 +98,6 @@ def app():
|
|
103 |
st.title('Explore cVAE')
|
104 |
st.markdown('This demo shows cVAE for conditional galaxy generation')
|
105 |
|
106 |
-
if not Path(model_path).exists():
|
107 |
-
download_file(drive_id, model_path)
|
108 |
-
|
109 |
model = load_model(model_path)
|
110 |
eps = get_eps(bs)
|
111 |
labels_train = get_labels()
|
|
|
1 |
import math
|
|
|
2 |
|
3 |
import streamlit as st
|
4 |
import numpy as np
|
|
|
11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
12 |
from src.models import ConditionalDecoder
|
13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
14 |
+
from src.utils import sample_labels
|
15 |
|
16 |
|
17 |
# global parameters
|
|
|
24 |
bs = 16 # number of samples to generate
|
25 |
n_cols = int(math.sqrt(bs))
|
26 |
model_path = params.path_cvae
|
|
|
27 |
path_labels = params.path_labels
|
28 |
|
29 |
# manual labels
|
|
|
88 |
|
89 |
@st.cache
|
90 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
|
|
91 |
labels_train = get_labels_train(path_labels)
|
92 |
return labels_train
|
93 |
|
|
|
98 |
st.title('Explore cVAE')
|
99 |
st.markdown('This demo shows cVAE for conditional galaxy generation')
|
100 |
|
|
|
|
|
|
|
101 |
model = load_model(model_path)
|
102 |
eps = get_eps(bs)
|
103 |
labels_train = get_labels()
|
src/app/explore_infoscc_gan.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from pathlib import Path
|
2 |
import math
|
3 |
|
4 |
import numpy as np
|
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
|
|
12 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
13 |
from src.models import ConditionalGenerator
|
14 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
15 |
-
from src.utils import
|
16 |
|
17 |
# global parameters
|
18 |
device = params.device
|
@@ -27,7 +26,6 @@ y_type = params.y_type
|
|
27 |
bs = 16 # number of samples to generate
|
28 |
n_cols = int(math.sqrt(bs))
|
29 |
model_path = params.path_infoscc_gan # path to the model
|
30 |
-
drive_id = params.drive_id_infoscc_gan # google drive id of the model
|
31 |
path_labels = params.path_labels
|
32 |
|
33 |
# manual labels
|
@@ -87,8 +85,6 @@ def load_model(model_path: str) -> ConditionalGenerator:
|
|
87 |
|
88 |
@st.cache
|
89 |
def get_labels() -> torch.Tensor:
|
90 |
-
if not Path(path_labels).exists():
|
91 |
-
download_file(params.drive_id_labels, path_labels)
|
92 |
labels_train = get_labels_train(path_labels)
|
93 |
return labels_train
|
94 |
|
@@ -100,9 +96,6 @@ def app():
|
|
100 |
st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
|
101 |
st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
|
102 |
|
103 |
-
if not Path(model_path).exists():
|
104 |
-
download_file(drive_id, model_path)
|
105 |
-
|
106 |
model = load_model(model_path)
|
107 |
eps = model.sample_eps(bs).to(device)
|
108 |
labels_train = get_labels()
|
|
|
|
|
1 |
import math
|
2 |
|
3 |
import numpy as np
|
|
|
11 |
q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
|
12 |
from src.models import ConditionalGenerator
|
13 |
from src.data import get_labels_train, make_galaxy_labels_hierarchical
|
14 |
+
from src.utils import sample_labels
|
15 |
|
16 |
# global parameters
|
17 |
device = params.device
|
|
|
26 |
bs = 16 # number of samples to generate
|
27 |
n_cols = int(math.sqrt(bs))
|
28 |
model_path = params.path_infoscc_gan # path to the model
|
|
|
29 |
path_labels = params.path_labels
|
30 |
|
31 |
# manual labels
|
|
|
85 |
|
86 |
@st.cache
|
87 |
def get_labels() -> torch.Tensor:
|
|
|
|
|
88 |
labels_train = get_labels_train(path_labels)
|
89 |
return labels_train
|
90 |
|
|
|
96 |
st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
|
97 |
st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
|
98 |
|
|
|
|
|
|
|
99 |
model = load_model(model_path)
|
100 |
eps = model.sample_eps(bs).to(device)
|
101 |
labels_train = get_labels()
|
src/app/interpolate_labels.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from pathlib import Path
|
2 |
import math
|
3 |
|
4 |
import numpy as np
|
@@ -12,7 +11,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
|
|
12 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
13 |
from src.models import ConditionalDecoder as cVAE
|
14 |
from src.data import get_labels_train
|
15 |
-
from src.utils import
|
16 |
|
17 |
|
18 |
device = params.device
|
@@ -31,26 +30,14 @@ def load_model(model_type: str):
|
|
31 |
g = InfoSCC_GAN(size=params.size,
|
32 |
y_size=params.shape_label,
|
33 |
z_size=params.noise_dim)
|
34 |
-
|
35 |
-
if not Path(params.path_infoscc_gan).exists():
|
36 |
-
download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
|
37 |
-
|
38 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
39 |
g.load_state_dict(ckpt['g_ema'])
|
40 |
elif model_type == 'BigGAN':
|
41 |
g = BigGAN2Generator()
|
42 |
-
|
43 |
-
if not Path(params.path_biggan).exists():
|
44 |
-
download_file(params.drive_id_biggan, params.path_biggan)
|
45 |
-
|
46 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
47 |
g.load_state_dict(ckpt)
|
48 |
elif model_type == 'cVAE':
|
49 |
g = cVAE()
|
50 |
-
|
51 |
-
if not Path(params.path_cvae).exists():
|
52 |
-
download_file(params.drive_id_cvae, params.path_cvae)
|
53 |
-
|
54 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
55 |
g.load_state_dict(ckpt)
|
56 |
else:
|
@@ -62,10 +49,6 @@ def load_model(model_type: str):
|
|
62 |
@st.cache
|
63 |
def get_labels() -> torch.Tensor:
|
64 |
path_labels = params.path_labels
|
65 |
-
|
66 |
-
if not Path(path_labels).exists():
|
67 |
-
download_file(params.drive_id_labels, path_labels)
|
68 |
-
|
69 |
labels_train = get_labels_train(path_labels)
|
70 |
return labels_train
|
71 |
|
|
|
|
|
1 |
import math
|
2 |
|
3 |
import numpy as np
|
|
|
11 |
from src.models.big.BigGAN2 import Generator as BigGAN2Generator
|
12 |
from src.models import ConditionalDecoder as cVAE
|
13 |
from src.data import get_labels_train
|
14 |
+
from src.utils import sample_labels
|
15 |
|
16 |
|
17 |
device = params.device
|
|
|
30 |
g = InfoSCC_GAN(size=params.size,
|
31 |
y_size=params.shape_label,
|
32 |
z_size=params.noise_dim)
|
|
|
|
|
|
|
|
|
33 |
ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
|
34 |
g.load_state_dict(ckpt['g_ema'])
|
35 |
elif model_type == 'BigGAN':
|
36 |
g = BigGAN2Generator()
|
|
|
|
|
|
|
|
|
37 |
ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
|
38 |
g.load_state_dict(ckpt)
|
39 |
elif model_type == 'cVAE':
|
40 |
g = cVAE()
|
|
|
|
|
|
|
|
|
41 |
ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
|
42 |
g.load_state_dict(ckpt)
|
43 |
else:
|
|
|
49 |
@st.cache
|
50 |
def get_labels() -> torch.Tensor:
|
51 |
path_labels = params.path_labels
|
|
|
|
|
|
|
|
|
52 |
labels_train = get_labels_train(path_labels)
|
53 |
return labels_train
|
54 |
|
src/app/params.py
CHANGED
@@ -12,14 +12,7 @@ n_basis = 6 # size of additional z vectors in InfoSCC-GAN
|
|
12 |
y_type = 'real' # type of labels in InfoSCC-GAN
|
13 |
dim_z = 128 # z vector size in BigGAN and cVAE
|
14 |
|
15 |
-
path_infoscc_gan = './models/InfoSCC-GAN/generator.pt'
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
drive_id_biggan = '1sMSDdnQ5GjHcno5knHTDSKAKhhoHh_4z'
|
20 |
-
|
21 |
-
path_cvae = './models/CVAE/generator.pth'
|
22 |
-
drive_id_cvae = '17FmLvhwXq8PQMrD1CtjqyoAy5BobYMTE'
|
23 |
-
|
24 |
-
path_labels = './data/training_solutions_rev1.csv'
|
25 |
-
drive_id_labels = '1dzsB_HdGtmSHE4pCppamISpFaJBfPF7E'
|
|
|
12 |
y_type = 'real' # type of labels in InfoSCC-GAN
|
13 |
dim_z = 128 # z vector size in BigGAN and cVAE
|
14 |
|
15 |
+
path_infoscc_gan = './galaxy-zoo-generation/models/InfoSCC-GAN/generator.pt'
|
16 |
+
path_biggan = './galaxy-zoo-generation/models/BigGAN/generator.pth'
|
17 |
+
path_cvae = './galaxy-zoo-generation/models/CVAE/generator.pth'
|
18 |
+
path_labels = './galaxy-zoo-generation/data/training_solutions_rev1.csv'
|
|
|
|
|
|
|
|
|
|
|
|
|
|