mrcuddle's picture
Update app.py
9151d79 verified
raw
history blame
3.89 kB
import gradio as gr
from huggingface_hub import HfApi
import spaces
import shutil
import logging
import subprocess
from pathlib import Path
@spaces.GPU
def write_repo(base_model, model_to_merge):
with open("repo.txt", "w") as repo:
repo.write(base_model + "\n" + model_to_merge)
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token):
# Define a fixed output path
outpath = Path('/tmp/output')
if outpath.exists() and outpath.is_dir():
shutil.rmtree(outpath)
write_repo(base_model, model_to_merge)
# Construct the command to run hf_merge.py
command = [
"python3", "hf_merge.py",
"-p", str(weight_drop_prob),
"-lambda", str(scaling_factor),
"repo.txt", str(outpath)
]
# Set up logging
logging.basicConfig(level=logging.INFO)
log_output = ""
# Run the command and capture the output
result = subprocess.run(command, capture_output=True, text=True)
# Log the output
log_output += result.stdout + "\n"
log_output += result.stderr + "\n"
logging.info(result.stdout)
logging.error(result.stderr)
# Check if the merge was successful
if result.returncode != 0:
return None, f"Error in merging models: {result.stderr}", log_output
# Update progress bar
yield 0.5, "Merging completed. Uploading to Hugging Face Hub..."
# Upload the result to Hugging Face Hub
api = HfApi(token=token)
try:
# Get the username of the user who is logged in
user = api.whoami(token=token)["name"]
# Autofill the repo name if none is provided
if not repo_name:
repo_name = f"{user}/default-repo"
# Create a new repo or update an existing one
api.create_repo(repo_id=repo_name, token=token, exist_ok=True)
# Upload the file
api.upload_folder(
folder_path=str(outpath),
repo_id=repo_name,
repo_type="model",
token=token
)
repo_url = f"https://huggingface.co/{repo_name}"
yield 1.0, "Upload completed."
return repo_url, "Model merged and uploaded successfully!", log_output
except Exception as e:
return None, f"Error uploading to Hugging Face Hub: {str(e)}", log_output
# Define the Gradio interface
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
gr.Markdown("# Model Merger and Uploader")
gr.Markdown("Combine any two models using a Super Mario merge(DARE) as described in the linked whitepaper.")
gr.Markdown("Works with:")
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
gr.Markdown("* LLMs (Mistral, Llama, etc)")
gr.Markdown("* LoRas (must be same size)")
gr.Markdown("* Any two homologous models")
with gr.Column():
with gr.Row():
token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
with gr.Row():
base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
with gr.Row():
model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
with gr.Row():
repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
with gr.Row():
scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
with gr.Row():
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
repo_url = gr.Markdown(label="Repository URL")
gr.Button("Merge").click(
merge_and_upload,
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token],
outputs=[repo_url]
)
demo.launch()