r3gm commited on
Commit
c71cd1c
·
verified ·
1 Parent(s): 1eb4ae4

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +269 -270
utils.py CHANGED
@@ -1,270 +1,269 @@
1
-
2
- import os
3
- import re
4
- import gradio as gr
5
- from constants import (
6
- DIFFUSERS_FORMAT_LORAS,
7
- CIVITAI_API_KEY,
8
- HF_TOKEN,
9
- MODEL_TYPE_CLASS,
10
- DIRECTORY_LORAS,
11
- )
12
- from huggingface_hub import HfApi
13
- from diffusers import DiffusionPipeline
14
- from huggingface_hub import model_info as model_info_data
15
- from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
16
- from pathlib import PosixPath
17
-
18
-
19
- def download_things(directory, url, hf_token="", civitai_api_key=""):
20
- url = url.strip()
21
-
22
- if "drive.google.com" in url:
23
- original_dir = os.getcwd()
24
- os.chdir(directory)
25
- os.system(f"gdown --fuzzy {url}")
26
- os.chdir(original_dir)
27
- elif "huggingface.co" in url:
28
- url = url.replace("?download=true", "")
29
- # url = urllib.parse.quote(url, safe=':/') # fix encoding
30
- if "/blob/" in url:
31
- url = url.replace("/blob/", "/resolve/")
32
- user_header = f'"Authorization: Bearer {hf_token}"'
33
- if hf_token:
34
- os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
35
- else:
36
- os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
37
- elif "civitai.com" in url:
38
- if "?" in url:
39
- url = url.split("?")[0]
40
- if civitai_api_key:
41
- url = url + f"?token={civitai_api_key}"
42
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
43
- else:
44
- print("\033[91mYou need an API key to download Civitai models.\033[0m")
45
- else:
46
- os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
47
-
48
-
49
- def get_model_list(directory_path):
50
- model_list = []
51
- valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
52
-
53
- for filename in os.listdir(directory_path):
54
- if os.path.splitext(filename)[1] in valid_extensions:
55
- # name_without_extension = os.path.splitext(filename)[0]
56
- file_path = os.path.join(directory_path, filename)
57
- # model_list.append((name_without_extension, file_path))
58
- model_list.append(file_path)
59
- print('\033[34mFILE: ' + file_path + '\033[0m')
60
- return model_list
61
-
62
-
63
- def extract_parameters(input_string):
64
- parameters = {}
65
- input_string = input_string.replace("\n", "")
66
-
67
- if "Negative prompt:" not in input_string:
68
- if "Steps:" in input_string:
69
- input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
70
- else:
71
- print("Invalid metadata")
72
- parameters["prompt"] = input_string
73
- return parameters
74
-
75
- parm = input_string.split("Negative prompt:")
76
- parameters["prompt"] = parm[0].strip()
77
- if "Steps:" not in parm[1]:
78
- print("Steps not detected")
79
- parameters["neg_prompt"] = parm[1].strip()
80
- return parameters
81
- parm = parm[1].split("Steps:")
82
- parameters["neg_prompt"] = parm[0].strip()
83
- input_string = "Steps:" + parm[1]
84
-
85
- # Extracting Steps
86
- steps_match = re.search(r'Steps: (\d+)', input_string)
87
- if steps_match:
88
- parameters['Steps'] = int(steps_match.group(1))
89
-
90
- # Extracting Size
91
- size_match = re.search(r'Size: (\d+x\d+)', input_string)
92
- if size_match:
93
- parameters['Size'] = size_match.group(1)
94
- width, height = map(int, parameters['Size'].split('x'))
95
- parameters['width'] = width
96
- parameters['height'] = height
97
-
98
- # Extracting other parameters
99
- other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
100
- for param in other_parameters:
101
- parameters[param[0]] = param[1].strip('"')
102
-
103
- return parameters
104
-
105
-
106
- def get_my_lora(link_url):
107
- for url in [url.strip() for url in link_url.split(',')]:
108
- if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
109
- download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
110
- new_lora_model_list = get_model_list(DIRECTORY_LORAS)
111
- new_lora_model_list.insert(0, "None")
112
- new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
113
-
114
- return gr.update(
115
- choices=new_lora_model_list
116
- ), gr.update(
117
- choices=new_lora_model_list
118
- ), gr.update(
119
- choices=new_lora_model_list
120
- ), gr.update(
121
- choices=new_lora_model_list
122
- ), gr.update(
123
- choices=new_lora_model_list
124
- ),
125
-
126
-
127
- def info_html(json_data, title, subtitle):
128
- return f"""
129
- <div style='padding: 0; border-radius: 10px;'>
130
- <p style='margin: 0; font-weight: bold;'>{title}</p>
131
- <details>
132
- <summary>Details</summary>
133
- <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
134
- </details>
135
- </div>
136
- """
137
-
138
-
139
- def get_model_type(repo_id: str):
140
- api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
141
- default = "SD 1.5"
142
- try:
143
- model = api.model_info(repo_id=repo_id, timeout=5.0)
144
- tags = model.tags
145
- for tag in tags:
146
- if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
147
- except Exception:
148
- return default
149
- return default
150
-
151
-
152
- def restart_space(repo_id: str, factory_reboot: bool, token: str):
153
- api = HfApi(token=token)
154
- api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
155
-
156
-
157
- def extract_exif_data(image):
158
- if image is None: return ""
159
-
160
- try:
161
- metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
162
-
163
- for key in metadata_keys:
164
- if key in image.info:
165
- return image.info[key]
166
-
167
- return str(image.info)
168
-
169
- except Exception as e:
170
- return f"Error extracting metadata: {str(e)}"
171
-
172
-
173
- def create_mask_now(img, invert):
174
- import numpy as np
175
- import time
176
-
177
- time.sleep(0.5)
178
-
179
- transparent_image = img["layers"][0]
180
-
181
- # Extract the alpha channel
182
- alpha_channel = np.array(transparent_image)[:, :, 3]
183
-
184
- # Create a binary mask by thresholding the alpha channel
185
- binary_mask = alpha_channel > 1
186
-
187
- if invert:
188
- print("Invert")
189
- # Invert the binary mask so that the drawn shape is white and the rest is black
190
- binary_mask = np.invert(binary_mask)
191
-
192
- # Convert the binary mask to a 3-channel RGB mask
193
- rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
194
-
195
- # Convert the mask to uint8
196
- rgb_mask = rgb_mask.astype(np.uint8) * 255
197
-
198
- return img["background"], rgb_mask
199
-
200
-
201
- def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
202
-
203
- variant = None
204
- if token is True and not os.environ.get("HF_TOKEN"):
205
- token = None
206
-
207
- if model_type == "SDXL":
208
- info = model_info_data(
209
- repo_name,
210
- token=token,
211
- revision=revision,
212
- timeout=5.0,
213
- )
214
-
215
- filenames = {sibling.rfilename for sibling in info.siblings}
216
- model_filenames, variant_filenames = variant_compatible_siblings(
217
- filenames, variant="fp16"
218
- )
219
-
220
- if len(variant_filenames):
221
- variant = "fp16"
222
-
223
- cached_folder = DiffusionPipeline.download(
224
- pretrained_model_name=repo_name,
225
- force_download=False,
226
- token=token,
227
- revision=revision,
228
- # mirror="https://hf-mirror.com",
229
- variant=variant,
230
- use_safetensors=True,
231
- trust_remote_code=False,
232
- timeout=5.0,
233
- )
234
-
235
- if isinstance(cached_folder, PosixPath):
236
- cached_folder = cached_folder.as_posix()
237
-
238
- # Task model
239
- # from huggingface_hub import hf_hub_download
240
- # hf_hub_download(
241
- # task_model,
242
- # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
243
- # )
244
-
245
- return cached_folder
246
-
247
-
248
- def progress_step_bar(step, total):
249
- # Calculate the percentage for the progress bar width
250
- percentage = min(100, ((step / total) * 100))
251
-
252
- return f"""
253
- <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
254
- <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
255
- <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
256
- {int(percentage)}%
257
- </div>
258
- </div>
259
- """
260
-
261
-
262
- def html_template_message(msg):
263
- return f"""
264
- <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
265
- <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
266
- <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
267
- {msg}
268
- </div>
269
- </div>
270
- """
 
