File size: 3,356 Bytes
1b621fc
6efcc11
a26bb9f
 
 
e69d387
15ad924
c4a29da
a26bb9f
 
7d75a06
a26bb9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c118f8c
a26bb9f
 
 
 
 
fae0f74
a26bb9f
 
 
 
 
 
 
406418b
a26bb9f
406418b
a26bb9f
406418b
a26bb9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import subprocess
import os
import logging
from pathlib import Path
import spaces

@spaces.GPU()
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message):
    # Define a fixed output path
    outpath = Path('/tmp')

    # Construct the command to run hf_merge.py
    command = [
        "python3", "hf_merge.py",
        base_model,
        model_to_merge,
        "-p", str(weight_drop_prob),
        "-lambda", str(scaling_factor),
        "--token", token,
        "--repo", repo_name,
        "--commit-message", commit_message,
        "-U"
    ]

    # Set up logging
    logging.basicConfig(level=logging.INFO)
    log_output = ""

    # Run the command and capture the output
    result = subprocess.run(command, capture_output=True, text=True)

    # Log the output
    log_output += result.stdout + "\n"
    log_output += result.stderr + "\n"
    logging.info(result.stdout)
    logging.error(result.stderr)

    # Check if the merge was successful
    if result.returncode != 0:
        return None, f"Error in merging models: {result.stderr}", log_output

    # Assuming the script handles the upload and returns the repo URL
    repo_url = f"https://huggingface.co/{repo_name}"
    return repo_url, "Model merged and uploaded successfully!", log_output

# Define the Gradio interface
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
    gr.Markdown("# SuperMario Safetensors Merger")
    gr.Markdown("Combine any two models using a Super Mario merge(DARE)")
    gr.Markdown("Based on: https://github.com/martyn/safetensors-merge-supermario")
    gr.Markdown("Works with:")
    gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
    gr.Markdown("* LLMs (Mistral, Llama, etc) (also works with Llava, Visison models) ")
    gr.Markdown("* LoRas (must be same size)")
    gr.Markdown("* Any two homologous models")

    with gr.Column():
        with gr.Row():
            token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
        with gr.Row():
            base_model = gr.Textbox(label="Base Model", placeholder="meta-llama/Llama-3.2-11B-Vision-Instruct", info="Safetensors format")
        with gr.Row():
            model_to_merge = gr.Textbox(label="Merge Model", placeholder="Qwen/Qwen2.5-Coder-7B-Instruct", info="Safetensors or .bin")
        with gr.Row():
            repo_name = gr.Textbox(label="New Model", placeholder="Llama-Qwen-Vision_Instruct", info="your-username/new-model-name", value="", max_lines=1)
        with gr.Row():
            scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
        with gr.Row():
            weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
        with gr.Row():
            commit_message = gr.Textbox(label="Commit Message", value="Upload merged model", max_lines=1)

        progress = gr.Progress()
        repo_url = gr.Markdown(label="Repository URL")
        output = gr.Textbox(label="Output")

        gr.Button("Merge").click(
            merge_and_upload,
            inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, commit_message],
            outputs=[repo_url, output]
        )

demo.launch()