mrcuddle's picture
Update app.py
e69d387 verified
raw
history blame
4.22 kB
import gradio as gr
import os
import shutil
import logging
import subprocess
from pathlib import Path
from merge_script import ModelMerger, get_max_vocab_size, download_json_files
import spaces
@spaces.GPU
def merge_models(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
# Define staging and output paths
staging_path = "/tmp/staging"
output_path = "/tmp/output"
os.makedirs(staging_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
# Initialize ModelMerger and prepare base model
model_merger = ModelMerger(staging_path, repo_name, token)
model_merger.prepare_base_model(base_model, os.path.join(staging_path, "base_model"))
# Merge models and handle progress updates
for repo_name in [base_model, model_to_merge]:
model_merger.merge_repo(repo_name, os.path.join(staging_path, "staging_model"), weight_drop_prob, scaling_factor)
yield 0.25, f"Merged {repo_name}"
# Finalize merge and handle vocab size
model_merger.finalize_merge(output_path)
yield 0.5, "Finalizing merge and handling vocab size..."
max_vocab_size, repo_with_max_vocab = get_max_vocab_size([base_model, model_to_merge])
if max_vocab_size > 0:
download_json_files(repo_with_max_vocab, ['config.json', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json'], output_path)
# Upload merged model to Hugging Face Hub
if repo_name:
model_merger.upload_model(output_path, repo_name, commit_message)
yield 0.75, "Uploading merged model to Hugging Face Hub..."
repo_url = f"https://huggingface.co/{repo_name}"
yield 1.0, f"Model merged and uploaded successfully! {repo_url}"
else:
yield 1.0, "Model merged successfully! No upload performed."
def get_model_type_info(model_name):
model_types = {
"base_model": "Base model should be in .safetensors format.",
"model_to_merge": "Model to merge can be in .safetensors or .bin format."
}
return model_types.get(model_name, "No specific info available.")
def validate_model_format(model_name, model_path):
if model_name == "base_model":
if not model_path.endswith(".safetensors"):
return False, "Base model must be in .safetensors format."
elif model_name == "model_to_merge":
if not model_path.endswith((".safetensors", ".bin")):
return False, "Model to merge must be in .safetensors or .bin format."
return True, None
def merge_and_upload_interface():
with gr.Blocks(theme="Ytheme/XRainbow", fill_width=True) as demo:
gr.Markdown("# Model Merger and Uploader")
gr.Markdown("Combine and upload models with real-time progress updates.")
gr.Markdown("**Model Compatibility:**")
gr.Markdown("Combine any two models using a Super Mario merge.")
gr.Markdown("Works with:")
gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
gr.Markdown("* LLMs (Mistral, Llama, etc)")
gr.Markdown("* LoRas (must be same size)")
gr.Markdown("* Any two homologous models")
with gr.Column():
token = gr.Textbox(label="HuggingFace Token")
base_model = gr.Textbox(label="Base Model")
base_model_info = gr.HTML(get_model_type_info("base_model"))
model_to_merge = gr.Textbox(label="Model to Merge")
model_to_merge_info = gr.HTML(get_model_type_info("model_to_merge"))
repo_name = gr.Textbox(label="New Model Name")
scaling_factor = gr.Slider(minimum=0, maximum=10, label="Scaling Factor")
weight_drop_prob = gr.Slider(minimum=0, maximum=1, label="Weight Drop Probability")
gr.Button("Merge and Upload").click(
merge_models,
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
outputs=[
gr.Textbox(label="Progress/Output"),
gr.Textbox(label="Log")
],
pre_process=validate_model_format
)
demo.launch()
if __name__ == "__main__":
merge_and_upload_interface()