1
+ import os
2
+ import re
3
+ import gradio as gr
4
+ from constants import (
5
+ DIFFUSERS_FORMAT_LORAS,
6
+ CIVITAI_API_KEY,
7
+ HF_TOKEN,
8
+ MODEL_TYPE_CLASS,
9
+ DIRECTORY_LORAS,
10
+ )
11
+ from huggingface_hub import HfApi
12
+ from diffusers import DiffusionPipeline
13
+ from huggingface_hub import model_info as model_info_data
14
+ from diffusers.pipelines.pipeline_loading_utils import variant_compatible_siblings
15
+ from pathlib import PosixPath
16
+
17
+
18
+ def download_things(directory, url, hf_token="", civitai_api_key=""):
19
+ url = url.strip()
20
+
21
+ if "drive.google.com" in url:
22
+ original_dir = os.getcwd()
23
+ os.chdir(directory)
24
+ os.system(f"gdown --fuzzy {url}")
25
+ os.chdir(original_dir)
26
+ elif "huggingface.co" in url:
27
+ url = url.replace("?download=true", "")
28
+ # url = urllib.parse.quote(url, safe=':/') # fix encoding
29
+ if "/blob/" in url:
30
+ url = url.replace("/blob/", "/resolve/")
31
+ user_header = f'"Authorization: Bearer {hf_token}"'
32
+ if hf_token:
33
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
34
+ else:
35
+ os.system(f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}")
36
+ elif "civitai.com" in url:
37
+ if "?" in url:
38
+ url = url.split("?")[0]
39
+ if civitai_api_key:
40
+ url = url + f"?token={civitai_api_key}"
41
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
42
+ else:
43
+ print("\033[91mYou need an API key to download Civitai models.\033[0m")
44
+ else:
45
+ os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}")
46
+
47
+
48
+ def get_model_list(directory_path):
49
+ model_list = []
50
+ valid_extensions = {'.ckpt', '.pt', '.pth', '.safetensors', '.bin'}
51
+
52
+ for filename in os.listdir(directory_path):
53
+ if os.path.splitext(filename)[1] in valid_extensions:
54
+ # name_without_extension = os.path.splitext(filename)[0]
55
+ file_path = os.path.join(directory_path, filename)
56
+ # model_list.append((name_without_extension, file_path))
57
+ model_list.append(file_path)
58
+ print('\033[34mFILE: ' + file_path + '\033[0m')
59
+ return model_list
60
+
61
+
62
+ def extract_parameters(input_string):
63
+ parameters = {}
64
+ input_string = input_string.replace("\n", "")
65
+
66
+ if "Negative prompt:" not in input_string:
67
+ if "Steps:" in input_string:
68
+ input_string = input_string.replace("Steps:", "Negative prompt: Steps:")
69
+ else:
70
+ print("Invalid metadata")
71
+ parameters["prompt"] = input_string
72
+ return parameters
73
+
74
+ parm = input_string.split("Negative prompt:")
75
+ parameters["prompt"] = parm[0].strip()
76
+ if "Steps:" not in parm[1]:
77
+ print("Steps not detected")
78
+ parameters["neg_prompt"] = parm[1].strip()
79
+ return parameters
80
+ parm = parm[1].split("Steps:")
81
+ parameters["neg_prompt"] = parm[0].strip()
82
+ input_string = "Steps:" + parm[1]
83
+
84
+ # Extracting Steps
85
+ steps_match = re.search(r'Steps: (\d+)', input_string)
86
+ if steps_match:
87
+ parameters['Steps'] = int(steps_match.group(1))
88
+
89
+ # Extracting Size
90
+ size_match = re.search(r'Size: (\d+x\d+)', input_string)
91
+ if size_match:
92
+ parameters['Size'] = size_match.group(1)
93
+ width, height = map(int, parameters['Size'].split('x'))
94
+ parameters['width'] = width
95
+ parameters['height'] = height
96
+
97
+ # Extracting other parameters
98
+ other_parameters = re.findall(r'(\w+): (.*?)(?=, \w+|$)', input_string)
99
+ for param in other_parameters:
100
+ parameters[param[0]] = param[1].strip('"')
101
+
102
+ return parameters
103
+
104
+
105
+ def get_my_lora(link_url):
106
+ for url in [url.strip() for url in link_url.split(',')]:
107
+ if not os.path.exists(f"./loras/{url.split('/')[-1]}"):
108
+ download_things(DIRECTORY_LORAS, url, HF_TOKEN, CIVITAI_API_KEY)
109
+ new_lora_model_list = get_model_list(DIRECTORY_LORAS)
110
+ new_lora_model_list.insert(0, "None")
111
+ new_lora_model_list = new_lora_model_list + DIFFUSERS_FORMAT_LORAS
112
+
113
+ return gr.update(
114
+ choices=new_lora_model_list
115
+ ), gr.update(
116
+ choices=new_lora_model_list
117
+ ), gr.update(
118
+ choices=new_lora_model_list
119
+ ), gr.update(
120
+ choices=new_lora_model_list
121
+ ), gr.update(
122
+ choices=new_lora_model_list
123
+ ),
124
+
125
+
126
+ def info_html(json_data, title, subtitle):
127
+ return f"""
128
+ <div style='padding: 0; border-radius: 10px;'>
129
+ <p style='margin: 0; font-weight: bold;'>{title}</p>
130
+ <details>
131
+ <summary>Details</summary>
132
+ <p style='margin: 0; font-weight: bold;'>{subtitle}</p>
133
+ </details>
134
+ </div>
135
+ """
136
+
137
+
138
+ def get_model_type(repo_id: str):
139
+ api = HfApi(token=os.environ.get("HF_TOKEN")) # if use private or gated model
140
+ default = "SD 1.5"
141
+ try:
142
+ model = api.model_info(repo_id=repo_id, timeout=5.0)
143
+ tags = model.tags
144
+ for tag in tags:
145
+ if tag in MODEL_TYPE_CLASS.keys(): return MODEL_TYPE_CLASS.get(tag, default)
146
+ except Exception:
147
+ return default
148
+ return default
149
+
150
+
151
+ def restart_space(repo_id: str, factory_reboot: bool, token: str):
152
+ api = HfApi(token=token)
153
+ api.restart_space(repo_id=repo_id, factory_reboot=factory_reboot)
154
+
155
+
156
+ def extract_exif_data(image):
157
+ if image is None: return ""
158
+
159
+ try:
160
+ metadata_keys = ['parameters', 'metadata', 'prompt', 'Comment']
161
+
162
+ for key in metadata_keys:
163
+ if key in image.info:
164
+ return image.info[key]
165
+
166
+ return str(image.info)
167
+
168
+ except Exception as e:
169
+ return f"Error extracting metadata: {str(e)}"
170
+
171
+
172
+ def create_mask_now(img, invert):
173
+ import numpy as np
174
+ import time
175
+
176
+ time.sleep(0.5)
177
+
178
+ transparent_image = img["layers"][0]
179
+
180
+ # Extract the alpha channel
181
+ alpha_channel = np.array(transparent_image)[:, :, 3]
182
+
183
+ # Create a binary mask by thresholding the alpha channel
184
+ binary_mask = alpha_channel > 1
185
+
186
+ if invert:
187
+ print("Invert")
188
+ # Invert the binary mask so that the drawn shape is white and the rest is black
189
+ binary_mask = np.invert(binary_mask)
190
+
191
+ # Convert the binary mask to a 3-channel RGB mask
192
+ rgb_mask = np.stack((binary_mask,) * 3, axis=-1)
193
+
194
+ # Convert the mask to uint8
195
+ rgb_mask = rgb_mask.astype(np.uint8) * 255
196
+
197
+ return img["background"], rgb_mask
198
+
199
+
200
+ def download_diffuser_repo(repo_name: str, model_type: str, revision: str = "main", token=True):
201
+
202
+ variant = None
203
+ if token is True and not os.environ.get("HF_TOKEN"):
204
+ token = None
205
+
206
+ if model_type == "SDXL":
207
+ info = model_info_data(
208
+ repo_name,
209
+ token=token,
210
+ revision=revision,
211
+ timeout=5.0,
212
+ )
213
+
214
+ filenames = {sibling.rfilename for sibling in info.siblings}
215
+ model_filenames, variant_filenames = variant_compatible_siblings(
216
+ filenames, variant="fp16"
217
+ )
218
+
219
+ if len(variant_filenames):
220
+ variant = "fp16"
221
+
222
+ cached_folder = DiffusionPipeline.download(
223
+ pretrained_model_name=repo_name,
224
+ force_download=False,
225
+ token=token,
226
+ revision=revision,
227
+ # mirror="https://hf-mirror.com",
228
+ variant=variant,
229
+ use_safetensors=True,
230
+ trust_remote_code=False,
231
+ timeout=5.0,
232
+ )
233
+
234
+ if isinstance(cached_folder, PosixPath):
235
+ cached_folder = cached_folder.as_posix()
236
+
237
+ # Task model
238
+ # from huggingface_hub import hf_hub_download
239
+ # hf_hub_download(
240
+ # task_model,
241
+ # filename="diffusion_pytorch_model.safetensors", # fix fp16 variant
242
+ # )
243
+
244
+ return cached_folder
245
+
246
+
247
+ def progress_step_bar(step, total):
248
+ # Calculate the percentage for the progress bar width
249
+ percentage = min(100, ((step / total) * 100))
250
+
251
+ return f"""
252
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
253
+ <div style="width: {percentage}%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
254
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 13px;">
255
+ {int(percentage)}%
256
+ </div>
257
+ </div>
258
+ """
259
+
260
+
261
+ def html_template_message(msg):
262
+ return f"""
263
+ <div style="position: relative; width: 100%; background-color: gray; border-radius: 5px; overflow: hidden;">
264
+ <div style="width: 0%; height: 17px; background-color: #800080; transition: width 0.5s;"></div>
265
+ <div style="position: absolute; width: 100%; text-align: center; color: white; top: 0; line-height: 19px; font-size: 14px; font-weight: bold; text-shadow: 1px 1px 2px black;">
266
+ {msg}
267
+ </div>
268
+ </div>
269
+ """