Spaces:
Sleeping
Sleeping
File size: 3,888 Bytes
1b621fc 912e4bf b521cf1 d740066 6efcc11 495558a 15ad924 b521cf1 912e4bf dcdad58 8e39690 dd47c8c 379abaf 912e4bf 379abaf 6897b6a 1b621fc 6897b6a 379abaf 1b621fc e47add2 12cd0d5 9c73f19 3052ca7 379abaf 6897b6a 1b621fc 379abaf 12cd0d5 3052ca7 12cd0d5 379abaf 0e69513 379abaf 6a1f2ac 3052ca7 dcdad58 379abaf 0e69513 912e4bf 0e69513 6efcc11 0e69513 912e4bf 379abaf 0e69513 ea54467 379abaf 4dda767 379abaf 0e69513 3052ca7 dcdad58 6a1f2ac 0e69513 6a1f2ac 1b621fc 912e4bf 5decfa4 912e4bf 5decfa4 6efcc11 379abaf 6efcc11 4dda767 0f2e02c 5decfa4 0f2e02c 5decfa4 0f2e02c 379abaf 5decfa4 379abaf 9151d79 379abaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import gradio as gr
from huggingface_hub import HfApi
import spaces
import shutil
import logging
import subprocess
from pathlib import Path
@spaces.GPU
def write_repo(base_model, model_to_merge):
with open("repo.txt", "w") as repo:
repo.write(base_model + "\n" + model_to_merge)
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token):
# Define a fixed output path
outpath = Path('/tmp/output')
if outpath.exists() and outpath.is_dir():
shutil.rmtree(outpath)
write_repo(base_model, model_to_merge)
# Construct the command to run hf_merge.py
command = [
"python3", "hf_merge.py",
"-p", str(weight_drop_prob),
"-lambda", str(scaling_factor),
"repo.txt", str(outpath)
]
# Set up logging
logging.basicConfig(level=logging.INFO)
log_output = ""
# Run the command and capture the output
result = subprocess.run(command, capture_output=True, text=True)
# Log the output
log_output += result.stdout + "\n"
log_output += result.stderr + "\n"
logging.info(result.stdout)
logging.error(result.stderr)
# Check if the merge was successful
if result.returncode != 0:
return None, f"Error in merging models: {result.stderr}", log_output
# Update progress bar
yield 0.5, "Merging completed. Uploading to Hugging Face Hub..."
# Upload the result to Hugging Face Hub
api = HfApi(token=token)
try:
# Get the username of the user who is logged in
user = api.whoami(token=token)["name"]
# Autofill the repo name if none is provided
if not repo_name:
repo_name = f"{user}/default-repo"
# Create a new repo or update an existing one
api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
# Upload the file
api.upload_folder(
folder_path=str(outpath),
repo_id=repo_name,
repo_type="model",
token=token
)
repo_url = f"https://huggingface.co/{repo_name}"
yield 1.0, "Upload completed."
return repo_url, "Model merged and uploaded successfully!", log_output
except Exception as e:
return None, f"Error uploading to Hugging Face Hub: {str(e)}", log_output
# Define the Gradio interface
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
gr.Markdown("# Model Merger and Uploader")
gr.Markdown("Combine any two models using a Super Mario merge(DARE) as described in the linked whitepaper.")
gr.Markdown("Works with:")
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
gr.Markdown("* LLMs (Mistral, Llama, etc)")
gr.Markdown("* LoRas (must be same size)")
gr.Markdown("* Any two homologous models")
with gr.Column():
with gr.Row():
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
with gr.Row():
base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
with gr.Row():
model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
with gr.Row():
repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
with gr.Row():
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
with gr.Row():
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
repo_url = gr.Markdown(label="Repository URL")
gr.Button("Merge").click(
merge_and_upload,
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token],
outputs=[repo_url]
)
demo.launch()
|