vitaliykinakh commited on
Commit
d872920
Β·
1 Parent(s): d6b8e04

Add model weights as submodule

Browse files
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "galaxy-zoo-generation"]
2
+ path = galaxy-zoo-generation
3
+ url = https://huggingface.co/vitaliykinakh/galaxy-zoo-generation
galaxy-zoo-generation ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit af017a826b2be3dec5f364ac4e232ede6cc0e04f
src/app/compare_models.py CHANGED
@@ -1,4 +1,3 @@
1
- from pathlib import Path
2
  import math
3
 
4
  import streamlit as st
@@ -14,7 +13,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
14
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
15
  from src.models import ConditionalDecoder as cVAE
16
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
17
- from src.utils import download_file, sample_labels
18
 
19
 
20
  device = params.device
@@ -76,25 +75,14 @@ def load_model(model_type: str):
76
  y_size=params.shape_label,
77
  z_size=params.noise_dim)
78
 
79
- if not Path(params.path_infoscc_gan).exists():
80
- download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
81
-
82
  ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
83
  g.load_state_dict(ckpt['g_ema'])
84
  elif model_type == 'BigGAN':
85
  g = BigGAN2Generator()
86
-
87
- if not Path(params.path_biggan).exists():
88
- download_file(params.drive_id_biggan, params.path_biggan)
89
-
90
  ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
91
  g.load_state_dict(ckpt)
92
  elif model_type == 'cVAE':
93
  g = cVAE()
94
-
95
- if not Path(params.path_cvae).exists():
96
- download_file(params.drive_id_cvae, params.path_cvae)
97
-
98
  ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
99
  g.load_state_dict(ckpt)
100
  else:
@@ -107,9 +95,6 @@ def load_model(model_type: str):
107
  def get_labels() -> torch.Tensor:
108
  path_labels = params.path_labels
109
 
110
- if not Path(path_labels).exists():
111
- download_file(params.drive_id_labels, path_labels)
112
-
113
  labels_train = get_labels_train(path_labels)
114
  return labels_train
115
 
 
 
1
  import math
2
 
3
  import streamlit as st
 
13
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
14
  from src.models import ConditionalDecoder as cVAE
15
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
16
+ from src.utils import sample_labels
17
 
18
 
19
  device = params.device
 
75
  y_size=params.shape_label,
76
  z_size=params.noise_dim)
77
 
 
 
 
78
  ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
79
  g.load_state_dict(ckpt['g_ema'])
80
  elif model_type == 'BigGAN':
81
  g = BigGAN2Generator()
 
 
 
 
82
  ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
83
  g.load_state_dict(ckpt)
84
  elif model_type == 'cVAE':
85
  g = cVAE()
 
 
 
 
86
  ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
87
  g.load_state_dict(ckpt)
88
  else:
 
95
  def get_labels() -> torch.Tensor:
96
  path_labels = params.path_labels
97
 
 
 
 
98
  labels_train = get_labels_train(path_labels)
99
  return labels_train
100
 
src/app/explore_biggan.py CHANGED
@@ -1,5 +1,4 @@
1
  import math
2
- from pathlib import Path
3
 
4
  import streamlit as st
5
  import numpy as np
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
12
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
14
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
- from src.utils import download_file, sample_labels
16
 
17
 
18
  # global parameters
@@ -25,7 +24,6 @@ dim_z = params.dim_z
25
  bs = 16 # number of samples to generate
26
  n_cols = int(math.sqrt(bs))
27
  model_path = params.path_biggan
28
- drive_id = params.drive_id_biggan
29
  path_labels = params.path_labels
30
 
31
  # manual labels
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
90
 
91
  @st.cache
92
  def get_labels() -> torch.Tensor:
93
- if not Path(path_labels).exists():
94
- download_file(params.drive_id_labels, path_labels)
95
-
96
  labels_train = get_labels_train(path_labels)
97
  return labels_train
98
 
@@ -102,10 +97,6 @@ def app():
102
 
103
  st.title('Explore BigGAN')
104
  st.markdown('This demo shows BigGAN for conditional galaxy generation')
105
-
106
- if not Path(model_path).exists():
107
- download_file(drive_id, model_path)
108
-
109
  model = load_model(model_path)
110
  eps = get_eps(bs)
111
  labels_train = get_labels()
 
1
  import math
 
2
 
3
  import streamlit as st
4
  import numpy as np
 
11
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
12
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
13
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
14
+ from src.utils import sample_labels
15
 
16
 
17
  # global parameters
 
24
  bs = 16 # number of samples to generate
25
  n_cols = int(math.sqrt(bs))
26
  model_path = params.path_biggan
 
27
  path_labels = params.path_labels
28
 
29
  # manual labels
 
88
 
89
  @st.cache
90
  def get_labels() -> torch.Tensor:
 
 
 
91
  labels_train = get_labels_train(path_labels)
92
  return labels_train
93
 
 
97
 
98
  st.title('Explore BigGAN')
99
  st.markdown('This demo shows BigGAN for conditional galaxy generation')
 
 
 
 
