vpcom commited on
Commit
108fbc5
·
1 Parent(s): bd7413b

feat: implement our own InferenceClient

Browse files
Files changed (1) hide show
  1. app.py +127 -4
app.py CHANGED
@@ -20,6 +20,14 @@ from gradio.components import (
20
  Textbox,
21
  get_component_instance,
22
  )
 
 
 
 
 
 
 
 
23
  from gradio.themes import ThemeClass as Theme
24
 
25
  import gradio as gr
@@ -30,6 +38,18 @@ import anyio
30
  from huggingface_hub import Repository, InferenceClient
31
  from utils import force_git_push
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
34
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
35
  MODEL_NAME = os.getenv("MODEL_NAME")
@@ -97,10 +117,113 @@ stop_sequences = ["<|endoftext|>"] # ":پایان","@","#","$",
97
  # ["<%مهدی اخوان ثالث%"],
98
  # ]
99
 
100
- client = InferenceClient(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  API_URL,
102
- headers={"Authorization": f"Bearer {HF_TOKEN}",
103
- "use_cache": False},
104
  )
105
 
106
  def asynchronous_push(f_stop):
@@ -209,7 +332,7 @@ additional_inputs=[
209
  ),
210
  gr.Slider(
211
  label="Top-p (nucleus sampling)",
212
- value=1.0,
213
  minimum=0.0,
214
  maximum=1,
215
  step=0.05,
 
20
  Textbox,
21
  get_component_instance,
22
  )
23
+
24
+ from huggingface_hub.utils import (
25
+ BadRequestError,
26
+ build_hf_headers,
27
+ get_session,
28
+ hf_raise_for_status,
29
+ )
30
+
31
  from gradio.themes import ThemeClass as Theme
32
 
33
  import gradio as gr
 
38
  from huggingface_hub import Repository, InferenceClient
39
  from utils import force_git_push
40
 
41
+ from typing import (
42
+ TYPE_CHECKING,
43
+ Any,
44
+ Dict,
45
+ Iterable,
46
+ List,
47
+ Literal,
48
+ Optional,
49
+ Union,
50
+ overload,
51
+ )
52
+
53
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
54
  DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
55
  MODEL_NAME = os.getenv("MODEL_NAME")
 
117
  # ["<%مهدی اخوان ثالث%"],
118
  # ]
119
 
120
+
121
+ class InferenceClientUS(InferenceClient):
122
+ def __init__(
123
+ self,
124
+ model: Optional[str] = None,
125
+ token: Union[str, bool, None] = None,
126
+ timeout: Optional[float] = None,
127
+ headers: Optional[Dict[str, str]] = None,
128
+ cookies: Optional[Dict[str, str]] = None,
129
+ ) -> None:
130
+ super().__init__(
131
+ model=model,
132
+ token=token,
133
+ timeout=timeout,
134
+ headers=headers,
135
+ cookies=cookies,
136
+ )
137
+
138
+ def post(
139
+ self,
140
+ *,
141
+ json: Optional[Union[str, Dict, List]] = None,
142
+ data: Optional[ContentT] = None,
143
+ model: Optional[str] = None,
144
+ task: Optional[str] = None,
145
+ stream: bool = False,
146
+ ) -> Union[bytes, Iterable[bytes]]:
147
+ """
148
+ Make a POST request to the inference server.
149
+
150
+ Args:
151
+ json (`Union[str, Dict, List]`, *optional*):
152
+ The JSON data to send in the request body. Defaults to None.
153
+ data (`Union[str, Path, bytes, BinaryIO]`, *optional*):
154
+ The content to send in the request body. It can be raw bytes, a pointer to an opened file, a local file
155
+ path, or a URL to an online resource (image, audio file,...). If both `json` and `data` are passed,
156
+ `data` will take precedence. At least `json` or `data` must be provided. Defaults to None.
157
+ model (`str`, *optional*):
158
+ The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed
159
+ Inference Endpoint. Will override the model defined at the instance level. Defaults to None.
160
+ task (`str`, *optional*):
161
+ The task to perform on the inference. Used only to default to a recommended model if `model` is not
162
+ provided. At least `model` or `task` must be provided. Defaults to None.
163
+ stream (`bool`, *optional*):
164
+ Whether to iterate over streaming APIs.
165
+
166
+ Returns:
167
+ bytes: The raw bytes returned by the server.
168
+
169
+ Raises:
170
+ [`InferenceTimeoutError`]:
171
+ If the model is unavailable or the request times out.
172
+ `HTTPError`:
173
+ If the request fails with an HTTP error status code other than HTTP 503.
174
+ """
175
+ url = self._resolve_url(model, task)
176
+
177
+ if data is not None and json is not None:
178
+ warnings.warn("Ignoring `json` as `data` is passed as binary.")
179
+
180
+ # Set Accept header if relevant
181
+ headers = self.headers.copy()
182
+ if task in TASKS_EXPECTING_IMAGES and "Accept" not in headers:
183
+ headers["Accept"] = "image/png"
184
+
185
+ t0 = time.time()
186
+ timeout = self.timeout
187
+ while True:
188
+ with _open_as_binary(data) as data_as_binary:
189
+ try:
190
+ response = get_session().post(
191
+ url,
192
+ json=json,
193
+ data=data_as_binary,
194
+ headers=headers,
195
+ cookies=self.cookies,
196
+ timeout=self.timeout,
197
+ stream=stream,
198
+ )
199
+ except TimeoutError as error:
200
+ # Convert any `TimeoutError` to a `InferenceTimeoutError`
201
+ raise InferenceTimeoutError(f"Inference call timed out: {url}") from error # type: ignore
202
+
203
+ try:
204
+ hf_raise_for_status(response)
205
+ return response.iter_lines() if stream else response.content
206
+ except HTTPError as error:
207
+ if error.response.status_code == 503:
208
+ # If Model is unavailable, either raise a TimeoutError...
209
+ if timeout is not None and time.time() - t0 > timeout:
210
+ raise InferenceTimeoutError(
211
+ f"Model not loaded on the server: {url}. Please retry with a higher timeout (current:"
212
+ f" {self.timeout}).",
213
+ request=error.request,
214
+ response=error.response,
215
+ ) from error
216
+ # ...or wait 1s and retry
217
+ logger.info(f"Waiting for model to be loaded on the server: {error}")
218
+ time.sleep(1)
219
+ if timeout is not None:
220
+ timeout = max(self.timeout - (time.time() - t0), 1) # type: ignore
221
+ continue
222
+ raise
223
+
224
+ client = InferenceClientUS(
225
  API_URL,
226
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
 
227
  )
228
 
229
  def asynchronous_push(f_stop):
 
332
  ),
333
  gr.Slider(
334
  label="Top-p (nucleus sampling)",
335
+ value=0.9,
336
  minimum=0.0,
337
  maximum=1,
338
  step=0.05,