Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ def write_repo(base_model, model_to_merge):
|
|
11 |
with open("repo.txt", "w") as repo:
|
12 |
repo.write(base_model + "\n" + model_to_merge)
|
13 |
|
14 |
-
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token):
|
15 |
# Define a fixed output path
|
16 |
outpath = Path('/tmp/output')
|
17 |
if outpath.exists() and outpath.is_dir():
|
@@ -29,17 +29,23 @@ def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_pro
|
|
29 |
|
30 |
# Set up logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
|
|
32 |
|
33 |
# Run the command and capture the output
|
34 |
result = subprocess.run(command, capture_output=True, text=True)
|
35 |
|
36 |
# Log the output
|
|
|
|
|
37 |
logging.info(result.stdout)
|
38 |
logging.error(result.stderr)
|
39 |
|
40 |
# Check if the merge was successful
|
41 |
if result.returncode != 0:
|
42 |
-
return f"Error in merging models: {result.stderr}"
|
|
|
|
|
|
|
43 |
|
44 |
# Upload the result to Hugging Face Hub
|
45 |
api = HfApi(token=token)
|
@@ -61,9 +67,10 @@ def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_pro
|
|
61 |
repo_type="model",
|
62 |
token=token
|
63 |
)
|
64 |
-
|
|
|
65 |
except Exception as e:
|
66 |
-
return f"Error uploading to Hugging Face Hub: {str(e)}"
|
67 |
|
68 |
# Define the Gradio interface
|
69 |
with gr.Blocks() as demo:
|
@@ -88,13 +95,16 @@ with gr.Blocks() as demo:
|
|
88 |
with gr.Row():
|
89 |
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
|
90 |
|
|
|
91 |
gr.Button("Merge and Upload").click(
|
92 |
merge_and_upload,
|
93 |
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token],
|
94 |
-
|
95 |
)
|
96 |
|
97 |
with gr.Column():
|
|
|
|
|
98 |
output = gr.Textbox(label="Output")
|
99 |
|
100 |
demo.launch()
|
|
|
11 |
with open("repo.txt", "w") as repo:
|
12 |
repo.write(base_model + "\n" + model_to_merge)
|
13 |
|
14 |
+
def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token, progress=gr.Progress()):
|
15 |
# Define a fixed output path
|
16 |
outpath = Path('/tmp/output')
|
17 |
if outpath.exists() and outpath.is_dir():
|
|
|
29 |
|
30 |
# Set up logging
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
+
log_output = ""
|
33 |
|
34 |
# Run the command and capture the output
|
35 |
result = subprocess.run(command, capture_output=True, text=True)
|
36 |
|
37 |
# Log the output
|
38 |
+
log_output += result.stdout + "\n"
|
39 |
+
log_output += result.stderr + "\n"
|
40 |
logging.info(result.stdout)
|
41 |
logging.error(result.stderr)
|
42 |
|
43 |
# Check if the merge was successful
|
44 |
if result.returncode != 0:
|
45 |
+
return log_output, None, f"Error in merging models: {result.stderr}"
|
46 |
+
|
47 |
+
# Update progress bar
|
48 |
+
progress(0.5, desc="Merging completed. Uploading to Hugging Face Hub...")
|
49 |
|
50 |
# Upload the result to Hugging Face Hub
|
51 |
api = HfApi(token=token)
|
|
|
67 |
repo_type="model",
|
68 |
token=token
|
69 |
)
|
70 |
+
repo_url = f"https://huggingface.co/{repo_name}"
|
71 |
+
return log_output, repo_url, "Model merged and uploaded successfully!"
|
72 |
except Exception as e:
|
73 |
+
return log_output, None, f"Error uploading to Hugging Face Hub: {str(e)}"
|
74 |
|
75 |
# Define the Gradio interface
|
76 |
with gr.Blocks() as demo:
|
|
|
95 |
with gr.Row():
|
96 |
weight_drop_prob = gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability")
|
97 |
|
98 |
+
progress = gr.Progress()
|
99 |
gr.Button("Merge and Upload").click(
|
100 |
merge_and_upload,
|
101 |
inputs=[base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name, token],
|
102 |
+
outputs=[log_output, repo_url, output]
|
103 |
)
|
104 |
|
105 |
with gr.Column():
|
106 |
+
log_output = gr.Textbox(label="Log Output")
|
107 |
+
repo_url = gr.Markdown(label="Repository URL")
|
108 |
output = gr.Textbox(label="Output")
|
109 |
|
110 |
demo.launch()
|