100
  model = load_model(model_path)
101
  eps = get_eps(bs)
102
  labels_train = get_labels()
src/app/explore_cvae.py CHANGED
@@ -1,5 +1,4 @@
1
  import math
2
- from pathlib import Path
3
 
4
  import streamlit as st
5
  import numpy as np
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
12
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
  from src.models import ConditionalDecoder
14
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
- from src.utils import download_file, sample_labels
16
 
17
 
18
  # global parameters
@@ -25,7 +24,6 @@ dim_z = params.dim_z
25
  bs = 16 # number of samples to generate
26
  n_cols = int(math.sqrt(bs))
27
  model_path = params.path_cvae
28
- drive_id = params.drive_id_cvae
29
  path_labels = params.path_labels
30
 
31
  # manual labels
@@ -90,9 +88,6 @@ def get_eps(n: int) -> torch.Tensor:
90
 
91
  @st.cache
92
  def get_labels() -> torch.Tensor:
93
- if not Path(path_labels).exists():
94
- download_file(params.drive_id_labels, path_labels)
95
-
96
  labels_train = get_labels_train(path_labels)
97
  return labels_train
98
 
@@ -103,9 +98,6 @@ def app():
103
  st.title('Explore cVAE')
104
  st.markdown('This demo shows cVAE for conditional galaxy generation')
105
 
106
- if not Path(model_path).exists():
107
- download_file(drive_id, model_path)
108
-
109
  model = load_model(model_path)
110
  eps = get_eps(bs)
111
  labels_train = get_labels()
 
1
  import math
 
2
 
3
  import streamlit as st
4
  import numpy as np
 
11
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
12
  from src.models import ConditionalDecoder
13
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
14
+ from src.utils import sample_labels
15
 
16
 
17
  # global parameters
 
24
  bs = 16 # number of samples to generate
25
  n_cols = int(math.sqrt(bs))
26
  model_path = params.path_cvae
 
27
  path_labels = params.path_labels
28
 
29
  # manual labels
 
88
 
89
  @st.cache
90
  def get_labels() -> torch.Tensor:
 
 
 
91
  labels_train = get_labels_train(path_labels)
92
  return labels_train
93
 
 
98
  st.title('Explore cVAE')
99
  st.markdown('This demo shows cVAE for conditional galaxy generation')
100
 
 
 
 
101
  model = load_model(model_path)
102
  eps = get_eps(bs)
103
  labels_train = get_labels()
src/app/explore_infoscc_gan.py CHANGED
@@ -1,4 +1,3 @@
1
- from pathlib import Path
2
  import math
3
 
4
  import numpy as np
@@ -12,7 +11,7 @@ from src.app.questions import q1, q1_options, q2, q2_options, q3, q3_options, q4
12
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
13
  from src.models import ConditionalGenerator
14
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
15
- from src.utils import download_file, sample_labels
16
 
17
  # global parameters
18
  device = params.device
@@ -27,7 +26,6 @@ y_type = params.y_type
27
  bs = 16 # number of samples to generate
28
  n_cols = int(math.sqrt(bs))
29
  model_path = params.path_infoscc_gan # path to the model
30
- drive_id = params.drive_id_infoscc_gan # google drive id of the model
31
  path_labels = params.path_labels
32
 
33
  # manual labels
@@ -87,8 +85,6 @@ def load_model(model_path: str) -> ConditionalGenerator:
87
 
88
  @st.cache
89
  def get_labels() -> torch.Tensor:
90
- if not Path(path_labels).exists():
91
- download_file(params.drive_id_labels, path_labels)
92
  labels_train = get_labels_train(path_labels)
93
  return labels_train
94
 
@@ -100,9 +96,6 @@ def app():
100
  st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
101
  st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
102
 
103
- if not Path(model_path).exists():
104
- download_file(drive_id, model_path)
105
-
106
  model = load_model(model_path)
107
  eps = model.sample_eps(bs).to(device)
108
  labels_train = get_labels()
 
 
1
  import math
2
 
3
  import numpy as np
 
11
  q6, q6_options, q7, q7_options, q8, q8_options, q9, q9_options, q10, q10_options, q11, q11_options
12
  from src.models import ConditionalGenerator
13
  from src.data import get_labels_train, make_galaxy_labels_hierarchical
14
+ from src.utils import sample_labels
15
 
16
  # global parameters
17
  device = params.device
 
26
  bs = 16 # number of samples to generate
27
  n_cols = int(math.sqrt(bs))
28
  model_path = params.path_infoscc_gan # path to the model
 
29
  path_labels = params.path_labels
30
 
31
  # manual labels
 
85
 
86
  @st.cache
87
  def get_labels() -> torch.Tensor:
 
 
88
  labels_train = get_labels_train(path_labels)
89
  return labels_train
90
 
 
96
  st.markdown('This demo shows InfoSCC-GAN for conditional galaxy generation')
97
  st.subheader(r'<- Use sidebar to explore $z_1, ..., z_k$ latent variables')
