Spaces:
Running
Running
import gradio as gr | |
import os | |
import shutil | |
import logging | |
import subprocess | |
from pathlib import Path | |
from merge_script import ModelMerger, get_max_vocab_size, download_json_files | |
import spaces | |
def merge_models(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message): | |
# Define staging and output paths | |
staging_path = "/tmp/staging" | |
output_path = "/tmp/output" | |
os.makedirs(staging_path, exist_ok=True) | |
os.makedirs(output_path, exist_ok=True) | |
# Initialize ModelMerger and prepare base model | |
model_merger = ModelMerger(staging_path, repo_name, token) | |
model_merger.prepare_base_model(base_model, os.path.join(staging_path, "base_model")) | |
# Merge models and handle progress updates | |
for repo_name in [base_model, model_to_merge]: | |
model_merger.merge_repo(repo_name, os.path.join(staging_path, "staging_model"), weight_drop_prob, scaling_factor) | |
yield 0.25, f"Merged {repo_name}" | |
# Finalize merge and handle vocab size | |
model_merger.finalize_merge(output_path) | |
yield 0.5, "Finalizing merge and handling vocab size..." | |
max_vocab_size, repo_with_max_vocab = get_max_vocab_size([base_model, model_to_merge]) | |
if max_vocab_size > 0: | |
download_json_files(repo_with_max_vocab, ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json'], output_path) | |
# Upload merged model to Hugging Face Hub | |
if repo_name: | |
model_merger.upload_model(output_path, repo_name, commit_message) | |
yield 0.75, "Uploading merged model to Hugging Face Hub..." | |
repo_url = f"https://huggingface.co/{repo_name}" | |
yield 1.0, f"Model merged and uploaded successfully! {repo_url}" | |
else: | |
yield 1.0, "Model merged successfully! No upload performed." | |
def get_model_type_info(model_name): | |
model_types = { | |
"base_model": "Base model should be in .safetensors format.", | |
"model_to_merge": "Model to merge can be in .safetensors or .bin format." | |
} | |
return model_types.get(model_name, "No specific info available.") | |
def validate_model_format(model_name, model_path): | |
if model_name == "base_model": | |
if not model_path.endswith(".safetensors"): | |
return False, "Base model must be in .safetensors format." | |
elif model_name == "model_to_merge": | |
if not model_path.endswith((".safetensors", ".bin")): | |
return False, "Model to merge must be in .safetensors or .bin format." | |
return True, None | |
def merge_and_upload_interface(): | |
with gr.Blocks(theme="Ytheme/XRainbow", fill_width=True) as demo: | |
gr.Markdown("# Model Merger and Uploader") | |
gr.Markdown("Combine and upload models with real-time progress updates.") | |
gr.Markdown("**Model Compatibility:**") | |
gr.Markdown("Combine any two models using a Super Mario merge.") | |
gr.Markdown("Works with:") | |
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)") | |
gr.Markdown("* LLMs (Mistral, Llama, etc)") | |
gr.Markdown("* LoRas (must be same size)") | |
gr.Markdown("* Any two homologous models") | |
with gr.Column(): | |
token = gr.Textbox(label="HuggingFace Token") | |
base_model = gr.Textbox(label="Base Model") | |
base_model_info = gr.HTML(get_model_type_info("base_model")) | |
model_to_merge = gr.Textbox(label="Model to Merge") | |
model_to_merge_info = gr.HTML(get_model_type_info("model_to_merge")) | |
repo_name = gr.Textbox(label="New Model Name") | |
scaling_factor = gr.Slider(minimum=0, maximum=10, label="Scaling Factor") | |
weight_drop_prob = gr.Slider(minimum=0, maximum=1, label="Weight Drop Probability") | |
gr.Button("Merge and Upload").click( | |
merge_models, | |
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message], | |
outputs=[ | |
gr.Textbox(label="Progress/Output"), | |
gr.Textbox(label="Log") | |
], | |
pre_process=validate_model_format | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
merge_and_upload_interface() |