File size: 3,888 Bytes
1b621fc
912e4bf
b521cf1
d740066
 
6efcc11
495558a
15ad924
b521cf1
912e4bf
 
 
 
dcdad58
8e39690
dd47c8c
379abaf
 
 
912e4bf
379abaf
6897b6a
1b621fc
6897b6a
 
 
379abaf
1b621fc
e47add2
12cd0d5
9c73f19
3052ca7
379abaf
6897b6a
1b621fc
379abaf
12cd0d5
3052ca7
 
12cd0d5
 
379abaf
0e69513
379abaf
6a1f2ac
3052ca7
 
dcdad58
379abaf
0e69513
912e4bf
0e69513
6efcc11
 
 
 
 
 
 
0e69513
912e4bf
379abaf
0e69513
ea54467
379abaf
 
4dda767
379abaf
0e69513
3052ca7
dcdad58
6a1f2ac
0e69513
6a1f2ac
1b621fc
912e4bf
5decfa4
912e4bf
5decfa4
 
6efcc11
 
 
 
379abaf
6efcc11
4dda767
0f2e02c
 
5decfa4
0f2e02c
5decfa4
 
 
0f2e02c
 
 
 
379abaf
5decfa4
 
 
379abaf
 
9151d79
379abaf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from huggingface_hub import HfApi
import spaces
import shutil
import logging
import subprocess
from pathlib import Path

@spaces.GPU
def write_repo(base_model, model_to_merge):
    with open("repo.txt", "w") as repo:
        repo.write(base_model + "\n" + model_to_merge)

def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token):
    # Define a fixed output path
    outpath = Path('/tmp/output')
    if outpath.exists() and outpath.is_dir():
        shutil.rmtree(outpath)

    write_repo(base_model, model_to_merge)

    # Construct the command to run hf_merge.py
    command = [
        "python3", "hf_merge.py",
        "-p", str(weight_drop_prob),
        "-lambda", str(scaling_factor),
        "repo.txt", str(outpath)
    ]

    # Set up logging
    logging.basicConfig(level=logging.INFO)
    log_output = ""

    # Run the command and capture the output
    result = subprocess.run(command, capture_output=True, text=True)

    # Log the output
    log_output += result.stdout + "\n"
    log_output += result.stderr + "\n"
    logging.info(result.stdout)
    logging.error(result.stderr)

    # Check if the merge was successful
    if result.returncode != 0:
        return None, f"Error in merging models: {result.stderr}", log_output

    # Update progress bar
    yield 0.5, "Merging completed. Uploading to Hugging Face Hub..."

    # Upload the result to Hugging Face Hub
    api = HfApi(token=token)
    try:
        # Get the username of the user who is logged in
        user = api.whoami(token=token)["name"]

        # Autofill the repo name if none is provided
        if not repo_name:
            repo_name = f"{user}/default-repo"

        # Create a new repo or update an existing one
        api.create_repo(repo_id=repo_name, token=token, exist_ok=True)

        # Upload the file
        api.upload_folder(
            folder_path=str(outpath),
            repo_id=repo_name,
            repo_type="model",
            token=token
        )
        repo_url = f"https://huggingface.co/{repo_name}"
        yield 1.0, "Upload completed."
        return repo_url, "Model merged and uploaded successfully!", log_output
    except Exception as e:
        return None, f"Error uploading to Hugging Face Hub: {str(e)}", log_output

# Define the Gradio interface
with gr.Blocks(theme="Ytheme/Minecraft", fill_width=True, delete_cache=(60, 3600)) as demo:
    gr.Markdown("# Model Merger and Uploader")
    gr.Markdown("Combine any two models using a Super Mario merge(DARE) as described in the linked whitepaper.")
    gr.Markdown("Works with:")
    gr.Markdown("* Stable Diffusion (1.5, XL/XL Turbo)")
    gr.Markdown("* LLMs (Mistral, Llama, etc)")
    gr.Markdown("* LoRas (must be same size)")
    gr.Markdown("* Any two homologous models")

    with gr.Column():
        with gr.Row():
            token = gr.Textbox(label="Your HF write token", placeholder="hf_...", value="", max_lines=1)
        with gr.Row():
            base_model = gr.Textbox(label="Base Model", placeholder=".safetensors")
        with gr.Row():
            model_to_merge = gr.Textbox(label="Merge Model", placeholder=".bin/.safetensors")
        with gr.Row():
            repo_name = gr.Textbox(label="New Model", placeholder="SDXL-", info="If empty, auto-complete", value="", max_lines=1)
        with gr.Row():
            scaling_factor = gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor")
        with gr.Row():
            weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")

        repo_url = gr.Markdown(label="Repository URL")

        gr.Button("Merge").click(
            merge_and_upload,
            inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token],
            outputs=[repo_url]
        )

demo.launch()