98
 
 
 
 
99
  model = load_model(model_path)
100
  eps = model.sample_eps(bs).to(device)
101
  labels_train = get_labels()
src/app/interpolate_labels.py CHANGED
@@ -1,4 +1,3 @@
1
- from pathlib import Path
2
  import math
3
 
4
  import numpy as np
@@ -12,7 +11,7 @@ from src.models import ConditionalGenerator as InfoSCC_GAN
12
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
13
  from src.models import ConditionalDecoder as cVAE
14
  from src.data import get_labels_train
15
- from src.utils import download_file, sample_labels
16
 
17
 
18
  device = params.device
@@ -31,26 +30,14 @@ def load_model(model_type: str):
31
  g = InfoSCC_GAN(size=params.size,
32
  y_size=params.shape_label,
33
  z_size=params.noise_dim)
34
-
35
- if not Path(params.path_infoscc_gan).exists():
36
- download_file(params.drive_id_infoscc_gan, params.path_infoscc_gan)
37
-
38
  ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
39
  g.load_state_dict(ckpt['g_ema'])
40
  elif model_type == 'BigGAN':
41
  g = BigGAN2Generator()
42
-
43
- if not Path(params.path_biggan).exists():
44
- download_file(params.drive_id_biggan, params.path_biggan)
45
-
46
  ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
47
  g.load_state_dict(ckpt)
48
  elif model_type == 'cVAE':
49
  g = cVAE()
50
-
51
- if not Path(params.path_cvae).exists():
52
- download_file(params.drive_id_cvae, params.path_cvae)
53
-
54
  ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
55
  g.load_state_dict(ckpt)
56
  else:
@@ -62,10 +49,6 @@ def load_model(model_type: str):
62
  @st.cache
63
  def get_labels() -> torch.Tensor:
64
  path_labels = params.path_labels
65
-
66
- if not Path(path_labels).exists():
67
- download_file(params.drive_id_labels, path_labels)
68
-
69
  labels_train = get_labels_train(path_labels)
70
  return labels_train
71
 
 
 
1
  import math
2
 
3
  import numpy as np
 
11
  from src.models.big.BigGAN2 import Generator as BigGAN2Generator
12
  from src.models import ConditionalDecoder as cVAE
13
  from src.data import get_labels_train
14
+ from src.utils import sample_labels
15
 
16
 
17
  device = params.device
 
30
  g = InfoSCC_GAN(size=params.size,
31
  y_size=params.shape_label,
32
  z_size=params.noise_dim)
 
 
 
 
33
  ckpt = torch.load(params.path_infoscc_gan, map_location=torch.device('cpu'))
34
  g.load_state_dict(ckpt['g_ema'])
35
  elif model_type == 'BigGAN':
36
  g = BigGAN2Generator()
 
 
 
 
37
  ckpt = torch.load(params.path_biggan, map_location=torch.device('cpu'))
38
  g.load_state_dict(ckpt)
39
  elif model_type == 'cVAE':
40
  g = cVAE()
 
 
 
 
41
  ckpt = torch.load(params.path_cvae, map_location=torch.device('cpu'))
42
  g.load_state_dict(ckpt)
43
  else:
 
49
  @st.cache
50
  def get_labels() -> torch.Tensor:
51
  path_labels = params.path_labels
 
 
 
 
52
  labels_train = get_labels_train(path_labels)
53
  return labels_train
54
 
src/app/params.py CHANGED
@@ -12,14 +12,7 @@ n_basis = 6 # size of additional z vectors in InfoSCC-GAN
12
  y_type = 'real' # type of labels in InfoSCC-GAN
13
  dim_z = 128 # z vector size in BigGAN and cVAE
14
 
15
- path_infoscc_gan = './models/InfoSCC-GAN/generator.pt'
16
- drive_id_infoscc_gan = '1_kIujc497OH0ZJ7PNPwS5_otNlS7jMLI'
17
-
18
- path_biggan = './models/BigGAN/generator.pth'
19
- drive_id_biggan = '1sMSDdnQ5GjHcno5knHTDSKAKhhoHh_4z'
20
-
21
- path_cvae = './models/CVAE/generator.pth'
22
- drive_id_cvae = '17FmLvhwXq8PQMrD1CtjqyoAy5BobYMTE'
23
-
24
- path_labels = './data/training_solutions_rev1.csv'
25
- drive_id_labels = '1dzsB_HdGtmSHE4pCppamISpFaJBfPF7E'
 
12
  y_type = 'real' # type of labels in InfoSCC-GAN
13
  dim_z = 128 # z vector size in BigGAN and cVAE
14
 
15
+ path_infoscc_gan = './galaxy-zoo-generation/models/InfoSCC-GAN/generator.pt'
16
+ path_biggan = './galaxy-zoo-generation/models/BigGAN/generator.pth'
17
+ path_cvae = './galaxy-zoo-generation/models/CVAE/generator.pth'
18
+ path_labels = './galaxy-zoo-generation/data/training_solutions_rev1.csv'