Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import signal
|
|
4 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
5 |
import gradio as gr
|
6 |
import tempfile
|
|
|
7 |
|
8 |
from huggingface_hub import HfApi, ModelCard, whoami
|
9 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
@@ -11,6 +12,69 @@ from pathlib import Path
|
|
11 |
from textwrap import dedent
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
15 |
if oauth_token is None or oauth_token.token is None:
|
16 |
raise gr.Error("You must be logged in")
|
@@ -23,16 +87,20 @@ def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo,
|
|
23 |
api = HfApi(token=oauth_token.token)
|
24 |
|
25 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
26 |
-
|
|
|
27 |
"mergekit-extract-lora",
|
28 |
ft_model_id,
|
29 |
base_model_id,
|
30 |
outputdir,
|
31 |
f"--rank={rank}",
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
36 |
print("Model converted to LoRA PEFT successfully!")
|
37 |
print(f"Converted model path: {outputdir}")
|
38 |
|
|
|
4 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
5 |
import gradio as gr
|
6 |
import tempfile
|
7 |
+
import torch
|
8 |
|
9 |
from huggingface_hub import HfApi, ModelCard, whoami
|
10 |
from gradio_huggingfacehub_search import HuggingfaceHubSearch
|
|
|
12 |
from textwrap import dedent
|
13 |
|
14 |
|
15 |
+
###########
|
16 |
+
|
17 |
+
import subprocess
|
18 |
+
import threading
|
19 |
+
from queue import Queue, Empty
|
20 |
+
|
21 |
+
def stream_output(pipe, queue):
|
22 |
+
"""Read output from pipe and put it in the queue."""
|
23 |
+
for line in iter(pipe.readline, b''):
|
24 |
+
queue.put(line.decode('utf-8').rstrip())
|
25 |
+
pipe.close()
|
26 |
+
|
27 |
+
def run_command(command):
|
28 |
+
# Create process with pipes for stdout and stderr
|
29 |
+
process = subprocess.Popen(
|
30 |
+
command,
|
31 |
+
stdout=subprocess.PIPE,
|
32 |
+
stderr=subprocess.PIPE,
|
33 |
+
bufsize=1,
|
34 |
+
universal_newlines=False
|
35 |
+
)
|
36 |
+
|
37 |
+
# Create queues to store output
|
38 |
+
stdout_queue = Queue()
|
39 |
+
stderr_queue = Queue()
|
40 |
+
|
41 |
+
# Create and start threads to read output
|
42 |
+
stdout_thread = threading.Thread(target=stream_output, args=(process.stdout, stdout_queue))
|
43 |
+
stderr_thread = threading.Thread(target=stream_output, args=(process.stderr, stderr_queue))
|
44 |
+
stdout_thread.daemon = True
|
45 |
+
stderr_thread.daemon = True
|
46 |
+
stdout_thread.start()
|
47 |
+
stderr_thread.start()
|
48 |
+
|
49 |
+
output_stdout = ""
|
50 |
+
output_stderr = ""
|
51 |
+
# Monitor output in real-time
|
52 |
+
while process.poll() is None:
|
53 |
+
# Check stdout
|
54 |
+
try:
|
55 |
+
stdout_line = stdout_queue.get_nowait()
|
56 |
+
print(f"STDOUT: {stdout_line}")
|
57 |
+
output_stdout += stdout_line + "\n"
|
58 |
+
except Empty:
|
59 |
+
pass
|
60 |
+
|
61 |
+
# Check stderr
|
62 |
+
try:
|
63 |
+
stderr_line = stderr_queue.get_nowait()
|
64 |
+
print(f"STDERR: {stderr_line}")
|
65 |
+
output_stderr += stderr_line + "\n"
|
66 |
+
except Empty:
|
67 |
+
pass
|
68 |
+
|
69 |
+
# Get remaining lines
|
70 |
+
stdout_thread.join()
|
71 |
+
stderr_thread.join()
|
72 |
+
|
73 |
+
return (process.returncode, output_stdout, output_stderr)
|
74 |
+
|
75 |
+
###########
|
76 |
+
|
77 |
+
|
78 |
def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
|
79 |
if oauth_token is None or oauth_token.token is None:
|
80 |
raise gr.Error("You must be logged in")
|
|
|
87 |
api = HfApi(token=oauth_token.token)
|
88 |
|
89 |
with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
|
90 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
+
(returncode, output_stdout, output_stderr) = run_command([
|
92 |
"mergekit-extract-lora",
|
93 |
ft_model_id,
|
94 |
base_model_id,
|
95 |
outputdir,
|
96 |
f"--rank={rank}",
|
97 |
+
f"--device={device}"
|
98 |
+
])
|
99 |
+
print("returncode", returncode)
|
100 |
+
print("output_stdout", output_stdout)
|
101 |
+
print("output_stderr", output_stderr)
|
102 |
+
if returncode != 0:
|
103 |
+
raise Exception(f"Error converting to LoRA PEFT {q_method}: {output_stderr}")
|
104 |
print("Model converted to LoRA PEFT successfully!")
|
105 |
print(f"Converted model path: {outputdir}")
|
106 |
|