KingNish commited on
Commit
a80ba5c
·
verified ·
1 Parent(s): fc21d85

Delete spaces

Browse files
spaces/__init__.py DELETED
@@ -1,30 +0,0 @@
1
- """
2
- """
3
-
4
- import sys
5
-
6
-
7
- if sys.version_info.minor < 8: # pragma: no cover
8
- raise RuntimeError("Importing PySpaces requires Python 3.8+")
9
-
10
-
11
- # Prevent gradio from importing spaces
12
- if (gr := sys.modules.get('gradio')) is not None: # pragma: no cover
13
- try:
14
- gr.Blocks
15
- except AttributeError:
16
- raise ImportError
17
-
18
-
19
- from .zero.decorator import GPU
20
- from .gradio import gradio_auto_wrap
21
- from .gradio import disable_gradio_auto_wrap
22
- from .gradio import enable_gradio_auto_wrap
23
-
24
-
25
- __all__ = [
26
- 'GPU',
27
- 'gradio_auto_wrap',
28
- 'disable_gradio_auto_wrap',
29
- 'enable_gradio_auto_wrap',
30
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/config.py DELETED
@@ -1,37 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import os
6
- from pathlib import Path
7
-
8
- from .utils import boolean
9
-
10
-
11
- ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors')
12
-
13
-
14
- class Settings:
15
- def __init__(self):
16
- self.zero_gpu = boolean(
17
- os.getenv('SPACES_ZERO_GPU'))
18
- self.zero_device_api_url = (
19
- os.getenv('SPACES_ZERO_DEVICE_API_URL'))
20
- self.gradio_auto_wrap = boolean(
21
- os.getenv('SPACES_GRADIO_AUTO_WRAP'))
22
- self.zero_patch_torch_device = boolean(
23
- os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE'))
24
- self.zero_gpu_v2 = boolean(
25
- os.getenv('ZEROGPU_V2'))
26
- self.zerogpu_offload_dir = (
27
- os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT))
28
-
29
-
30
- Config = Settings()
31
-
32
-
33
- if Config.zero_gpu:
34
- assert Config.zero_device_api_url is not None, (
35
- 'SPACES_ZERO_DEVICE_API_URL env must be set '
36
- 'on ZeroGPU Spaces (identified by SPACES_ZERO_GPU=true)'
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/gradio.py DELETED
@@ -1,55 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from typing import Callable
6
- from typing import Generator
7
- from typing import TypeVar
8
- from typing import overload
9
- from typing_extensions import ParamSpec
10
-
11
- from .config import Config
12
- from .zero.decorator import GPU
13
-
14
-
15
- Param = ParamSpec('Param')
16
- Res = TypeVar('Res')
17
-
18
-
19
- gradio_auto_wrap_enabled = Config.gradio_auto_wrap
20
-
21
-
22
- def disable_gradio_auto_wrap():
23
- global gradio_auto_wrap_enabled
24
- gradio_auto_wrap_enabled = False
25
-
26
- def enable_gradio_auto_wrap():
27
- global gradio_auto_wrap_enabled
28
- gradio_auto_wrap_enabled = True
29
-
30
-
31
- @overload
32
- def gradio_auto_wrap(
33
- task:
34
- Callable[Param, Res],
35
- ) -> Callable[Param, Res]:
36
- ...
37
- @overload
38
- def gradio_auto_wrap(
39
- task:
40
- None,
41
- ) -> None:
42
- ...
43
- def gradio_auto_wrap(
44
- task:
45
- Callable[Param, Res]
46
- | None,
47
- ) -> (Callable[Param, Res]
48
- | None):
49
- """
50
- """
51
- if not gradio_auto_wrap_enabled:
52
- return task
53
- if not callable(task):
54
- return task
55
- return GPU(task) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/utils.py DELETED
@@ -1,85 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import ctypes
6
- import sys
7
- from functools import lru_cache as cache
8
- from functools import partial
9
-
10
- import multiprocessing
11
- from multiprocessing.queues import SimpleQueue as _SimpleQueue
12
- from pathlib import Path
13
- from pickle import PicklingError
14
- from typing import Callable
15
- from typing import TypeVar
16
-
17
-
18
- GRADIO_VERSION_ERROR_MESSAGE = "Make sure Gradio version is at least 3.46"
19
-
20
-
21
- T = TypeVar('T')
22
-
23
-
24
- @cache
25
- def self_cgroup_device_path() -> str:
26
- cgroup_content = Path('/proc/self/cgroup').read_text()
27
- for line in cgroup_content.strip().split('\n'):
28
- contents = line.split(':devices:')
29
- if len(contents) != 2:
30
- continue # pragma: no cover
31
- return contents[1]
32
- raise Exception # pragma: no cover
33
-
34
-
35
- if sys.version_info.minor < 9: # pragma: no cover
36
- _SimpleQueue.__class_getitem__ = classmethod(lambda cls, _: cls) # type: ignore
37
-
38
- class SimpleQueue(_SimpleQueue[T]):
39
- def __init__(self, *args):
40
- super().__init__(*args, ctx=multiprocessing.get_context('fork'))
41
- def put(self, obj: T):
42
- try:
43
- super().put(obj)
44
- except PicklingError:
45
- raise # pragma: no cover
46
- # https://bugs.python.org/issue29187
47
- except Exception as e:
48
- message = str(e)
49
- if not "pickle" in message:
50
- raise # pragma: no cover
51
- raise PicklingError(message)
52
- def close(self): # Python 3.8 static typing trick
53
- super().close() # type: ignore
54
- def wlock_release(self):
55
- if (lock := getattr(self, '_wlock', None)) is None:
56
- return # pragma: no cover
57
- try:
58
- lock.release()
59
- except ValueError:
60
- pass
61
-
62
-
63
- def drop_params(fn: Callable[[], T]) -> Callable[..., T]:
64
- def drop(*args):
65
- return fn()
66
- return drop
67
-
68
-
69
- def boolean(value: str | None) -> bool:
70
- return value is not None and value.lower() in ("1", "t", "true")
71
-
72
-
73
- def gradio_request_var():
74
- try:
75
- from gradio.context import LocalContext
76
- except ImportError: # pragma: no cover
77
- raise RuntimeError(GRADIO_VERSION_ERROR_MESSAGE)
78
- return LocalContext.request
79
-
80
-
81
- def malloc_trim():
82
- ctypes.CDLL("libc.so.6").malloc_trim(0)
83
-
84
-
85
- debug = partial(print, 'SPACES_ZERO_GPU_DEBUG')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- """
2
- """
3
-
4
- from pathlib import Path
5
-
6
- from ..config import Config
7
-
8
-
9
- if Config.zero_gpu:
10
-
11
- from . import gradio
12
- from . import torch
13
-
14
- if torch.is_in_bad_fork():
15
- raise RuntimeError(
16
- "CUDA has been initialized before importing the `spaces` package"
17
- )
18
-
19
- torch.patch()
20
- gradio.one_launch(torch.pack)
21
- Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/api.py DELETED
@@ -1,156 +0,0 @@
1
- """
2
- Synced with huggingface/pyspaces:spaces/zero/api.py
3
- """
4
- from __future__ import annotations
5
-
6
- from datetime import timedelta
7
- from typing import Any
8
- from typing import Generator
9
- from typing import Literal
10
- from typing import NamedTuple
11
- from typing import Optional
12
- from typing import overload
13
-
14
- import httpx
15
- from pydantic import BaseModel
16
- from typing_extensions import assert_never
17
-
18
-
19
- AllowToken = str
20
- NvidiaIndex = int # TODO: Migrate to GpuIndex (less confusing for MIG)
21
- NvidiaUUID = str
22
- CGroupPath = str
23
- VisitorId = str
24
- Score = float
25
-
26
- AuthLevel = Literal['regular', 'pro']
27
-
28
-
29
- AUTHENTICATED_HEADER = 'X-Authenticated'
30
-
31
-
32
- class ScheduleResponse(BaseModel):
33
- idle: bool
34
- nvidiaIndex: int
35
- nvidiaUUID: str
36
- allowToken: str
37
-
38
-
39
- class QuotaInfos(BaseModel):
40
- left: int
41
- wait: timedelta
42
-
43
-
44
- class ReportUsageMonitoringParams(NamedTuple):
45
- nvidia_index: int
46
- visitor_id: str
47
- duration: timedelta
48
-
49
-
50
- class QueueEvent(BaseModel):
51
- event: Literal['ping', 'failed', 'succeeded']
52
- data: Optional[ScheduleResponse] = None
53
-
54
-
55
- def sse_parse(text: str):
56
- event, *data = text.strip().splitlines()
57
- assert event.startswith('event:')
58
- event = event[6:].strip()
59
- if event in ('ping', 'failed'):
60
- return QueueEvent(event=event)
61
- assert event == 'succeeded'
62
- (data,) = data
63
- assert data.startswith('data:')
64
- data = data[5:].strip()
65
- return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data))
66
-
67
-
68
- def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]:
69
- for text in res.iter_text():
70
- if len(text) == 0:
71
- break # pragma: no cover
72
- try:
73
- yield sse_parse(text)
74
- except GeneratorExit:
75
- res.close()
76
- break
77
-
78
-
79
- class APIClient:
80
-
81
- def __init__(self, client: httpx.Client):
82
- self.client = client
83
-
84
- def startup_report(self) -> httpx.codes:
85
- res = self.client.post('/startup-report')
86
- return httpx.codes(res.status_code)
87
-
88
- def schedule(
89
- self,
90
- cgroup_path: str,
91
- task_id: int = 0,
92
- token: str | None = None,
93
- duration_seconds: int | None = None,
94
- enable_queue: bool = True,
95
- ):
96
- params: dict[str, str | int | bool] = {
97
- 'cgroupPath': cgroup_path,
98
- 'taskId': task_id,
99
- 'enableQueue': enable_queue,
100
- }
101
- if duration_seconds is not None:
102
- params['durationSeconds'] = duration_seconds
103
- if token is not None:
104
- params['token'] = token
105
- res = self.client.send(
106
- request=self.client.build_request(
107
- method='POST',
108
- url='/schedule',
109
- params=params,
110
- ),
111
- stream=True,
112
- )
113
- status = httpx.codes(res.status_code)
114
- auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER)
115
- if (status is not httpx.codes.OK and
116
- status is not httpx.codes.TOO_MANY_REQUESTS
117
- ):
118
- res.close()
119
- return status, auth
120
- if "text/event-stream" in res.headers['content-type']:
121
- return sse_stream(res), auth
122
- res.read()
123
- if status is httpx.codes.TOO_MANY_REQUESTS:
124
- return QuotaInfos(**res.json()), auth # pragma: no cover
125
- if status is httpx.codes.OK:
126
- return ScheduleResponse(**res.json()), auth
127
- assert_never(status)
128
-
129
- def allow(
130
- self,
131
- allow_token: str,
132
- pid: int,
133
- ):
134
- res = self.client.post('/allow', params={
135
- 'allowToken': allow_token,
136
- 'pid': pid,
137
- })
138
- return httpx.codes(res.status_code)
139
-
140
- def release(
141
- self,
142
- allow_token: str,
143
- fail: bool = False,
144
- ) -> httpx.codes:
145
- res = self.client.post('/release', params={
146
- 'allowToken': allow_token,
147
- 'fail': fail,
148
- })
149
- return httpx.codes(res.status_code)
150
-
151
- def get_queue_size(self) -> int:
152
- res = self.client.get('/queue-size')
153
- assert res.status_code == 200, res.status_code
154
- size = res.json()
155
- assert isinstance(size, int)
156
- return size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/client.py DELETED
@@ -1,239 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import os
6
- import time
7
- import warnings
8
- from datetime import timedelta
9
-
10
- import gradio as gr
11
- import httpx
12
- from packaging import version
13
- from typing_extensions import assert_never
14
-
15
- from .. import utils
16
- from ..config import Config
17
- from .api import APIClient
18
- from .api import AuthLevel
19
- from .api import QuotaInfos
20
- from .api import ScheduleResponse
21
- from .gradio import HTMLError
22
- from .gradio import get_event
23
- from .gradio import supports_auth
24
-
25
-
26
- TOKEN_HEADER = 'X-IP-Token'
27
- DEFAULT_SCHEDULE_DURATION = 60
28
-
29
- QUOTA_MESSAGE = "You have exceeded your GPU quota"
30
- UNUSED_MESSAGE = "GPU device not used"
31
- NO_GPU_MESSAGE_REGULAR = "No GPU was available"
32
- NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60s"
33
-
34
- SIGNUP_ON_HF_TXT = "Create a free account"
35
- SIGNUP_ON_HF_URL = "https://huggingface.co/join"
36
- SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro"
37
- SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription"
38
-
39
-
40
- def api_client():
41
- assert Config.zero_device_api_url is not None
42
- httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False)
43
- return APIClient(httpx_client)
44
-
45
-
46
- def startup_report():
47
- retries, max_retries = 0, 2
48
- client = api_client()
49
- while (status := client.startup_report()) is httpx.codes.NOT_FOUND: # pragma: no cover
50
- time.sleep(1)
51
- if (retries := retries + 1) > max_retries:
52
- raise RuntimeError("Error while initializing ZeroGPU: NotFound")
53
- if status is not httpx.codes.OK: # pragma: no cover
54
- raise RuntimeError("Error while initializing ZeroGPU: Unknown")
55
-
56
-
57
- def html_string(html_contents: str, text_contents: str): # pragma: no cover
58
- class HTMLString(str):
59
- def __str__(self):
60
- return text_contents
61
- return HTMLString(html_contents)
62
-
63
-
64
- def _toast_action(
65
- auth: AuthLevel | None,
66
- supports_html: bool,
67
- pro_message: str,
68
- unlogged_desc: str,
69
- logged_desc: str,
70
- ending: str,
71
- ) -> tuple[str, str]: # pragma: no cover
72
- if not supports_auth() or auth == 'pro':
73
- return pro_message, pro_message
74
- html = ""
75
- link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL
76
- text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT
77
- desc = unlogged_desc if auth is None else logged_desc
78
- desc += f" {ending}."
79
- style = ";".join([
80
- "white-space: nowrap",
81
- "text-underline-offset: 2px",
82
- "color: var(--body-text-color)",
83
- ])
84
- if supports_html:
85
- html += f'<a style="{style}" href="{link}">'
86
- html += text
87
- if supports_html:
88
- html += '</a> '
89
- html += desc
90
- markdown = f'[{text}]({link}) {desc}'
91
- return html, markdown
92
-
93
-
94
- def schedule(
95
- task_id: int,
96
- request: gr.Request | None = None,
97
- duration: timedelta | None = None,
98
- _first_attempt: bool = True,
99
- ) -> ScheduleResponse:
100
-
101
- if not (gradio_version := version.parse(gr.__version__)).major >= 4: # pragma: no cover
102
- raise RuntimeError("ZeroGPU is only compatible with Gradio 4+")
103
-
104
- GRADIO_HTML_TOASTS = gradio_version.minor >= 39
105
-
106
- res, auth = api_client().schedule(
107
- cgroup_path=utils.self_cgroup_device_path(),
108
- task_id=task_id,
109
- token=_get_token(request),
110
- duration_seconds=duration.seconds if duration is not None else None,
111
- )
112
-
113
- if isinstance(res, ScheduleResponse):
114
- return res
115
-
116
- if isinstance(res, QuotaInfos): # pragma: no cover
117
- requested = duration.seconds if duration is not None else DEFAULT_SCHEDULE_DURATION
118
- if res.wait < timedelta(0):
119
- raise gr.Error(
120
- f"The requested GPU duration ({requested}s) "
121
- f"is larger than the maximum allowed"
122
- )
123
- else:
124
- gpu = "Pro GPU" if auth == 'pro' else ("free GPU" if auth == 'regular' else "GPU")
125
- message = (
126
- f"You have exceeded your {gpu} quota "
127
- f"({requested}s requested vs. {res.left}s left)."
128
- )
129
- details_html, details_markdown = _toast_action(
130
- auth=auth,
131
- supports_html=GRADIO_HTML_TOASTS,
132
- pro_message=f"Try again in {res.wait}",
133
- unlogged_desc="to get more",
134
- logged_desc="to get 5x more",
135
- ending="usage quota",
136
- )
137
- message_html = f"{message} {details_html}"
138
- message_text = f"{message} {details_markdown}"
139
- raise HTMLError(html_string(message_html, message_text))
140
-
141
- if not isinstance(res, httpx.codes): # pragma: no cover
142
- gr.Info("Waiting for a GPU to become available")
143
- # TODO: Sign-up message if not authenticated (after some time ?)
144
- connection_event = get_event()
145
- if connection_event is None and request is not None:
146
- warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
147
- while True:
148
- try:
149
- event = next(res)
150
- except StopIteration:
151
- raise RuntimeError("Unexpected end of stream")
152
- except httpx.RemoteProtocolError:
153
- if not _first_attempt:
154
- raise RuntimeError("Error while re-trying after queue disconnect")
155
- return schedule(task_id, request, duration, _first_attempt=False)
156
- if event.event == 'ping':
157
- if connection_event is not None and not connection_event.alive:
158
- res.close()
159
- raise RuntimeError("Connection closed by visitor while queueing")
160
- continue
161
- if event.event == 'failed':
162
- details_html, details_markdown = _toast_action(
163
- auth=auth,
164
- supports_html=GRADIO_HTML_TOASTS,
165
- pro_message="Retry later",
166
- unlogged_desc="to get a higher",
167
- logged_desc="to get the highest",
168
- ending="priority in ZeroGPU queues",
169
- )
170
- message_html = f"{NO_GPU_MESSAGE_INQUEUE}. {details_html}"
171
- message_text = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}"
172
- raise HTMLError(html_string(message_html, message_text))
173
- if event.event == 'succeeded':
174
- assert event.data is not None
175
- if connection_event is not None and not connection_event.alive:
176
- release(event.data.allowToken)
177
- raise RuntimeError("Connection closed by visitor on queue success")
178
- gr.Info("Successfully acquired a GPU")
179
- return event.data
180
-
181
- if res is httpx.codes.SERVICE_UNAVAILABLE:
182
- raise gr.Error(NO_GPU_MESSAGE_REGULAR)
183
-
184
- # TODO: Find a way to log 'detail' response field
185
- raise RuntimeError(f"ZeroGPU API /schedule error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
186
-
187
-
188
- def allow(allow_token: str) -> None:
189
- pid = os.getpid()
190
- assert pid != 1, "Allowing PID 1 on ZeroGPU will end up killing your Space"
191
- assert api_client().allow(allow_token=allow_token, pid=pid) is httpx.codes.OK
192
-
193
-
194
- def release(
195
- allow_token: str, *,
196
- fail: bool = False,
197
- allow_404: bool = False,
198
- ) -> None:
199
-
200
- res = api_client().release(
201
- allow_token=allow_token,
202
- fail=fail,
203
- )
204
-
205
- if res is httpx.codes.NO_CONTENT: # pragma: no cover
206
- try:
207
- gr.Warning(UNUSED_MESSAGE)
208
- except AttributeError:
209
- pass
210
- warnings.warn(UNUSED_MESSAGE, RuntimeWarning)
211
- return None
212
-
213
- if res is httpx.codes.NOT_FOUND:
214
- if not allow_404:
215
- warnings.warn("ZeroGPU API /release warning: 404 Not Found")
216
- return None
217
-
218
- if httpx.codes.is_success(res):
219
- return None
220
-
221
- # TODO: Find a way to log 'detail' response field
222
- # TODO: Only raise in dev environment. Simply warn in production ?
223
- raise RuntimeError(f"ZeroGPU API /release error: {res} ({httpx.codes.get_reason_phrase(res)})") # pragma: no cover
224
-
225
-
226
- def _get_token(request: gr.Request | None) -> str | None:
227
-
228
- if request is None:
229
- return None
230
-
231
- headers = getattr(request, 'headers', None)
232
- if headers is None or not hasattr(headers, '__dict__'):
233
- raise gr.Error("Internal Gradio error")
234
-
235
- # Compatibility trick
236
- if not hasattr(headers, 'get'):
237
- headers = headers.__dict__ # pragma: no cover
238
-
239
- return headers.get(TOKEN_HEADER.lower())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/decorator.py DELETED
@@ -1,113 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import inspect
6
- import sys
7
- import warnings
8
- from datetime import timedelta
9
- from functools import partial
10
- from typing import Callable
11
- from typing import TypeVar
12
- from typing import overload
13
- from typing_extensions import ParamSpec
14
- from typing_extensions import Unpack
15
-
16
- from ..config import Config
17
- from .types import DynamicDuration
18
- from .types import EmptyKwargs
19
-
20
-
21
- P = ParamSpec('P')
22
- R = TypeVar('R')
23
-
24
-
25
- decorated_cache: dict[Callable, Callable] = {}
26
-
27
-
28
- @overload
29
- def GPU(
30
- task: None = None, *,
31
- duration: DynamicDuration[P] = None,
32
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
33
- ...
34
- @overload
35
- def GPU(
36
- task: Callable[P, R], *,
37
- duration: DynamicDuration[P] = None,
38
- ) -> Callable[P, R]:
39
- ...
40
- def GPU(
41
- task: Callable[P, R] | None = None, *,
42
- duration: DynamicDuration[P] = None,
43
- **kwargs: Unpack[EmptyKwargs],
44
- ) -> Callable[[Callable[P, R]], Callable[P, R]] | Callable[P, R]:
45
- """
46
- ZeroGPU decorator
47
-
48
- Basic usage:
49
- ```
50
- @spaces.GPU
51
- def fn(...):
52
- # CUDA is available here
53
- pass
54
- ```
55
-
56
- With custom duration:
57
- ```
58
- @spaces.GPU(duration=45) # Expressed in seconds
59
- def fn(...):
60
- # CUDA is available here
61
- pass
62
- ```
63
-
64
- Args:
65
- task (`Callable | None`): Python function that requires CUDA
66
- duration (`int | datetime.timedelta`): Estimated duration in seconds or `datetime.timedelta`
67
-
68
- Returns:
69
- `Callable`: GPU-ready function
70
- """
71
- if "enable_queue" in kwargs:
72
- warnings.warn("`enable_queue` parameter is now ignored and always set to `True`")
73
- if task is None:
74
- return partial(_GPU, duration=duration)
75
- return _GPU(task, duration)
76
-
77
-
78
- def _GPU(
79
- task: Callable[P, R],
80
- duration: DynamicDuration[P],
81
- ) -> Callable[P, R]:
82
-
83
- if not Config.zero_gpu:
84
- return task
85
-
86
- from . import client
87
- from .wrappers import regular_function_wrapper
88
- from .wrappers import generator_function_wrapper
89
-
90
- if sys.version_info.minor < 9: # pragma: no cover
91
- raise RuntimeError("Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+")
92
-
93
- if task in decorated_cache:
94
- # TODO: Assert same duration ?
95
- return decorated_cache[task] # type: ignore
96
-
97
- if inspect.iscoroutinefunction(task):
98
- raise NotImplementedError
99
-
100
- if inspect.isgeneratorfunction(task):
101
- decorated = generator_function_wrapper(task, duration)
102
- else:
103
- decorated = regular_function_wrapper(task, duration)
104
-
105
- setattr(decorated, 'zerogpu', None)
106
-
107
- client.startup_report()
108
- decorated_cache.update({
109
- task: decorated,
110
- decorated: decorated,
111
- })
112
-
113
- return decorated # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/gradio.py DELETED
@@ -1,150 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from functools import wraps
6
- from packaging import version
7
- from typing import Callable
8
- from typing import NamedTuple
9
- from typing import TYPE_CHECKING
10
- import warnings
11
-
12
- import gradio as gr
13
- from gradio.context import Context
14
- from gradio.context import LocalContext
15
- from gradio.helpers import Progress
16
- from gradio.helpers import TrackedIterable
17
- from gradio.queueing import Queue
18
- from typing_extensions import ParamSpec
19
-
20
- from ..utils import SimpleQueue
21
- from .types import GeneratorResQueueResult
22
- from .types import GradioQueueEvent
23
- from .types import RegularResQueueResult
24
-
25
-
26
- QUEUE_RPC_METHODS = [
27
- "set_progress",
28
- "log_message",
29
- ]
30
-
31
-
32
- class GradioPartialContext(NamedTuple):
33
- event_id: str | None
34
- in_event_listener: bool
35
- progress: Progress | None
36
-
37
- @staticmethod
38
- def get():
39
- TrackedIterable.__reduce__ = tracked_iterable__reduce__
40
- return GradioPartialContext(
41
- event_id=LocalContext.event_id.get(),
42
- in_event_listener=LocalContext.in_event_listener.get(),
43
- progress=LocalContext.progress.get(),
44
- )
45
-
46
- @staticmethod
47
- def apply(context: 'GradioPartialContext'):
48
- LocalContext.event_id.set(context.event_id)
49
- LocalContext.in_event_listener.set(context.in_event_listener)
50
- LocalContext.progress.set(context.progress)
51
-
52
-
53
- def get_queue_instance():
54
- blocks = LocalContext.blocks.get()
55
- if blocks is None: # pragma: no cover
56
- return None
57
- return blocks._queue
58
-
59
-
60
- def get_event():
61
- queue = get_queue_instance()
62
- event_id = LocalContext.event_id.get()
63
- if queue is None:
64
- return None
65
- if event_id is None: # pragma: no cover
66
- return None
67
- for job in queue.active_jobs:
68
- if job is None: # pragma: no cover
69
- continue
70
- for event in job:
71
- if event._id == event_id:
72
- return event
73
-
74
-
75
- def get_server_port() -> int | None:
76
- from_request_context = True
77
- if (blocks := LocalContext.blocks.get()) is None: # Request
78
- from_request_context = False
79
- if (blocks := Context.root_block) is None: # Caching
80
- return None
81
- if (server := getattr(blocks, 'server', None)) is None:
82
- if from_request_context:
83
- warnings.warn("Gradio: No blocks.server inside a request") # pragma: no cover
84
- return -1
85
- if TYPE_CHECKING:
86
- assert (server := blocks.server)
87
- return server.config.port
88
-
89
-
90
- def try_process_queue_event(method_name: str, *args, **kwargs):
91
- queue = get_queue_instance()
92
- if queue is None: # pragma: no cover
93
- warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance")
94
- return
95
- method = getattr(queue, method_name, None)
96
- assert callable(method)
97
- method(*args, **kwargs)
98
-
99
-
100
- def patch_gradio_queue(
101
- res_queue: SimpleQueue[RegularResQueueResult | None] | SimpleQueue[GeneratorResQueueResult | None],
102
- ):
103
-
104
- def rpc_method(method_name: str):
105
- def method(*args, **kwargs):
106
- if args and isinstance(args[0], Queue):
107
- args = args[1:] # drop `self`
108
- res_queue.put(GradioQueueEvent(method_name, args, kwargs))
109
- return method
110
-
111
- for method_name in QUEUE_RPC_METHODS:
112
- if (method := getattr(Queue, method_name, None)) is None: # pragma: no cover
113
- warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute")
114
- continue
115
- if not callable(method): # pragma: no cover
116
- warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable")
117
- continue
118
- setattr(Queue, method_name, rpc_method(method_name))
119
-
120
- TrackedIterable.__reduce__ = tracked_iterable__reduce__
121
-
122
-
123
- def tracked_iterable__reduce__(self):
124
- res: tuple = super(TrackedIterable, self).__reduce__() # type: ignore
125
- cls, base, state, *_ = res
126
- return cls, base,{**state, **{
127
- 'iterable': None,
128
- '_tqdm': None,
129
- }}
130
-
131
-
132
- def supports_auth():
133
- return version.parse(gr.__version__) >= version.Version('4.27.0')
134
-
135
-
136
- Param = ParamSpec('Param')
137
-
138
- def one_launch(task: Callable[Param, None], *task_args: Param.args, **task_kwargs: Param.kwargs):
139
- _launch = gr.Blocks.launch
140
- @wraps(gr.Blocks.launch)
141
- def launch(*args, **kwargs):
142
- task(*task_args, **task_kwargs)
143
- gr.Blocks.launch = _launch
144
- return gr.Blocks.launch(*args, **kwargs)
145
- gr.Blocks.launch = launch
146
-
147
-
148
- class HTMLError(gr.Error):
149
- def __str__(self): # pragma: no cover
150
- return self.message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/__init__.py DELETED
@@ -1,42 +0,0 @@
1
- """
2
- """
3
-
4
- from ...config import Config
5
-
6
-
7
- try:
8
-
9
- import torch
10
-
11
- except ImportError:
12
-
13
- _patch = lambda *args, **kwargs: None
14
- _unpatch = lambda *args, **kwargs: None
15
- _pack = lambda *args, **kwargs: None
16
- _init = lambda *args, **kwargs: None
17
- _size = lambda *args, **kwargs: 0
18
- _move = lambda *args, **kwargs: None
19
- _is_in_bad_fork = lambda *args, **kwargs: False
20
-
21
- else:
22
-
23
- if Config.zero_gpu_v2:
24
- from . import patching as _patching
25
- else: # pragma: no cover
26
- from . import patching_legacy as _patching
27
-
28
- _patch = _patching.patch
29
- _unpatch = _patching.unpatch
30
- _pack = _patching.pack
31
- _init = _patching.init
32
- _size = _patching.size
33
- _move = _patching.move
34
- _is_in_bad_fork = _patching.is_in_bad_fork
35
-
36
- patch = _patch
37
- unpatch = _unpatch
38
- pack = _pack
39
- init = _init
40
- size = _size
41
- move = _move
42
- is_in_bad_fork = _is_in_bad_fork
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/bitsandbytes.py DELETED
@@ -1,162 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import importlib
8
- from contextlib import contextmanager
9
- from importlib import metadata
10
- from types import ModuleType
11
- from typing import TYPE_CHECKING
12
- from typing import Tuple
13
-
14
- import torch
15
- from packaging import version
16
-
17
- if TYPE_CHECKING:
18
- import torch as Torch
19
-
20
-
21
- @contextmanager
22
- def cuda_unavailable(torch: ModuleType):
23
- _is_available = torch.cuda.is_available
24
- torch.cuda.is_available = lambda: False
25
- yield
26
- torch.cuda.is_available = _is_available
27
-
28
-
29
- def maybe_import_bitsandbytes():
30
- try:
31
- import torch
32
- except ImportError: # pragma: no cover
33
- return None
34
- with cuda_unavailable(torch):
35
- try:
36
- import bitsandbytes
37
- except ImportError:
38
- bitsandbytes = None
39
- else:
40
- if (bnb_version := version.parse(metadata.version('bitsandbytes'))) < version.parse('0.40.0'):
41
- raise RuntimeError(f"ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})") # pragma: no cover
42
- print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑")
43
- return bitsandbytes
44
-
45
-
46
- if (bnb := maybe_import_bitsandbytes()):
47
-
48
- from torch.utils.weak import WeakTensorKeyDictionary
49
-
50
- with cuda_unavailable(torch):
51
- from bitsandbytes import cextension
52
- from bitsandbytes import functional
53
- try: # bitsandbytes < 0.44
54
- from bitsandbytes.cuda_setup.main import CUDASetup
55
- except ModuleNotFoundError: # pragma: no cover
56
- CUDASetup = None
57
- from bitsandbytes.nn import Int8Params
58
- from bitsandbytes.nn import Params4bit
59
-
60
- _param_to_8bit = Int8Params.to # type: ignore
61
- _param_cuda_8bit = Int8Params.cuda
62
- _param_to_4bit = Params4bit.to # type: ignore
63
- _param_cuda_4bit = Params4bit.cuda
64
-
65
- TensorToArgs = Tuple[torch.device, torch.dtype, bool, torch.memory_format]
66
-
67
- to_ops_8bit: dict[Int8Params, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
68
- to_ops_4bit: dict[Params4bit, TensorToArgs | None] = WeakTensorKeyDictionary() # type: ignore
69
-
70
- def _to_op_register_8bit(self: Int8Params, *args, **kwargs):
71
- parsed = torch._C._nn._parse_to(*args, **kwargs)
72
- device, *_ = parsed
73
- if not isinstance(device, torch.device): # pragma: no cover
74
- return _param_to_8bit(self, *args, **kwargs)
75
- if device.type != 'cuda':
76
- return _param_to_8bit(self, *args, **kwargs)
77
- to_ops_8bit[self] = parsed
78
- return self
79
-
80
- def _to_op_register_4bit(self: Params4bit, *args, **kwargs):
81
- parsed = torch._C._nn._parse_to(*args, **kwargs)
82
- device, *_ = parsed
83
- if not isinstance(device, torch.device): # pragma: no cover
84
- return _param_to_4bit(self, *args, **kwargs)
85
- if device.type != 'cuda':
86
- return _param_to_4bit(self, *args, **kwargs)
87
- to_ops_4bit[self] = parsed
88
- return self
89
-
90
- def _cuda_op_arg_check(device: Torch.device | int | str | None) -> bool:
91
- if device is None: # pragma: no cover
92
- return True
93
- if isinstance(device, int):
94
- return True
95
- if isinstance(device, str): # pragma: no cover
96
- device = torch.device(device)
97
- return device.type == 'cuda' # pragma: no cover
98
-
99
- def _cuda_op_register_8bit(self: Int8Params, device: Torch.device | int | str | None = None, **kwargs):
100
- if not _cuda_op_arg_check(device): # pragma: no cover
101
- # Let PyTorch handle the fail
102
- return _param_cuda_8bit(self, device, **kwargs)
103
- to_ops_8bit[self] = None
104
- return self
105
-
106
- def _cuda_op_register_4bit(self: Params4bit, device: Torch.device | int | str | None = None, **kwargs):
107
- if not _cuda_op_arg_check(device): # pragma: no cover
108
- # Let PyTorch handle the fail
109
- return _param_cuda_4bit(self, device, **kwargs)
110
- to_ops_4bit[self] = None
111
- return self
112
-
113
- def _patch():
114
- Int8Params.to = _to_op_register_8bit # type: ignore
115
- Int8Params.cuda = _cuda_op_register_8bit # type: ignore
116
- Params4bit.to = _to_op_register_4bit # type: ignore
117
- Params4bit.cuda = _cuda_op_register_4bit # type: ignore
118
-
119
- def _unpatch():
120
- Int8Params.to = _param_to_8bit # type: ignore
121
- Int8Params.cuda = _param_cuda_8bit
122
- Params4bit.to = _param_to_4bit # type: ignore
123
- Params4bit.cuda = _param_cuda_4bit
124
-
125
- def _move():
126
- if CUDASetup is not None:
127
- CUDASetup._instance = None
128
- importlib.reload(cextension)
129
- functional.lib = cextension.lib
130
- for op in to_ops_8bit.items():
131
- tensor, parsed_args = op
132
- if parsed_args:
133
- _, dtype, _, memory_format = parsed_args
134
- else:
135
- dtype, memory_format = None, None
136
- tensor.data = _param_to_8bit(tensor,
137
- device='cuda',
138
- dtype=dtype,
139
- memory_format=memory_format,
140
- ) # type: ignore
141
- for op in to_ops_4bit.items():
142
- tensor, parsed_args = op
143
- if parsed_args:
144
- _, dtype, _, memory_format = parsed_args
145
- else:
146
- dtype, memory_format = None, None
147
- tensor.data = _param_to_4bit(tensor,
148
- device='cuda',
149
- dtype=dtype,
150
- memory_format=memory_format,
151
- ) # type: ignore
152
-
153
- else:
154
-
155
- _patch = lambda: None
156
- _unpatch = lambda: None
157
- _move = lambda: None
158
-
159
-
160
- patch = _patch
161
- unpatch = _unpatch
162
- move = _move
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/packing.py DELETED
@@ -1,209 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import time
6
-
7
- import ctypes
8
- import os
9
- from concurrent.futures import as_completed
10
- from concurrent.futures import ThreadPoolExecutor
11
- from contextvars import copy_context
12
- from dataclasses import dataclass
13
- from queue import Queue
14
- from typing import Callable
15
-
16
- from ...utils import debug
17
-
18
- import torch
19
- from typing_extensions import TypeAlias
20
-
21
-
22
- PAGE_SIZE = 4096
23
- TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')
24
- VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2)
25
-
26
- BUFFER_SIZE = 64 * 2**20
27
- BUFFER_COUNT = 2
28
-
29
-
30
- TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]'
31
-
32
- @dataclass
33
- class ZeroGPUTensorPack:
34
- base_dir: str
35
- batches: list[list[TensorWithSizes]]
36
- big_tensors: list[TensorWithSizes]
37
- fakes: dict[torch.Tensor, list[torch.Tensor]]
38
- total_size: int
39
- def path(self):
40
- return f'{self.base_dir}/{id(self)}'
41
- def __del__(self):
42
- try:
43
- os.remove(self.path())
44
- except FileNotFoundError: # pragma: no cover
45
- pass
46
-
47
-
48
- def write(fd: int, tensor: torch.Tensor):
49
- clone = torch.empty_like(tensor)
50
- size = clone.untyped_storage().size() # pyright: ignore [reportAttributeAccessIssue]
51
- buffer = torch.UntypedStorage(VM_MAX_SIZE)
52
- buffer_ptr = buffer.data_ptr()
53
- offset = -buffer_ptr % PAGE_SIZE
54
- padding = -size % PAGE_SIZE
55
- clone.set_(buffer[offset:offset+size], 0, clone.shape, clone.stride()) # pyright: ignore [reportArgumentType]
56
- clone.copy_(tensor)
57
- mv = memoryview((ctypes.c_char * (size+padding)).from_address(buffer_ptr+offset))
58
- written_bytes = 0
59
- while written_bytes < size:
60
- written_bytes += os.write(fd, mv[written_bytes:])
61
-
62
-
63
- def pack_tensors(
64
- tensors: set[torch.Tensor],
65
- fakes: dict[torch.Tensor, list[torch.Tensor]],
66
- offload_dir: str,
67
- callback: Callable[[int]] | None = None,
68
- ):
69
-
70
- callback = (lambda bytes: None) if callback is None else callback
71
-
72
- batches: list[list[TensorWithSizes]] = []
73
- big_tensors: list[TensorWithSizes] = []
74
-
75
- tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = []
76
- for tensor in tensors:
77
- size = tensor.numel() * tensor.element_size()
78
- aligned_size = size + (-size % PAGE_SIZE)
79
- tensors_with_sizes += [(tensor, size, aligned_size)]
80
-
81
- current_batch, current_size = [], 0
82
- for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]):
83
- if aligned_size > BUFFER_SIZE:
84
- big_tensors += [(tensor, size, aligned_size)]
85
- continue
86
- current_size += aligned_size
87
- if current_size > BUFFER_SIZE:
88
- batches += [current_batch]
89
- current_batch, current_size = [(tensor, size, aligned_size)], aligned_size
90
- else:
91
- current_batch += [(tensor, size, aligned_size)]
92
-
93
- if current_batch:
94
- batches += [current_batch]
95
-
96
- get_meta = {tensor: torch.empty_like(tensor) for tensor in tensors}
97
- batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches]
98
- big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors]
99
- fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()}
100
-
101
- pack = ZeroGPUTensorPack(
102
- base_dir=offload_dir,
103
- batches=batches_meta,
104
- big_tensors=big_tensors_meta,
105
- fakes=fakes_meta,
106
- total_size=sum([size for _, size, _ in tensors_with_sizes]),
107
- )
108
-
109
- fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT)
110
- try:
111
- total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch])
112
- total_asize += sum([aligned_size for *_, aligned_size in big_tensors])
113
- if total_asize > 0:
114
- os.posix_fallocate(fd, 0, total_asize)
115
- for batch in batches:
116
- for tensor, size, _ in batch:
117
- write(fd, tensor)
118
- callback(size)
119
- for tensor, size, _ in big_tensors:
120
- write(fd, tensor)
121
- callback(size)
122
- return pack
123
- finally:
124
- os.close(fd)
125
-
126
-
127
- def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int]] | None = None):
128
-
129
- callback = (lambda bytes: None) if callback is None else callback
130
-
131
- free_buffers: Queue[torch.Tensor] = Queue()
132
- read_buffers: Queue[torch.Tensor] = Queue()
133
-
134
- for _ in range(BUFFER_COUNT):
135
- free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory())
136
-
137
- def read(fd: int, buffer: torch.Tensor, size: int):
138
- mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr()))
139
- read_bytes = 0
140
- while read_bytes < size:
141
- read_bytes += os.readv(fd, [mv[read_bytes:]])
142
-
143
- def disk_to_pin(fd: int):
144
- for batch in pack.batches:
145
- buffer = free_buffers.get()
146
- batch_size = sum([aligned_size for *_, aligned_size in batch])
147
- read(fd, buffer, batch_size)
148
- read_buffers.put(buffer)
149
- for *_, aligned_size in pack.big_tensors:
150
- read_bytes = 0
151
- while read_bytes < aligned_size:
152
- buffer = free_buffers.get()
153
- read_size = min(BUFFER_SIZE, aligned_size - read_bytes)
154
- read(fd, buffer, read_size)
155
- read_buffers.put(buffer)
156
- read_bytes += read_size
157
-
158
- def pin_to_cuda():
159
- total_duration_in_callback = 0
160
- for batch in pack.batches:
161
- buffer = read_buffers.get()
162
- offset = 0
163
- cuda_storages = []
164
- for tensor, size, aligned_size in batch:
165
- cuda_storages += [buffer[offset:offset+size].cuda(non_blocking=True)]
166
- offset += aligned_size
167
- torch.cuda.synchronize()
168
- free_buffers.put(buffer)
169
- batch_total_size = 0
170
- for (tensor, size, _), cuda_storage in zip(batch, cuda_storages):
171
- cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
172
- cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
173
- for fake in pack.fakes[tensor]:
174
- fake.data = cuda_tensor
175
- batch_total_size += size
176
- t0 = time.perf_counter()
177
- callback(batch_total_size)
178
- total_duration_in_callback += time.perf_counter() - t0
179
- for tensor, size, _ in pack.big_tensors:
180
- cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda')
181
- offset = 0
182
- while offset < size:
183
- buffer = read_buffers.get()
184
- read_size = min(BUFFER_SIZE, size - offset)
185
- cuda_storage[offset:offset+read_size] = buffer[:read_size]
186
- offset += read_size
187
- torch.cuda.synchronize() # Probably not needed
188
- free_buffers.put(buffer)
189
- t0 = time.perf_counter()
190
- callback(read_size)
191
- total_duration_in_callback += time.perf_counter() - t0
192
- cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda')
193
- cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride())
194
- for fake in pack.fakes[tensor]:
195
- fake.data = cuda_tensor
196
-
197
- debug(f"{total_duration_in_callback=}")
198
-
199
- with ThreadPoolExecutor(2) as e:
200
- fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT)
201
- try:
202
- futures = [
203
- e.submit(copy_context().run, disk_to_pin, fd),
204
- e.submit(copy_context().run, pin_to_cuda),
205
- ]
206
- for future in as_completed(futures):
207
- future.result()
208
- finally:
209
- os.close(fd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/patching.py DELETED
@@ -1,386 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import gc
8
- import multiprocessing
9
- import os
10
- from collections import defaultdict
11
- from concurrent.futures import ProcessPoolExecutor
12
- from concurrent.futures import ThreadPoolExecutor
13
- from contextlib import nullcontext
14
- from contextvars import copy_context
15
- from types import SimpleNamespace
16
- from typing import Any
17
- from typing import Callable
18
-
19
- import torch
20
- from torch.overrides import TorchFunctionMode
21
- from torch.overrides import resolve_name
22
- from torch.utils._python_dispatch import TorchDispatchMode
23
- from torch.utils._pytree import tree_map_only
24
- from torch.utils.weak import WeakTensorKeyDictionary
25
-
26
- from ...config import Config
27
- from ...utils import malloc_trim
28
- from ..tqdm import tqdm
29
- from . import bitsandbytes
30
- from .packing import ZeroGPUTensorPack
31
- from .packing import pack_tensors
32
- from .packing import pack_to_cuda
33
- from .types import AliasId
34
-
35
-
36
- # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
37
- CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
38
- CUDA_TOTAL_MEMORY = 42144366592
39
- CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
40
- CUDA_DEVICE_CAPABILITY = (8, 0)
41
- CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
42
-
43
- OPS_INPUTS_CHECK_NO_RETURN = (
44
- torch.Tensor.equal,
45
- )
46
-
47
- OPS_INPUT_CHECK_SELF_RETURN = (
48
- torch.Tensor.set_, # probably never dispatched
49
- torch.ops.aten.set_.source_Tensor, # pyright: ignore [reportAttributeAccessIssue]
50
- )
51
-
52
- OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}"
53
-
54
- _tensor_make_subclass = torch.Tensor._make_subclass
55
- _asarray = torch.asarray
56
- _cuda_init = torch._C._cuda_init
57
- _cuda_exchange_device = torch.cuda._exchange_device
58
- _cuda_available = torch.cuda.is_available
59
- _cuda_device_count = torch.cuda.device_count
60
- _cuda_current_device = torch.cuda.current_device
61
- _cuda_mem_get_info = torch.cuda.mem_get_info
62
- _cuda_get_device_capability = torch.cuda.get_device_capability
63
- _cuda_get_device_properties = torch.cuda.get_device_properties
64
- _cuda_get_device_name = torch.cuda.get_device_name
65
-
66
- # PyTorch 2.3
67
- _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None)
68
-
69
-
70
- cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() # pyright: ignore [reportAssignmentType]
71
-
72
- tensor_packs: list[ZeroGPUTensorPack] = []
73
-
74
- class ZeroGPUTensor(torch.Tensor):
75
- pass
76
-
77
- def empty_fake(tensor: torch.Tensor):
78
- fake = torch.empty_like(tensor, requires_grad=tensor.requires_grad)
79
- if fake.__class__ != tensor.__class__:
80
- fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) # pyright: ignore [reportArgumentType]
81
- return fake
82
-
83
- class ZeroGPUFunctionMode(TorchFunctionMode):
84
-
85
- def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
86
-
87
- kwargs = {} if kwargs is None else kwargs
88
-
89
- if func == torch._C._nn._parse_to:
90
- return func(*args, **kwargs)
91
-
92
- # Redispatch: tensor.cuda() -> tensor.to(device='cuda')
93
- if func == torch.Tensor.cuda or func == torch.Tensor.cpu:
94
- memory_format = kwargs.get('memory_format')
95
- return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
96
- 'device': 'cuda' if func == torch.Tensor.cuda else 'cpu',
97
- **({'memory_format': memory_format} if memory_format is not None else {}),
98
- })
99
-
100
- # Redispatch: tensor.to('cuda') -> tensor.to(device='cuda')
101
- if func == torch.Tensor.to and len(args) > 1:
102
- device, dtype, _, memory_format = torch._C._nn._parse_to(*args[1:], **kwargs)
103
- return self.__torch_function__(torch.Tensor.to, types, (args[0],), {
104
- 'device': device,
105
- 'dtype': dtype,
106
- 'memory_format': memory_format,
107
- })
108
-
109
- if func == torch.Tensor.data.__set__: # pyright: ignore [reportAttributeAccessIssue]
110
- self, target = args
111
- if target in cuda_aliases:
112
- if (target_original := cuda_aliases[target]) is None:
113
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), target))
114
- original = empty_fake(self)
115
- original.data = target_original
116
- cuda_aliases[self] = original
117
- elif self in cuda_aliases:
118
- del cuda_aliases[self]
119
- self.data = target
120
- return
121
-
122
- if func == torch.Tensor.device.__get__:
123
- tensor, = args
124
- if tensor in cuda_aliases:
125
- return torch.device('cuda', index=0)
126
-
127
- elif func == torch.Tensor.__repr__:
128
- tensor, = args
129
- if tensor in cuda_aliases:
130
- if (original := cuda_aliases[tensor]) is None:
131
- original = tensor.to('meta')
132
- original_class = original.__class__
133
- original.__class__ = ZeroGPUTensor
134
- try:
135
- return func(original, **kwargs)
136
- finally:
137
- original.__class__ = original_class
138
-
139
- elif func == torch.Tensor.untyped_storage:
140
- tensor, = args
141
- if tensor in cuda_aliases:
142
- if (original := cuda_aliases[tensor]) is None:
143
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
144
- res = func(original, **kwargs)
145
- res._zerogpu = True
146
- return res
147
-
148
- cuda: bool | None = None
149
-
150
- # Handle device kwarg
151
- if (device := kwargs.get('device')) is not None:
152
- device = torch.device(device)
153
- if device.type == 'cuda':
154
- kwargs['device'] = torch.device('cpu')
155
- cuda = True
156
- else:
157
- cuda = False
158
-
159
- # Swap fake inputs with original data
160
- swapped = {}
161
- inputs_are_cuda = set()
162
- def swap(tensor: torch.Tensor):
163
- nonlocal inputs_are_cuda
164
- if tensor not in cuda_aliases:
165
- inputs_are_cuda |= {False}
166
- return tensor
167
- if (original := cuda_aliases[tensor]) is None:
168
- raise Exception(OFFLOADED_ERROR_MESSAGE.format(resolve_name(func), tensor))
169
- swapped[original] = tensor
170
- inputs_are_cuda |= {True}
171
- return original
172
- args_ = tree_map_only(torch.Tensor, swap, args)
173
- kwargs_ = tree_map_only(torch.Tensor, swap, kwargs)
174
- if inputs_are_cuda == {True}:
175
- if cuda is not False:
176
- cuda = True
177
-
178
- res = func(*args_, **kwargs_)
179
-
180
- # Re-generate swapped fakes in case of mutation
181
- for original, fake in swapped.items():
182
- fake.data = empty_fake(original)
183
-
184
- # Special case for Tensor indexing where only 'self' matters
185
- if func in {
186
- torch.ops.aten.index.Tensor, # pyright: ignore [reportAttributeAccessIssue]
187
- torch.Tensor.__getitem__, # PyTorch 2.4+
188
- }:
189
- self = args[0]
190
- cuda = self in cuda_aliases
191
- inputs_are_cuda = {cuda}
192
-
193
- # Emulate device check
194
- if isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN:
195
- self = None
196
- if len(args_) >= 1 and isinstance(args_[0], torch.Tensor):
197
- self = args_[0]
198
- # Only raise if func does not return its first input (Tensor.copy_)
199
- if res is not self or func in OPS_INPUT_CHECK_SELF_RETURN:
200
- if inputs_are_cuda == {True, False}:
201
- raise RuntimeError(
202
- "Expected all tensors to be on the same device, "
203
- "but found at least two devices, cuda:0 (ZeroGPU) and cpu!"
204
- )
205
-
206
- # Register output
207
- def register(tensor: torch.Tensor):
208
- if tensor in swapped and cuda is not False:
209
- return swapped[tensor]
210
- if cuda is not True:
211
- return tensor
212
- fake = empty_fake(tensor)
213
- cuda_aliases[fake] = tensor
214
- return fake
215
-
216
- return tree_map_only(torch.Tensor, register, res)
217
-
218
- # When enabling DispatchMode, some aten ops are dispatched to FunctionMode
219
- # We are using it for aten.alias.default and aten.set_.source_Tensor
220
- class DefaultDispatchMode(TorchDispatchMode):
221
- def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None):
222
- return func(*args, **(kwargs or {}))
223
-
224
-
225
- function_mode = ZeroGPUFunctionMode()
226
- dispatch_mode = DefaultDispatchMode()
227
-
228
-
229
- def _untyped_storage_new_register(*args, **kwargs):
230
- cuda = False
231
- if (device := kwargs.get('device')) is not None and device.type == 'cuda':
232
- cuda = True
233
- del kwargs['device']
234
- storage = torch._C.StorageBase.__new__(*args, **kwargs)
235
- if cuda:
236
- storage._zerogpu = True
237
- return storage
238
-
239
- @property
240
- def _untyped_storage_device(self):
241
- if hasattr(self, '_zerogpu'):
242
- return torch.device('cuda', index=0)
243
- return torch._C.StorageBase.device.__get__(self) # pyright: ignore [reportAttributeAccessIssue]
244
-
245
- # Force dispatch
246
- def _tensor_make_subclass_function_mode(*args, **kwargs):
247
- with torch._C.DisableTorchFunction():
248
- return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs)
249
- def _asarray_function_mode(*args, **kwargs):
250
- with torch._C.DisableTorchFunction():
251
- return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs)
252
-
253
- def _cuda_init_raise():
254
- raise RuntimeError(
255
- "CUDA must not be initialized in the main process "
256
- "on Spaces with Stateless GPU environment.\n"
257
- "You can look at this Stacktrace to find out "
258
- "which part of your code triggered a CUDA init"
259
- )
260
-
261
- def _cuda_dummy_exchange_device(device):
262
- assert device in {-1, 0}
263
- return device
264
-
265
- def patch():
266
- function_mode.__enter__()
267
- dispatch_mode.__enter__()
268
- # TODO: only patch bellow methods on current Thread to be consistent with TorchModes
269
- # (or hijack threading.Thread.__init__ to force Modes on all threads)
270
- torch.Tensor._make_subclass = _tensor_make_subclass_function_mode # pyright: ignore [reportAttributeAccessIssue]
271
- torch.UntypedStorage.__new__ = _untyped_storage_new_register
272
- torch.UntypedStorage.device = _untyped_storage_device # pyright: ignore [reportAttributeAccessIssue]
273
- torch.asarray = _asarray_function_mode
274
- torch._C._cuda_init = _cuda_init_raise
275
- torch.cuda._exchange_device = _cuda_dummy_exchange_device
276
- torch.cuda.is_available = lambda: True
277
- torch.cuda.device_count = lambda: 1
278
- torch.cuda.current_device = lambda: 0
279
- torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
280
- torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
281
- torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
282
- torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
283
- # PyTorch 2.3
284
- if _cuda_maybe_exchange_device is not None: # pragma: no cover
285
- setattr(torch.cuda, '_maybe_exchange_device', _cuda_dummy_exchange_device)
286
- bitsandbytes.patch()
287
-
288
- def unpatch():
289
- try:
290
- dispatch_mode.__exit__(None, None, None)
291
- function_mode.__exit__(None, None, None)
292
- except RuntimeError:
293
- pass # patch() and unpatch() called from != threads
294
- torch.Tensor._make_subclass = _tensor_make_subclass
295
- torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__
296
- torch.UntypedStorage.device = torch._C.StorageBase.device # pyright: ignore [reportAttributeAccessIssue]
297
- torch.asarray = _asarray
298
- torch._C._cuda_init = _cuda_init
299
- torch.cuda._exchange_device = _cuda_exchange_device
300
- torch.cuda.is_available = _cuda_available
301
- torch.cuda.device_count = _cuda_device_count
302
- torch.cuda.current_device = _cuda_current_device
303
- torch.cuda.mem_get_info = _cuda_mem_get_info
304
- torch.cuda.get_device_capability = _cuda_get_device_capability
305
- torch.cuda.get_device_properties = _cuda_get_device_properties
306
- torch.cuda.get_device_name = _cuda_get_device_name
307
- # PyTorch 2.3
308
- if _cuda_maybe_exchange_device is not None: # pragma: no cover
309
- setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device)
310
- bitsandbytes.unpatch()
311
-
312
-
313
- def _total_unpacked_size():
314
- tensors = [tensor for tensor in cuda_aliases.values() if tensor is not None]
315
- deduped = {AliasId.from_tensor(tensor): tensor for tensor in tensors}
316
- return sum([tensor.numel() * tensor.element_size() for tensor in deduped.values()])
317
-
318
-
319
- def _pack(offload_dir: str):
320
- # Pack to disk
321
- originals: set[torch.Tensor] = set()
322
- originals_dedup: dict[AliasId, torch.Tensor] = {}
323
- fakes: dict[torch.Tensor, list[torch.Tensor]] = defaultdict(list)
324
- for fake, original in cuda_aliases.items():
325
- # TODO filter-out sparse Tensors
326
- if original is not None:
327
- original_id = AliasId.from_tensor(original)
328
- if original_id not in originals_dedup:
329
- originals_dedup[original_id] = original
330
- originals |= {original}
331
- fakes[originals_dedup[original_id]] += [fake]
332
- progress = tqdm(
333
- total=_total_unpacked_size(),
334
- unit='B',
335
- unit_scale=True,
336
- desc="ZeroGPU tensors packing",
337
- ) if tqdm is not None else nullcontext()
338
- with progress as progress:
339
- update = progress.update if progress is not None else lambda _: None
340
- pack = pack_tensors(originals, fakes, offload_dir, callback=update)
341
- tensor_packs.append(pack)
342
- # Free memory
343
- for fake_list in fakes.values():
344
- for fake in fake_list:
345
- cuda_aliases[fake] = None
346
-
347
- def pack():
348
- _pack(Config.zerogpu_offload_dir)
349
- gc.collect()
350
- malloc_trim()
351
-
352
- def init(nvidia_uuid: str):
353
- os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
354
- torch.Tensor([0]).cuda()
355
-
356
- def size():
357
- return _total_unpacked_size() + sum([pack.total_size for pack in tensor_packs])
358
-
359
- def _move(callback: Callable[[int]] | None = None):
360
- callback = callback if callback is not None else lambda _: None
361
- # CPU -> CUDA
362
- moved: dict[AliasId, torch.Tensor] = {}
363
- for fake, original in cuda_aliases.items():
364
- if original is not None:
365
- original_id = AliasId.from_tensor(original)
366
- if original_id not in moved:
367
- moved[original_id] = original.cuda()
368
- callback(fake.numel() * fake.element_size())
369
- for fake, original in cuda_aliases.items():
370
- if original is not None:
371
- fake.data = moved[AliasId.from_tensor(original)]
372
- # Disk -> CUDA
373
- for tensor_pack in tensor_packs:
374
- pack_to_cuda(tensor_pack, callback=callback)
375
- bitsandbytes.move()
376
-
377
- def move(callback: Callable[[int]] | None = None):
378
- callback = callback if callback is not None else lambda _: None
379
- with ThreadPoolExecutor(1) as e:
380
- e.submit(copy_context().run, _move, callback=callback).result()
381
- torch.cuda.synchronize()
382
-
383
- def is_in_bad_fork():
384
- with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
385
- f = e.submit(torch.cuda._is_in_bad_fork)
386
- return f.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/patching_legacy.py DELETED
@@ -1,266 +0,0 @@
1
- """
2
- """
3
- # pyright: reportPrivateImportUsage=false
4
-
5
- from __future__ import annotations
6
-
7
- import multiprocessing
8
- import os
9
- from concurrent.futures import ProcessPoolExecutor
10
- from contextlib import suppress
11
- from functools import partial
12
- from types import SimpleNamespace
13
- from typing import Any
14
- from typing import Callable
15
- from typing import Optional
16
- from typing import Tuple
17
-
18
- import torch
19
- from torch.utils.weak import WeakTensorKeyDictionary
20
-
21
- from ...config import Config
22
- from . import bitsandbytes
23
-
24
-
25
- # Nvidia A100.80G MIG (drivers 535) / Torch 2.2.0
26
- CUDA_DEVICE_NAME = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb'
27
- CUDA_TOTAL_MEMORY = 42144366592
28
- CUDA_MEM_GET_INFO = (41911451648, CUDA_TOTAL_MEMORY)
29
- CUDA_DEVICE_CAPABILITY = (8, 0)
30
- CUDA_DEVICE_PROPERTIES = SimpleNamespace(name=CUDA_DEVICE_NAME, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY, multi_processor_count=42)
31
-
32
- GENERIC_METHOD_NAMES = [
33
- 'arange',
34
- 'as_tensor',
35
- 'asarray',
36
- 'bartlett_window',
37
- 'blackman_window',
38
- 'empty',
39
- 'empty_like',
40
- 'empty_strided',
41
- 'eye',
42
- 'full',
43
- 'full_like',
44
- 'hamming_window',
45
- 'hann_window',
46
- 'kaiser_window',
47
- 'linspace',
48
- 'logspace',
49
- 'ones',
50
- 'ones_like',
51
- 'rand',
52
- 'rand_like',
53
- 'randint',
54
- 'randint_like',
55
- 'randn',
56
- 'randn_like',
57
- 'randperm',
58
- 'range',
59
- 'sparse_bsc_tensor',
60
- 'sparse_bsr_tensor',
61
- 'sparse_compressed_tensor',
62
- 'sparse_coo_tensor',
63
- 'sparse_csc_tensor',
64
- 'sparse_csr_tensor',
65
- 'tensor',
66
- 'tril_indices',
67
- 'triu_indices',
68
- 'zeros',
69
- 'zeros_like',
70
- ]
71
-
72
-
73
- TO_CUDA = (torch.device('cuda'), None, False, None)
74
-
75
- _tensor__deepcopy__ = torch.Tensor.__deepcopy__
76
- _tensor_to = torch.Tensor.to
77
- _tensor_cuda = torch.Tensor.cuda
78
- _tensor_cpu = torch.Tensor.cpu
79
- _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES}
80
- _cuda_init = torch._C._cuda_init
81
- _cuda_available = torch.cuda.is_available
82
- _cuda_device_count = torch.cuda.device_count
83
- _cuda_current_device = torch.cuda.current_device
84
- _cuda_mem_get_info = torch.cuda.mem_get_info
85
- _cuda_get_device_capability = torch.cuda.get_device_capability
86
- _cuda_get_device_properties = torch.cuda.get_device_properties
87
- _cuda_get_device_name = torch.cuda.get_device_name
88
-
89
- TensorToArgs = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]]
90
-
91
- to_ops: dict[torch.Tensor, TensorToArgs] = WeakTensorKeyDictionary() # type: ignore
92
-
93
- def _tensor_new_register(*args, **kwargs):
94
- new_tensor: torch.Tensor = torch._C._TensorBase.__new__(*args, **kwargs)
95
- if (base_tensor := new_tensor._base) is not None:
96
- if base_tensor in to_ops:
97
- to_ops[new_tensor] = to_ops[base_tensor]
98
- return new_tensor
99
-
100
- def _tensor_deepcopy_register(self: torch.Tensor, memo):
101
- new_tensor = _tensor__deepcopy__(self, memo)
102
- if isinstance(new_tensor, torch.Tensor):
103
- if self in to_ops:
104
- to_ops[new_tensor] = to_ops[self]
105
- return new_tensor
106
-
107
- @property
108
- def _tensor_device_property(self: torch.Tensor):
109
- if self in to_ops:
110
- return torch.device(type='cuda', index=0)
111
- del torch.Tensor.device
112
- try:
113
- return self.device
114
- finally:
115
- torch.Tensor.device = _tensor_device_property # type: ignore
116
-
117
- @property
118
- def _tensor_dtype_property(self: torch.Tensor):
119
- if self in to_ops:
120
- if (to_dtype := to_ops[self][1]) is not None:
121
- return to_dtype
122
- del torch.Tensor.dtype
123
- try:
124
- return self.dtype
125
- finally:
126
- torch.Tensor.dtype = _tensor_dtype_property # type: ignore
127
-
128
- def _to_op_register(self: torch.Tensor, *args, **kwargs):
129
- parsed = torch._C._nn._parse_to(*args, **kwargs)
130
- device, dtype, *_ = parsed
131
- try:
132
- to_args = to_ops.pop(self)
133
- except KeyError:
134
- to_args = None
135
- if device is None: # pyright: ignore [reportUnnecessaryComparison]
136
- if to_args is not None:
137
- to_ops[self] = (to_args[0], dtype, *to_args[2:])
138
- return self
139
- return _tensor_to(self, *args, **kwargs)
140
- if device.type != 'cuda':
141
- if to_args is not None:
142
- if (to_dtype := to_args[1]) is not None:
143
- kwargs = {'dtype': to_dtype, **kwargs}
144
- return _tensor_to(self, *args, **kwargs)
145
- to_ops[self] = parsed
146
- return self
147
-
148
- def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool:
149
- if device is None:
150
- return True
151
- if isinstance(device, int):
152
- return True
153
- if isinstance(device, str):
154
- device = torch.device(device)
155
- return device.type == 'cuda'
156
-
157
- def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs):
158
- if not _cuda_op_arg_check(device):
159
- # Let PyTorch handle the fail
160
- return _tensor_cuda(self, device, **kwargs)
161
- to_ops[self] = TO_CUDA
162
- return self
163
-
164
- def _cpu_op_remove(self: torch.Tensor, **kwargs):
165
- try:
166
- to_args = to_ops.pop(self)
167
- except KeyError:
168
- to_args = None
169
- if to_args is not None:
170
- if (to_dtype := to_args[1]) is not None:
171
- return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs})
172
- return _tensor_cpu(self, **kwargs)
173
-
174
- def _cuda_init_raise():
175
- raise RuntimeError(
176
- "CUDA must not be initialized in the main process "
177
- "on Spaces with Stateless GPU environment.\n"
178
- "You can look at this Stacktrace to find out "
179
- "which part of your code triggered a CUDA init"
180
- )
181
-
182
- def _generic_method_register(name: str, *args: Any, **kwargs: Any):
183
- try:
184
- device = torch.device(kwargs.get('device', "cpu"))
185
- except Exception:
186
- return _torch_generics[name](*args, **kwargs)
187
- if device.type != 'cuda':
188
- return _torch_generics[name](*args, **kwargs)
189
- tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"})
190
- to_ops[tensor] = TO_CUDA
191
- return tensor
192
-
193
- def patch():
194
- torch.Tensor.__deepcopy__ = _tensor_deepcopy_register
195
- torch.Tensor.__new__ = _tensor_new_register # pyright: ignore [reportAttributeAccessIssue]
196
- torch.Tensor.to = _to_op_register # type: ignore
197
- torch.Tensor.cuda = _cuda_op_register # type: ignore
198
- torch.Tensor.cpu = _cpu_op_remove # type: ignore
199
- if Config.zero_patch_torch_device:
200
- torch.Tensor.device = _tensor_device_property # type: ignore
201
- torch.Tensor.dtype = _tensor_dtype_property # pyright: ignore [reportAttributeAccessIssue]
202
- for name in GENERIC_METHOD_NAMES:
203
- setattr(torch, name, partial(_generic_method_register, name))
204
- torch._C._cuda_init = _cuda_init_raise
205
- torch.cuda.is_available = lambda: True
206
- torch.cuda.device_count = lambda: 1
207
- torch.cuda.current_device = lambda: 0
208
- torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO
209
- torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY
210
- torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES
211
- torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME
212
- bitsandbytes.patch()
213
-
214
- def unpatch():
215
- torch.Tensor.__deepcopy__ = _tensor__deepcopy__
216
- with suppress(AttributeError):
217
- del torch.Tensor.__new__
218
- torch.Tensor.to = _tensor_to
219
- torch.Tensor.cuda = _tensor_cuda
220
- torch.Tensor.cpu = _tensor_cpu
221
- with suppress(AttributeError):
222
- del torch.Tensor.device
223
- with suppress(AttributeError):
224
- del torch.Tensor.dtype
225
- for name in GENERIC_METHOD_NAMES:
226
- setattr(torch, name, _torch_generics[name])
227
- torch._C._cuda_init = _cuda_init
228
- torch.cuda.is_available = _cuda_available
229
- torch.cuda.device_count = _cuda_device_count
230
- torch.cuda.current_device = _cuda_current_device
231
- torch.cuda.mem_get_info = _cuda_mem_get_info
232
- torch.cuda.get_device_capability = _cuda_get_device_capability
233
- torch.cuda.get_device_properties = _cuda_get_device_properties
234
- torch.cuda.get_device_name = _cuda_get_device_name
235
- bitsandbytes.unpatch()
236
-
237
- def pack():
238
- pass
239
-
240
- def init(nvidia_uuid: str):
241
- os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid
242
- torch.Tensor([0]).cuda() # CUDA init
243
-
244
- def size():
245
- return 0
246
-
247
- def move(callback: Callable[[int]] | None = None):
248
- for op in to_ops.items():
249
- tensor, parsed_args = op
250
- _, dtype, _, memory_format = parsed_args
251
- tensor.data = _tensor_to(tensor,
252
- device='cuda',
253
- dtype=dtype,
254
- memory_format=memory_format,
255
- ) # type: ignore
256
- bitsandbytes.move()
257
- torch.cuda.synchronize()
258
-
259
- def is_in_bad_fork():
260
- with ProcessPoolExecutor(mp_context=multiprocessing.get_context('fork')) as e:
261
- f = e.submit(torch.cuda._is_in_bad_fork)
262
- return f.result()
263
-
264
- def disable_cuda_intercept():
265
- torch.Tensor.to = _tensor_to
266
- torch.Tensor.cuda = _tensor_cuda
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/torch/types.py DELETED
@@ -1,23 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- from typing import NamedTuple
6
-
7
- import torch
8
-
9
-
10
- class AliasId(NamedTuple):
11
- data_ptr: int
12
- dtype: torch.dtype
13
- shape: tuple[int, ...]
14
- stride: tuple[int, ...]
15
-
16
- @classmethod
17
- def from_tensor(cls, tensor: torch.Tensor):
18
- return cls(
19
- tensor.data_ptr(),
20
- tensor.dtype,
21
- tensor.shape,
22
- tensor.stride(),
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/tqdm.py DELETED
@@ -1,24 +0,0 @@
1
- """
2
- """
3
-
4
- from multiprocessing.synchronize import RLock as MultiprocessingRLock
5
-
6
-
7
- try:
8
- from tqdm import tqdm as _tqdm
9
- except ImportError: # pragma: no cover
10
- _tqdm = None
11
-
12
-
13
- def remove_tqdm_multiprocessing_lock():
14
- if _tqdm is None: # pragma: no cover
15
- return
16
- tqdm_lock = _tqdm.get_lock()
17
- assert tqdm_lock.__class__.__name__ == 'TqdmDefaultWriteLock'
18
- tqdm_lock.locks = [
19
- lock for lock in tqdm_lock.locks
20
- if not isinstance(lock, MultiprocessingRLock)
21
- ]
22
-
23
-
24
- tqdm = _tqdm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/types.py DELETED
@@ -1,49 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
-
6
- from dataclasses import dataclass
7
- from datetime import timedelta
8
- from typing import Any
9
- from typing import Dict
10
- from typing import Tuple
11
- from typing import TypedDict
12
- from typing_extensions import Callable
13
- from typing_extensions import Generic
14
- from typing_extensions import ParamSpec
15
- from typing_extensions import TypeAlias
16
- from typing_extensions import TypeVar
17
-
18
-
19
- Params = Tuple[Tuple[object, ...], Dict[str, Any]]
20
- Res = TypeVar('Res')
21
- Param = ParamSpec('Param')
22
-
23
- class EmptyKwargs(TypedDict):
24
- pass
25
-
26
- @dataclass
27
- class OkResult(Generic[Res]):
28
- value: Res
29
- @dataclass
30
- class ExceptionResult:
31
- value: Exception
32
- @dataclass
33
- class AbortedResult:
34
- pass
35
- @dataclass
36
- class EndResult:
37
- pass
38
- @dataclass
39
- class GradioQueueEvent:
40
- method_name: str
41
- args: tuple[Any, ...]
42
- kwargs: dict[str, Any]
43
-
44
- RegularResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | GradioQueueEvent"
45
- GeneratorResQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | GradioQueueEvent"
46
- YieldQueueResult: TypeAlias = "OkResult[Res] | ExceptionResult | EndResult | AbortedResult"
47
-
48
- Duration: TypeAlias = "int | timedelta"
49
- DynamicDuration: TypeAlias = "Duration | Callable[Param, Duration] | None"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
spaces/zero/wrappers.py DELETED
@@ -1,418 +0,0 @@
1
- """
2
- """
3
- from __future__ import annotations
4
-
5
- import multiprocessing
6
- import os
7
- import signal
8
- import traceback
9
- import warnings
10
- from concurrent.futures import ThreadPoolExecutor
11
- from contextlib import nullcontext
12
- from contextvars import copy_context
13
- from datetime import timedelta
14
- from functools import partial
15
- from functools import wraps
16
- from multiprocessing.context import ForkProcess
17
- from pickle import PicklingError
18
- from queue import Empty
19
- from queue import Queue as ThreadQueue
20
- from threading import Thread
21
- from typing import TYPE_CHECKING
22
- from typing import Callable
23
- from typing import Generator
24
- from typing import Generic
25
- from typing_extensions import assert_never
26
-
27
- import psutil
28
-
29
- from ..config import Config
30
- from ..utils import debug
31
- from ..utils import drop_params
32
- from ..utils import gradio_request_var
33
- from ..utils import SimpleQueue as Queue
34
- from . import client
35
- from . import torch
36
- from .api import AllowToken
37
- from .api import NvidiaIndex
38
- from .api import NvidiaUUID
39
- from .gradio import GradioPartialContext
40
- from .gradio import get_server_port
41
- from .gradio import patch_gradio_queue
42
- from .gradio import try_process_queue_event
43
- from .tqdm import remove_tqdm_multiprocessing_lock
44
- from .tqdm import tqdm
45
- from .types import * # TODO: Please don't do that
46
-
47
-
48
- GENERATOR_GLOBAL_TIMEOUT = 20 * 60
49
-
50
- SPAWN_PROGRESS_CLEANUP = 0.1
51
- SPAWN_PROGRESS_INIT = 0.1
52
-
53
-
54
- Process = multiprocessing.get_context('fork').Process
55
- forked = False
56
-
57
-
58
- class Worker(Generic[Res]):
59
- process: ForkProcess
60
- arg_queue: Queue[tuple[Params, GradioPartialContext]]
61
- res_queue: Queue[Res | None]
62
- _sentinel: Thread
63
-
64
- def __init__(
65
- self,
66
- target: Callable[[
67
- Queue[tuple[Params, GradioPartialContext]],
68
- Queue[Res | None],
69
- AllowToken,
70
- NvidiaUUID,
71
- list[int],
72
- ], None],
73
- allow_token: str,
74
- nvidia_uuid: str,
75
- ):
76
- self._sentinel = Thread(target=self._close_on_exit, daemon=True)
77
- self.arg_queue = Queue()
78
- self.res_queue = Queue()
79
- debug(f"{self.arg_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
80
- debug(f"{self.res_queue._writer.fileno()=}") # pyright: ignore [reportAttributeAccessIssue]
81
- if (server_port := get_server_port()) is not None:
82
- fds = [c.fd for c in psutil.Process().connections() if c.laddr.port == server_port]
83
- debug(f"{fds=}")
84
- else:
85
- warnings.warn("Using a ZeroGPU function outside of Gradio caching or request might block the app")
86
- fds = []
87
- args = self.arg_queue, self.res_queue, allow_token, nvidia_uuid, fds
88
- if TYPE_CHECKING:
89
- target(*args)
90
- self.process = Process(
91
- target=target,
92
- args=args,
93
- daemon=True,
94
- )
95
- self.process.start()
96
- self._sentinel.start()
97
-
98
- def _close_on_exit(self):
99
- self.process.join()
100
- self.arg_queue.close()
101
- self.res_queue.wlock_release()
102
- self.res_queue.put(None)
103
-
104
-
105
- def worker_init(
106
- res_queue: Queue[RegularResQueueResult | None] | Queue[GeneratorResQueueResult | None],
107
- allow_token: str,
108
- nvidia_uuid: str,
109
- fds: list[int],
110
- ) -> None | ExceptionResult:
111
- # Immediately close file descriptors
112
- for fd in fds:
113
- try:
114
- os.close(fd)
115
- except Exception as e: # pragma: no cover
116
- if isinstance(e, OSError) and e.errno == 9:
117
- continue
118
- traceback.print_exc()
119
- return ExceptionResult(e)
120
- progress = nullcontext()
121
- if tqdm is not None and Config.zero_gpu_v2:
122
- progress = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w'))
123
- try: # Unrecoverable init part
124
- patch_gradio_queue(res_queue)
125
- with progress as progress:
126
- current_progress = 0 # Gradio does not support float progress updates
127
- def update(n: float):
128
- nonlocal current_progress
129
- current_progress += n
130
- if progress is not None:
131
- progress.update(round(current_progress * 100) - progress.n)
132
- client.allow(allow_token)
133
- update(SPAWN_PROGRESS_CLEANUP)
134
- torch.unpatch()
135
- torch.init(nvidia_uuid)
136
- update(SPAWN_PROGRESS_INIT)
137
- callback = None
138
- if (transfer_size := torch.size()) > 0:
139
- remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT)
140
- callback = lambda n: update(n * remaining / transfer_size)
141
- torch.move(callback=callback)
142
- except Exception as e: # pragma: no cover
143
- traceback.print_exc()
144
- return ExceptionResult(e)
145
- try:
146
- remove_tqdm_multiprocessing_lock()
147
- except Exception: # pragma: no cover
148
- print("Error while trying to remove tqdm mp_lock:")
149
- traceback.print_exc()
150
-
151
-
152
- def process_duration(duration: Duration | None):
153
- if duration is None or isinstance(duration, timedelta):
154
- return duration
155
- return timedelta(seconds=duration)
156
-
157
-
158
- def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs):
159
- if not callable(duration):
160
- return duration
161
- return duration(*args, **kwargs)
162
-
163
-
164
- def regular_function_wrapper(
165
- task: Callable[Param, Res],
166
- duration: DynamicDuration[Param],
167
- ) -> Callable[Param, Res]:
168
-
169
- import gradio as gr
170
-
171
- request_var = gradio_request_var()
172
- workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res]]] = {}
173
- task_id = id(task)
174
-
175
- @wraps(task)
176
- def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Res:
177
-
178
- if forked:
179
- return task(*args, **kwargs)
180
-
181
- request = request_var.get()
182
- duration_ = static_duration(duration, *args, **kwargs)
183
- duration_ = process_duration(duration_)
184
- schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
185
- allow_token = schedule_response.allowToken
186
- nvidia_index = schedule_response.nvidiaIndex
187
- nvidia_uuid = schedule_response.nvidiaUUID
188
- release = partial(client.release, allow_token)
189
-
190
- try:
191
- worker = workers.pop(nvidia_index)
192
- except KeyError:
193
- worker = None
194
-
195
- if worker is not None and worker.process.is_alive() and schedule_response.idle:
196
- assert worker.arg_queue.empty()
197
- assert worker.res_queue.empty()
198
- else:
199
- worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
200
-
201
- try:
202
- worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
203
- except PicklingError: # TODO: detailed serialization diagnostic
204
- release(fail=True)
205
- raise
206
-
207
- while True:
208
- res = worker.res_queue.get()
209
- if res is None:
210
- release(fail=True, allow_404=True)
211
- raise gr.Error("GPU task aborted")
212
- if isinstance(res, ExceptionResult):
213
- release(fail=True)
214
- raise res.value
215
- if isinstance(res, OkResult):
216
- release()
217
- workers[nvidia_index] = worker
218
- return res.value
219
- if isinstance(res, GradioQueueEvent):
220
- try_process_queue_event(res.method_name, *res.args, **res.kwargs)
221
- continue
222
- assert_never(res)
223
-
224
-
225
- def thread_wrapper(
226
- arg_queue: Queue[tuple[Params, GradioPartialContext]],
227
- res_queue: Queue[RegularResQueueResult[Res] | None],
228
- allow_token: str,
229
- nvidia_uuid: str,
230
- fds: list[int],
231
- ):
232
- global forked
233
- forked = True
234
- signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
235
- initialized = False
236
- while True:
237
- try:
238
- (args, kwargs), gradio_context = arg_queue.get()
239
- except OSError:
240
- break
241
- if not initialized:
242
- if (res := worker_init(
243
- res_queue=res_queue,
244
- allow_token=allow_token,
245
- nvidia_uuid=nvidia_uuid,
246
- fds=fds,
247
- )) is not None:
248
- res_queue.put(res)
249
- return
250
- initialized = True
251
- GradioPartialContext.apply(gradio_context)
252
- context = copy_context()
253
- with ThreadPoolExecutor() as executor:
254
- future = executor.submit(context.run, task, *args, **kwargs) # type: ignore
255
- try:
256
- res = future.result()
257
- except Exception as e:
258
- traceback.print_exc()
259
- res = ExceptionResult(e)
260
- else:
261
- res = OkResult(res)
262
- try:
263
- res_queue.put(res)
264
- except PicklingError as e:
265
- res_queue.put(ExceptionResult(e))
266
-
267
- # https://github.com/python/cpython/issues/91002
268
- if not hasattr(task, '__annotations__'):
269
- gradio_handler.__annotations__ = {}
270
-
271
- return gradio_handler
272
-
273
-
274
- def generator_function_wrapper(
275
- task: Callable[Param, Generator[Res, None, None]],
276
- duration: DynamicDuration[Param],
277
- ) -> Callable[Param, Generator[Res, None, None]]:
278
-
279
- import gradio as gr
280
-
281
- request_var = gradio_request_var()
282
- workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res]]] = {}
283
- task_id = id(task)
284
-
285
- @wraps(task)
286
- def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]:
287
-
288
- if forked:
289
- yield from task(*args, **kwargs)
290
- return
291
-
292
- request = request_var.get()
293
- duration_ = static_duration(duration, *args, **kwargs)
294
- duration_ = process_duration(duration_)
295
- schedule_response = client.schedule(task_id=task_id, request=request, duration=duration_)
296
- allow_token = schedule_response.allowToken
297
- nvidia_index = schedule_response.nvidiaIndex
298
- nvidia_uuid = schedule_response.nvidiaUUID
299
- release = partial(client.release, allow_token)
300
-
301
- try:
302
- worker = workers.pop(nvidia_index)
303
- except KeyError:
304
- worker = None
305
-
306
- if worker is not None and worker.process.is_alive() and schedule_response.idle:
307
- assert worker.arg_queue.empty()
308
- assert worker.res_queue.empty()
309
- else:
310
- worker = Worker(thread_wrapper, allow_token, nvidia_uuid)
311
-
312
- try:
313
- worker.arg_queue.put(((args, kwargs), GradioPartialContext.get()))
314
- except PicklingError: # TODO: detailed serialization diagnostic
315
- release(fail=True)
316
- raise
317
-
318
- yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue()
319
- def fill_yield_queue(worker: Worker[GeneratorResQueueResult[Res]]):
320
- while True:
321
- res = worker.res_queue.get()
322
- if res is None:
323
- release(fail=True, allow_404=True)
324
- yield_queue.put(AbortedResult())
325
- return
326
- if isinstance(res, ExceptionResult):
327
- release(fail=True)
328
- yield_queue.put(ExceptionResult(res.value))
329
- return
330
- if isinstance(res, EndResult):
331
- release()
332
- workers[nvidia_index] = worker
333
- yield_queue.put(EndResult())
334
- return
335
- if isinstance(res, OkResult):
336
- yield_queue.put(OkResult(res.value))
337
- continue
338
- if isinstance(res, GradioQueueEvent): # pragma: no cover (not working properly on Gradio side)
339
- try_process_queue_event(res.method_name, *res.args, **res.kwargs)
340
- continue
341
- debug(f"fill_yield_queue: assert_never({res=})")
342
- assert_never(res)
343
- from typing_extensions import assert_never
344
- with ThreadPoolExecutor() as e:
345
- f = e.submit(copy_context().run, fill_yield_queue, worker)
346
- f.add_done_callback(lambda _: debug("fill_yield_queue DONE"))
347
- while True:
348
- try:
349
- res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT)
350
- except Empty: # pragma: no cover
351
- debug(f"yield_queue TIMEOUT ({GENERATOR_GLOBAL_TIMEOUT=})")
352
- raise
353
- if isinstance(res, AbortedResult):
354
- raise gr.Error("GPU task aborted")
355
- if isinstance(res, ExceptionResult):
356
- raise res.value
357
- if isinstance(res, EndResult):
358
- break
359
- if isinstance(res, OkResult):
360
- yield res.value
361
- continue
362
- debug(f"gradio_handler: assert_never({res=})")
363
- assert_never(res)
364
-
365
-
366
- def thread_wrapper(
367
- arg_queue: Queue[tuple[Params, GradioPartialContext]],
368
- res_queue: Queue[GeneratorResQueueResult[Res] | None],
369
- allow_token: str,
370
- nvidia_uuid: str,
371
- fds: list[int],
372
- ):
373
- global forked
374
- forked = True
375
- signal.signal(signal.SIGTERM, drop_params(arg_queue.close))
376
- initialized = False
377
- while True:
378
- try:
379
- (args, kwargs), gradio_context = arg_queue.get()
380
- except OSError:
381
- break
382
- if not initialized:
383
- if (res := worker_init(
384
- res_queue=res_queue,
385
- allow_token=allow_token,
386
- nvidia_uuid=nvidia_uuid,
387
- fds=fds,
388
- )) is not None:
389
- res_queue.put(res)
390
- return
391
- initialized = True
392
- def iterate():
393
- gen = task(*args, **kwargs) # type: ignore
394
- while True:
395
- try:
396
- res = next(gen)
397
- except StopIteration:
398
- break
399
- except Exception as e:
400
- res_queue.put(ExceptionResult(e))
401
- break
402
- try:
403
- res_queue.put(OkResult(res))
404
- except PicklingError as e:
405
- res_queue.put(ExceptionResult(e))
406
- break
407
- else:
408
- continue
409
- GradioPartialContext.apply(gradio_context)
410
- with ThreadPoolExecutor() as executor:
411
- executor.submit(copy_context().run, iterate)
412
- res_queue.put(EndResult())
413
-
414
- # https://github.com/python/cpython/issues/91002
415
- if not hasattr(task, '__annotations__'):
416
- gradio_handler.__annotations__ = {}
417
-
418
- return gradio_handler