ngxson HF staff commited on
Commit
337b381
·
verified ·
1 Parent(s): f8d9bd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -5
app.py CHANGED
@@ -4,6 +4,7 @@ import signal
4
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5
  import gradio as gr
6
  import tempfile
 
7
 
8
  from huggingface_hub import HfApi, ModelCard, whoami
9
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
@@ -11,6 +12,69 @@ from pathlib import Path
11
  from textwrap import dedent
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
15
  if oauth_token is None or oauth_token.token is None:
16
  raise gr.Error("You must be logged in")
@@ -23,16 +87,20 @@ def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo,
23
  api = HfApi(token=oauth_token.token)
24
 
25
  with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
26
- result = subprocess.run([
 
27
  "mergekit-extract-lora",
28
  ft_model_id,
29
  base_model_id,
30
  outputdir,
31
  f"--rank={rank}",
32
- ], shell=False, capture_output=True)
33
- print(result)
34
- if result.returncode != 0:
35
- raise Exception(f"Error converting to LoRA PEFT {q_method}: {result.stderr}")
 
 
 
36
  print("Model converted to LoRA PEFT successfully!")
37
  print(f"Converted model path: {outputdir}")
38
 
 
4
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
5
  import gradio as gr
6
  import tempfile
7
+ import torch
8
 
9
  from huggingface_hub import HfApi, ModelCard, whoami
10
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
 
12
  from textwrap import dedent
13
 
14
 
15
+ ###########
16
+
17
+ import subprocess
18
+ import threading
19
+ from queue import Queue, Empty
20
+
21
+ def stream_output(pipe, queue):
22
+ """Read output from pipe and put it in the queue."""
23
+ for line in iter(pipe.readline, b''):
24
+ queue.put(line.decode('utf-8').rstrip())
25
+ pipe.close()
26
+
27
+ def run_command(command):
28
+ # Create process with pipes for stdout and stderr
29
+ process = subprocess.Popen(
30
+ command,
31
+ stdout=subprocess.PIPE,
32
+ stderr=subprocess.PIPE,
33
+ bufsize=1,
34
+ universal_newlines=False
35
+ )
36
+
37
+ # Create queues to store output
38
+ stdout_queue = Queue()
39
+ stderr_queue = Queue()
40
+
41
+ # Create and start threads to read output
42
+ stdout_thread = threading.Thread(target=stream_output, args=(process.stdout, stdout_queue))
43
+ stderr_thread = threading.Thread(target=stream_output, args=(process.stderr, stderr_queue))
44
+ stdout_thread.daemon = True
45
+ stderr_thread.daemon = True
46
+ stdout_thread.start()
47
+ stderr_thread.start()
48
+
49
+ output_stdout = ""
50
+ output_stderr = ""
51
+ # Monitor output in real-time
52
+ while process.poll() is None:
53
+ # Check stdout
54
+ try:
55
+ stdout_line = stdout_queue.get_nowait()
56
+ print(f"STDOUT: {stdout_line}")
57
+ output_stdout += stdout_line + "\n"
58
+ except Empty:
59
+ pass
60
+
61
+ # Check stderr
62
+ try:
63
+ stderr_line = stderr_queue.get_nowait()
64
+ print(f"STDERR: {stderr_line}")
65
+ output_stderr += stderr_line + "\n"
66
+ except Empty:
67
+ pass
68
+
69
+ # Get remaining lines
70
+ stdout_thread.join()
71
+ stderr_thread.join()
72
+
73
+ return (process.returncode, output_stdout, output_stderr)
74
+
75
+ ###########
76
+
77
+
78
  def process_model(ft_model_id: str, base_model_id: str, rank: str, private_repo, oauth_token: gr.OAuthToken | None):
79
  if oauth_token is None or oauth_token.token is None:
80
  raise gr.Error("You must be logged in")
 
87
  api = HfApi(token=oauth_token.token)
88
 
89
  with tempfile.TemporaryDirectory(dir="outputs") as outputdir:
90
+ device = "cuda" if torch.cuda.is_available() else "cpu"
91
+ (returncode, output_stdout, output_stderr) = run_command([
92
  "mergekit-extract-lora",
93
  ft_model_id,
94
  base_model_id,
95
  outputdir,
96
  f"--rank={rank}",
97
+ f"--device={device}"
98
+ ])
99
+ print("returncode", returncode)
100
+ print("output_stdout", output_stdout)
101
+ print("output_stderr", output_stderr)
102
+ if returncode != 0:
103
+ raise Exception(f"Error converting to LoRA PEFT {q_method}: {output_stderr}")
104
  print("Model converted to LoRA PEFT successfully!")
105
  print(f"Converted model path: {outputdir}")
106