Evel anzorq commited on
Commit
ca957fb
·
0 Parent(s):

Duplicate from anzorq/sd-to-diffusers

Browse files

Co-authored-by: AQ <[email protected]>

Files changed (6) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +170 -0
  4. hf_utils.py +50 -0
  5. requirements.txt +7 -0
  6. utils.py +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD To Diffusers
3
+ emoji: 🎨➡️🧨
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 3.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: anzorq/sd-to-diffusers
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from huggingface_hub import HfApi, upload_folder
4
+ import gradio as gr
5
+ import hf_utils
6
+ import utils
7
+
8
+ subprocess.run(["git", "clone", "https://github.com/huggingface/diffusers.git", "diffs"])
9
+
10
+ def error_str(error, title="Error"):
11
+ return f"""#### {title}
12
+ {error}""" if error else ""
13
+
14
+ def on_token_change(token):
15
+ model_names, error = hf_utils.get_my_model_names(token)
16
+ if model_names:
17
+ model_names.append("Other")
18
+
19
+ return gr.update(visible=bool(model_names)), gr.update(choices=model_names, value=model_names[0] if model_names else None), gr.update(visible=bool(model_names)), gr.update(value=error_str(error))
20
+
21
+ def url_to_model_id(model_id_str):
22
+ return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1] if model_id_str.startswith("https://huggingface.co/") else model_id_str
23
+
24
+ def get_ckpt_names(token, radio_model_names, input_model):
25
+
26
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
27
+
28
+ if token == "" or model_id == "":
29
+ return error_str("Please enter both a token and a model name.", title="Invalid input"), gr.update(choices=[]), gr.update(visible=False)
30
+
31
+ try:
32
+ api = HfApi(token=token)
33
+ ckpt_files = [f for f in api.list_repo_files(repo_id=model_id) if f.endswith(".ckpt")]
34
+
35
+ if not ckpt_files:
36
+ return error_str("No checkpoint files found in the model repo."), gr.update(choices=[]), gr.update(visible=False)
37
+
38
+ return None, gr.update(choices=ckpt_files, value=ckpt_files[0], visible=True), gr.update(visible=True)
39
+
40
+ except Exception as e:
41
+ return error_str(e), gr.update(choices=[]), None
42
+
43
+ def convert_and_push(radio_model_names, input_model, ckpt_name, token):
44
+
45
+ model_id = url_to_model_id(input_model) if radio_model_names == "Other" else radio_model_names
46
+
47
+ try:
48
+ model_id = url_to_model_id(model_id)
49
+
50
+ # 1. Download the checkpoint file
51
+ ckpt_path, revision = hf_utils.download_file(repo_id=model_id, filename=ckpt_name, token=token)
52
+
53
+ # 2. Run the conversion script
54
+ os.makedirs(model_id)
55
+ subprocess.run(
56
+ [
57
+ "python3",
58
+ "./diffs/scripts/convert_original_stable_diffusion_to_diffusers.py",
59
+ "--checkpoint_path",
60
+ ckpt_path,
61
+ "--dump_path" ,
62
+ model_id,
63
+ ]
64
+ )
65
+
66
+ # 3. Push to the model repo
67
+ commit_message="Add Diffusers weights"
68
+ upload_folder(
69
+ folder_path=model_id,
70
+ repo_id=model_id,
71
+ token=token,
72
+ create_pr=True,
73
+ commit_message=commit_message,
74
+ commit_description=f"Add Diffusers weights converted from checkpoint `{ckpt_name}` in revision {revision}",
75
+ )
76
+
77
+ # # 4. Delete the downloaded checkpoint file, yaml files, and the converted model folder
78
+ hf_utils.delete_file(revision)
79
+ subprocess.run(["rm", "-rf", model_id.split('/')[0]])
80
+ import glob
81
+ for f in glob.glob("*.yaml*"):
82
+ subprocess.run(["rm", "-rf", f])
83
+
84
+ return f"""Successfully converted the checkpoint and opened a PR to add the weights to the model repo.
85
+ You can view and merge the PR [here]({hf_utils.get_pr_url(HfApi(token=token), model_id, commit_message)})."""
86
+
87
+ except Exception as e:
88
+ return error_str(e)
89
+
90
+
91
+ DESCRIPTION = """### Convert a stable diffusion checkpoint to Diffusers🧨
92
+ With this space, you can easily convert a CompVis stable diffusion checkpoint to Diffusers and automatically create a pull request to the model repo.
93
+ You can choose to convert a checkpoint from one of your own models, or from any other model on the Hub.
94
+ You can skip the queue by running the app in the colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/qunash/f0f3152c5851c0c477b68b7b98d547fe/convert-sd-to-diffusers.ipynb)"""
95
+
96
+ with gr.Blocks() as demo:
97
+
98
+ gr.Markdown(DESCRIPTION)
99
+ with gr.Row():
100
+
101
+ with gr.Column(scale=11):
102
+ with gr.Column():
103
+ gr.Markdown("## 1. Load model info")
104
+ input_token = gr.Textbox(
105
+ max_lines=1,
106
+ label="Enter your Hugging Face token",
107
+ placeholder="READ permission is enough",
108
+ )
109
+ gr.Markdown("You can get a token [here](https://huggingface.co/settings/tokens)")
110
+ with gr.Group(visible=False) as group_model:
111
+ radio_model_names = gr.Radio(label="Choose a model")
112
+ input_model = gr.Textbox(
113
+ max_lines=1,
114
+ label="Model name or URL",
115
+ placeholder="username/model_name",
116
+ visible=False,
117
+ )
118
+
119
+ btn_get_ckpts = gr.Button("Load", visible=False)
120
+
121
+ with gr.Column(scale=10):
122
+ with gr.Column(visible=False) as group_convert:
123
+ gr.Markdown("## 2. Convert to Diffusers🧨")
124
+ radio_ckpts = gr.Radio(label="Choose the checkpoint to convert", visible=False)
125
+ gr.Markdown("Conversion may take a few minutes.")
126
+ btn_convert = gr.Button("Convert & Push")
127
+
128
+ error_output = gr.Markdown(label="Output")
129
+
130
+ input_token.change(
131
+ fn=on_token_change,
132
+ inputs=input_token,
133
+ outputs=[group_model, radio_model_names, btn_get_ckpts, error_output],
134
+ queue=False,
135
+ scroll_to_output=True)
136
+
137
+ radio_model_names.change(
138
+ lambda x: gr.update(visible=x == "Other"),
139
+ inputs=radio_model_names,
140
+ outputs=input_model,
141
+ queue=False,
142
+ scroll_to_output=True)
143
+
144
+ btn_get_ckpts.click(
145
+ fn=get_ckpt_names,
146
+ inputs=[input_token, radio_model_names, input_model],
147
+ outputs=[error_output, radio_ckpts, group_convert],
148
+ scroll_to_output=True,
149
+ queue=False
150
+ )
151
+
152
+ btn_convert.click(
153
+ fn=convert_and_push,
154
+ inputs=[radio_model_names, input_model, radio_ckpts, input_token],
155
+ outputs=error_output,
156
+ scroll_to_output=True
157
+ )
158
+
159
+ # gr.Markdown("""<img src="https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/imgs/diffusers_library.jpg" width="150"/>""")
160
+ gr.HTML("""
161
+ <div style="border-top: 1px solid #303030;">
162
+ <br>
163
+ <p>Space by: <a href="https://twitter.com/hahahahohohe"><img src="https://img.shields.io/twitter/follow/hahahahohohe?label=%40anzorq&style=social" alt="Twitter Follow"></a></p><br>
164
+ <a href="https://www.buymeacoffee.com/anzorq" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" style="height: 45px !important;width: 162px !important;" ></a><br><br>
165
+ <p><img src="https://visitor-badge.glitch.me/badge?page_id=anzorq.sd-to-diffusers" alt="visitors"></p>
166
+ </div>
167
+ """)
168
+
169
+ demo.queue()
170
+ demo.launch(share=utils.is_google_colab())
hf_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import get_hf_file_metadata, hf_hub_url, hf_hub_download, scan_cache_dir, whoami, list_models
2
+
3
+
4
+ def get_my_model_names(token):
5
+
6
+ try:
7
+ author = whoami(token=token)
8
+ model_infos = list_models(author=author["name"], use_auth_token=token)
9
+ return [model.modelId for model in model_infos], None
10
+
11
+ except Exception as e:
12
+ return [], e
13
+
14
+ def download_file(repo_id: str, filename: str, token: str):
15
+ """Download a file from a repo on the Hugging Face Hub.
16
+
17
+ Returns:
18
+ file_path (:obj:`str`): The path to the downloaded file.
19
+ revision (:obj:`str`): The commit hash of the file.
20
+ """
21
+
22
+ md = get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename=filename), token=token)
23
+ revision = md.commit_hash
24
+
25
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, revision=revision, token=token)
26
+
27
+ return file_path, revision
28
+
29
+ def delete_file(revision: str):
30
+ """Delete a file from local cache.
31
+
32
+ Args:
33
+ revision (:obj:`str`): The commit hash of the file.
34
+ Returns:
35
+ None
36
+ """
37
+ scan_cache_dir().delete_revisions(revision).execute()
38
+
39
+ def get_pr_url(api, repo_id, title):
40
+ try:
41
+ discussions = api.get_repo_discussions(repo_id=repo_id)
42
+ except Exception:
43
+ return None
44
+ for discussion in discussions:
45
+ if (
46
+ discussion.status == "open"
47
+ and discussion.is_pull_request
48
+ and discussion.title == title
49
+ ):
50
+ return f"https://huggingface.co/{repo_id}/discussions/{discussion.num}"
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/huggingface_hub@main
2
+ git+https://github.com/huggingface/diffusers.git
3
+ torch
4
+ transformers
5
+ pytorch_lightning
6
+ OmegaConf
7
+ ftfy
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False