fffiloni commited on
Commit
eaa9650
·
verified ·
1 Parent(s): 2715379

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. Dockerfile +31 -0
  2. app.py +117 -0
  3. requirements.txt +2 -0
  4. run.sh +44 -0
  5. utils/gradio_helpers.py +469 -0
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM r8.im/fofr/consistent-character@sha256:e90b0b680b1dca1bab496f163fdfc7cff174f7cd18d094bf57664f1891f2e857
2
+ RUN apt-get update && apt-get install -y netcat jq
3
+
4
+ RUN useradd -m -u 1000 user
5
+ RUN chown -R user:user / || true
6
+ RUN chown -R user:user /src/
7
+ RUN chown -R user:user /root/
8
+ RUN chown -R user:user /var/
9
+ USER user
10
+ ENV HOME=/home/user \
11
+ PATH=/home/user/.local/bin:$PATH \
12
+ PYTHONPATH=$HOME/app \
13
+ PYTHONUNBUFFERED=1 \
14
+ GRADIO_ALLOW_FLAGGING=never \
15
+ GRADIO_NUM_PORTS=1 \
16
+ GRADIO_SERVER_NAME=0.0.0.0 \
17
+ GRADIO_THEME=huggingface \
18
+ SYSTEM=spaces
19
+
20
+ WORKDIR $HOME/app
21
+ COPY ./requirements.txt /code/requirements.txt
22
+
23
+ # create virtual env for Gradio app
24
+ RUN python -m venv $HOME/.venv && \
25
+ . $HOME/.venv/bin/activate && \
26
+ pip install --no-cache-dir --upgrade pip && \
27
+ pip install --no-cache-dir -r /code/requirements.txt
28
+
29
+ COPY --chown=user . $HOME/app
30
+ RUN chmod +x $HOME/app/run.sh
31
+ CMD ["bash", "-c", "$HOME/app/run.sh"]
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from urllib.parse import urlparse
3
+ import requests
4
+ import time
5
+ import os
6
+
7
+ from utils.gradio_helpers import parse_outputs, process_outputs
8
+
9
+ inputs = []
10
+ inputs.append(gr.Textbox(
11
+ label="Prompt", info='''Describe the subject. Include clothes and hairstyle for more consistency.'''
12
+ ))
13
+
14
+ inputs.append(gr.Textbox(
15
+ label="Negative Prompt", info='''Things you do not want to see in your image'''
16
+ ))
17
+
18
+ inputs.append(gr.Image(
19
+ label="Subject", type="filepath"
20
+ ))
21
+
22
+ inputs.append(gr.Slider(
23
+ label="Number Of Outputs", info='''The number of images to generate.''', value=3,
24
+ minimum=1, maximum=20, step=1,
25
+ ))
26
+
27
+ inputs.append(gr.Slider(
28
+ label="Number Of Images Per Pose", info='''The number of images to generate for each pose.''', value=1,
29
+ minimum=1, maximum=4, step=1,
30
+ ))
31
+
32
+ inputs.append(gr.Checkbox(
33
+ label="Randomise Poses", info='''Randomise the poses used.''', value=True
34
+ ))
35
+
36
+ inputs.append(gr.Dropdown(
37
+ choices=['webp', 'jpg', 'png'], label="output_format", info='''Format of the output images''', value="webp"
38
+ ))
39
+
40
+ inputs.append(gr.Number(
41
+ label="Output Quality", info='''Quality of the output images, from 0 to 100. 100 is best quality, 0 is lowest quality.''', value=80
42
+ ))
43
+
44
+ inputs.append(gr.Number(
45
+ label="Seed", info='''Set a seed for reproducibility. Random by default.''', value=None
46
+ ))
47
+
48
+ names = ['prompt', 'negative_prompt', 'subject', 'number_of_outputs', 'number_of_images_per_pose', 'randomise_poses', 'output_format', 'output_quality', 'seed']
49
+
50
+ outputs = []
51
+ outputs.append(gr.Image())
52
+ outputs.append(gr.Image())
53
+ outputs.append(gr.Image())
54
+ outputs.append(gr.Image())
55
+ outputs.append(gr.Image())
56
+
57
+ expected_outputs = len(outputs)
58
+ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
59
+ headers = {'Content-Type': 'application/json'}
60
+
61
+ payload = {"input": {}}
62
+
63
+
64
+ base_url = "http://0.0.0.0:7860"
65
+ for i, key in enumerate(names):
66
+ value = args[i]
67
+ if value and (os.path.exists(str(value))):
68
+ value = f"{base_url}/file=" + value
69
+ if value is not None and value != "":
70
+ payload["input"][key] = value
71
+
72
+ response = requests.post("http://0.0.0.0:5000/predictions", headers=headers, json=payload)
73
+
74
+
75
+ if response.status_code == 201:
76
+ follow_up_url = response.json()["urls"]["get"]
77
+ response = requests.get(follow_up_url, headers=headers)
78
+ while response.json()["status"] != "succeeded":
79
+ if response.json()["status"] == "failed":
80
+ raise gr.Error("The submission failed!")
81
+ response = requests.get(follow_up_url, headers=headers)
82
+ time.sleep(1)
83
+ if response.status_code == 200:
84
+ json_response = response.json()
85
+ #If the output component is JSON return the entire output response
86
+ if(outputs[0].get_config()["name"] == "json"):
87
+ return json_response["output"]
88
+ predict_outputs = parse_outputs(json_response["output"])
89
+ processed_outputs = process_outputs(predict_outputs)
90
+ difference_outputs = expected_outputs - len(processed_outputs)
91
+ # If less outputs than expected, hide the extra ones
92
+ if difference_outputs > 0:
93
+ extra_outputs = [gr.update(visible=False)] * difference_outputs
94
+ processed_outputs.extend(extra_outputs)
95
+ # If more outputs than expected, cap the outputs to the expected number
96
+ elif difference_outputs < 0:
97
+ processed_outputs = processed_outputs[:difference_outputs]
98
+
99
+ return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
100
+ else:
101
+ if(response.status_code == 409):
102
+ raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
103
+ raise gr.Error(f"The submission failed! Error: {response.status_code}")
104
+
105
+ title = "Demo for consistent-character cog image by fofr"
106
+ model_description = "Create images of a given character in different poses"
107
+
108
+ app = gr.Interface(
109
+ fn=predict,
110
+ inputs=inputs,
111
+ outputs=outputs,
112
+ title=title,
113
+ description=model_description,
114
+ allow_flagging="never",
115
+ )
116
+ app.launch(share=True)
117
+
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio==4.18.0
2
+ prance
run.sh ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start the cog server in the background - Ensure correct path to cog
2
+ cd /src && python3 -m cog.server.http --threads=10 &
3
+
4
+ # Initialize counter for the first loop
5
+ counter1=0
6
+
7
+ # Continuous loop for reliably checking cog server's readiness on port 5000
8
+ while true; do
9
+ if nc -z localhost 5000; then
10
+ echo "Cog server is running on port 5000."
11
+ break # Exit the loop when the server is up
12
+ fi
13
+ echo "Waiting for cog server to start on port 5000..."
14
+ sleep 5
15
+ ((counter1++))
16
+ if [ $counter1 -ge 250 ]; then
17
+ echo "Error: Cog server did not start on port 5000 after 250 attempts."
18
+ exit 1 # Exit the script with an error status
19
+ fi
20
+ done
21
+
22
+ # Initialize counter for the second loop
23
+ counter2=0
24
+
25
+ # New check: Waiting for the cog server to be fully ready
26
+ while true; do
27
+ response=$(curl -s http://localhost:5000/health-check) # Replace localhost:5000 with actual hostname and port if necessary
28
+ status=$(echo $response | jq -r '.status') # Parse status from JSON response
29
+ if [ "$status" = "READY" ]; then
30
+ echo "Cog server is fully ready."
31
+ break # Exit the loop when the server is fully ready
32
+ else
33
+ echo "Waiting for cog server (models loading) on port 5000..."
34
+ sleep 5
35
+ fi
36
+ ((counter2++))
37
+ if [ $counter2 -ge 250 ]; then
38
+ echo "Error: Cog server did not become fully ready after 250 attempts."
39
+ exit 1 # Exit the script with an error status
40
+ fi
41
+ done
42
+
43
+ # Run the application - only when cog server is fully ready
44
+ cd $HOME/app && . $HOME/.venv/bin/activate && python app.py
utils/gradio_helpers.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from urllib.parse import urlparse
3
+ import requests
4
+ import time
5
+ from PIL import Image
6
+ import base64
7
+ import io
8
+ import uuid
9
+ import os
10
+
11
+
12
+ def extract_property_info(prop):
13
+ combined_prop = {}
14
+ merge_keywords = ["allOf", "anyOf", "oneOf"]
15
+
16
+ for keyword in merge_keywords:
17
+ if keyword in prop:
18
+ for subprop in prop[keyword]:
19
+ combined_prop.update(subprop)
20
+ del prop[keyword]
21
+
22
+ if not combined_prop:
23
+ combined_prop = prop.copy()
24
+
25
+ for key in ["description", "default"]:
26
+ if key in prop:
27
+ combined_prop[key] = prop[key]
28
+
29
+ return combined_prop
30
+
31
+
32
+ def detect_file_type(filename):
33
+ audio_extensions = [".mp3", ".wav", ".flac", ".aac", ".ogg", ".m4a"]
34
+ image_extensions = [
35
+ ".jpg",
36
+ ".jpeg",
37
+ ".png",
38
+ ".gif",
39
+ ".bmp",
40
+ ".tiff",
41
+ ".svg",
42
+ ".webp",
43
+ ]
44
+ video_extensions = [
45
+ ".mp4",
46
+ ".mov",
47
+ ".wmv",
48
+ ".flv",
49
+ ".avi",
50
+ ".avchd",
51
+ ".mkv",
52
+ ".webm",
53
+ ]
54
+
55
+ # Extract the file extension
56
+ if isinstance(filename, str):
57
+ extension = filename[filename.rfind(".") :].lower()
58
+
59
+ # Check the extension against each list
60
+ if extension in audio_extensions:
61
+ return "audio"
62
+ elif extension in image_extensions:
63
+ return "image"
64
+ elif extension in video_extensions:
65
+ return "video"
66
+ else:
67
+ return "string"
68
+ elif isinstance(filename, list):
69
+ return "list"
70
+
71
+
72
+ def build_gradio_inputs(ordered_input_schema, example_inputs=None):
73
+ inputs = []
74
+ input_field_strings = """inputs = []\n"""
75
+ names = []
76
+ for index, (name, prop) in enumerate(ordered_input_schema):
77
+ names.append(name)
78
+ prop = extract_property_info(prop)
79
+ if "enum" in prop:
80
+ input_field = gr.Dropdown(
81
+ choices=prop["enum"],
82
+ label=prop.get("title"),
83
+ info=prop.get("description"),
84
+ value=prop.get("default"),
85
+ )
86
+ input_field_string = f"""inputs.append(gr.Dropdown(
87
+ choices={prop["enum"]}, label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value="{prop.get("default")}"
88
+ ))\n"""
89
+ elif prop["type"] == "integer":
90
+ if prop.get("minimum") and prop.get("maximum"):
91
+ input_field = gr.Slider(
92
+ label=prop.get("title"),
93
+ info=prop.get("description"),
94
+ value=prop.get("default"),
95
+ minimum=prop.get("minimum"),
96
+ maximum=prop.get("maximum"),
97
+ step=1,
98
+ )
99
+ input_field_string = f"""inputs.append(gr.Slider(
100
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
101
+ minimum={prop.get("minimum")}, maximum={prop.get("maximum")}, step=1,
102
+ ))\n"""
103
+ else:
104
+ input_field = gr.Number(
105
+ label=prop.get("title"),
106
+ info=prop.get("description"),
107
+ value=prop.get("default"),
108
+ )
109
+ input_field_string = f"""inputs.append(gr.Number(
110
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
111
+ ))\n"""
112
+ elif prop["type"] == "number":
113
+ if prop.get("minimum") and prop.get("maximum"):
114
+ input_field = gr.Slider(
115
+ label=prop.get("title"),
116
+ info=prop.get("description"),
117
+ value=prop.get("default"),
118
+ minimum=prop.get("minimum"),
119
+ maximum=prop.get("maximum"),
120
+ )
121
+ input_field_string = f"""inputs.append(gr.Slider(
122
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")},
123
+ minimum={prop.get("minimum")}, maximum={prop.get("maximum")}
124
+ ))\n"""
125
+ else:
126
+ input_field = gr.Number(
127
+ label=prop.get("title"),
128
+ info=prop.get("description"),
129
+ value=prop.get("default"),
130
+ )
131
+ input_field_string = f"""inputs.append(gr.Number(
132
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
133
+ ))\n"""
134
+ elif prop["type"] == "boolean":
135
+ input_field = gr.Checkbox(
136
+ label=prop.get("title"),
137
+ info=prop.get("description"),
138
+ value=prop.get("default"),
139
+ )
140
+ input_field_string = f"""inputs.append(gr.Checkbox(
141
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}, value={prop.get("default")}
142
+ ))\n"""
143
+ elif (
144
+ prop["type"] == "string" and prop.get("format") == "uri" and example_inputs
145
+ ):
146
+ input_type_example = example_inputs.get(name, None)
147
+ if input_type_example:
148
+ input_type = detect_file_type(input_type_example)
149
+ else:
150
+ input_type = None
151
+ if input_type == "image":
152
+ input_field = gr.Image(label=prop.get("title"), type="filepath")
153
+ input_field_string = f"""inputs.append(gr.Image(
154
+ label="{prop.get("title")}", type="filepath"
155
+ ))\n"""
156
+ elif input_type == "audio":
157
+ input_field = gr.Audio(label=prop.get("title"), type="filepath")
158
+ input_field_string = f"""inputs.append(gr.Audio(
159
+ label="{prop.get("title")}", type="filepath"
160
+ ))\n"""
161
+ elif input_type == "video":
162
+ input_field = gr.Video(label=prop.get("title"))
163
+ input_field_string = f"""inputs.append(gr.Video(
164
+ label="{prop.get("title")}"
165
+ ))\n"""
166
+ else:
167
+ input_field = gr.File(label=prop.get("title"))
168
+ input_field_string = f"""inputs.append(gr.File(
169
+ label="{prop.get("title")}"
170
+ ))\n"""
171
+ else:
172
+ input_field = gr.Textbox(
173
+ label=prop.get("title"),
174
+ info=prop.get("description"),
175
+ )
176
+ input_field_string = f"""inputs.append(gr.Textbox(
177
+ label="{prop.get("title")}", info={"'''"+prop.get("description")+"'''" if prop.get("description") else 'None'}
178
+ ))\n"""
179
+ inputs.append(input_field)
180
+ input_field_strings += f"{input_field_string}\n"
181
+
182
+ input_field_strings += f"names = {names}\n"
183
+
184
+ return inputs, input_field_strings, names
185
+
186
+
187
+ def build_gradio_outputs_replicate(output_types):
188
+ outputs = []
189
+ output_field_strings = """outputs = []\n"""
190
+ if output_types:
191
+ for output in output_types:
192
+ if output == "image":
193
+ output_field = gr.Image()
194
+ output_field_string = "outputs.append(gr.Image())"
195
+ elif output == "audio":
196
+ output_field = gr.Audio(type="filepath")
197
+ output_field_string = "outputs.append(gr.Audio(type='filepath'))"
198
+ elif output == "video":
199
+ output_field = gr.Video()
200
+ output_field_string = "outputs.append(gr.Video())"
201
+ elif output == "string":
202
+ output_field = gr.Textbox()
203
+ output_field_string = "outputs.append(gr.Textbox())"
204
+ elif output == "json":
205
+ output_field = gr.JSON()
206
+ output_field_string = "outputs.append(gr.JSON())"
207
+ elif output == "list":
208
+ output_field = gr.JSON()
209
+ output_field_string = "outputs.append(gr.JSON())"
210
+ outputs.append(output_field)
211
+ output_field_strings += f"{output_field_string}\n"
212
+ else:
213
+ output_field = gr.JSON()
214
+ output_field_string = "outputs.append(gr.JSON())"
215
+ outputs.append(output_field)
216
+
217
+ return outputs, output_field_strings
218
+
219
+
220
+ def build_gradio_outputs_cog():
221
+ pass
222
+
223
+
224
+ def process_outputs(outputs):
225
+ output_values = []
226
+ for output in outputs:
227
+ if not output:
228
+ continue
229
+ if isinstance(output, str):
230
+ if output.startswith("data:image"):
231
+ base64_data = output.split(",", 1)[1]
232
+ image_data = base64.b64decode(base64_data)
233
+ image_stream = io.BytesIO(image_data)
234
+ image = Image.open(image_stream)
235
+ output_values.append(image)
236
+ elif output.startswith("data:audio"):
237
+ base64_data = output.split(",", 1)[1]
238
+ audio_data = base64.b64decode(base64_data)
239
+ audio_stream = io.BytesIO(audio_data)
240
+ filename = f"{uuid.uuid4()}.wav" # Change format as needed
241
+ with open(filename, "wb") as audio_file:
242
+ audio_file.write(audio_stream.getbuffer())
243
+ output_values.append(filename)
244
+ elif output.startswith("data:video"):
245
+ base64_data = output.split(",", 1)[1]
246
+ video_data = base64.b64decode(base64_data)
247
+ video_stream = io.BytesIO(video_data)
248
+ # Here you can save the audio or return the stream for further processing
249
+ filename = f"{uuid.uuid4()}.mp4" # Change format as needed
250
+ with open(filename, "wb") as video_file:
251
+ video_file.write(video_stream.getbuffer())
252
+ output_values.append(filename)
253
+ else:
254
+ output_values.append(output)
255
+ else:
256
+ output_values.append(output)
257
+ return output_values
258
+
259
+
260
+ def parse_outputs(data):
261
+ if isinstance(data, dict):
262
+ # Handle case where data is an object
263
+ dict_values = []
264
+ for value in data.values():
265
+ extracted_values = parse_outputs(value)
266
+ # For dict, we append instead of extend to maintain list structure within objects
267
+ if isinstance(value, list):
268
+ dict_values += [extracted_values]
269
+ else:
270
+ dict_values += extracted_values
271
+ return dict_values
272
+ elif isinstance(data, list):
273
+ # Handle case where data is an array
274
+ list_values = []
275
+ for item in data:
276
+ # Here we extend to flatten the list since we're already in an array context
277
+ list_values += parse_outputs(item)
278
+ return list_values
279
+ else:
280
+ # Handle primitive data types directly
281
+ return [data]
282
+
283
+
284
+ def create_dynamic_gradio_app(
285
+ inputs,
286
+ outputs,
287
+ api_url,
288
+ api_id=None,
289
+ replicate_token=None,
290
+ title="",
291
+ model_description="",
292
+ names=[],
293
+ local_base=False,
294
+ hostname="0.0.0.0",
295
+ ):
296
+ expected_outputs = len(outputs)
297
+
298
+ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):
299
+ payload = {"input": {}}
300
+ if api_id:
301
+ payload["version"] = api_id
302
+ parsed_url = urlparse(str(request.url))
303
+ if local_base:
304
+ base_url = f"http://{hostname}:7860"
305
+ else:
306
+ base_url = parsed_url.scheme + "://" + parsed_url.netloc
307
+ for i, key in enumerate(names):
308
+ value = args[i]
309
+ if value and (os.path.exists(str(value))):
310
+ value = f"{base_url}/file=" + value
311
+ if value is not None and value != "":
312
+ payload["input"][key] = value
313
+ print(payload)
314
+ headers = {"Content-Type": "application/json"}
315
+ if replicate_token:
316
+ headers["Authorization"] = f"Token {replicate_token}"
317
+ print(headers)
318
+ response = requests.post(api_url, headers=headers, json=payload)
319
+ if response.status_code == 201:
320
+ follow_up_url = response.json()["urls"]["get"]
321
+ response = requests.get(follow_up_url, headers=headers)
322
+ while response.json()["status"] != "succeeded":
323
+ if response.json()["status"] == "failed":
324
+ raise gr.Error("The submission failed!")
325
+ response = requests.get(follow_up_url, headers=headers)
326
+ time.sleep(1)
327
+ # TODO: Add a failing mechanism if the API gets stuck
328
+ if response.status_code == 200:
329
+ json_response = response.json()
330
+ # If the output component is JSON return the entire output response
331
+ if outputs[0].get_config()["name"] == "json":
332
+ return json_response["output"]
333
+ predict_outputs = parse_outputs(json_response["output"])
334
+ processed_outputs = process_outputs(predict_outputs)
335
+ difference_outputs = expected_outputs - len(processed_outputs)
336
+ # If less outputs than expected, hide the extra ones
337
+ if difference_outputs > 0:
338
+ extra_outputs = [gr.update(visible=False)] * difference_outputs
339
+ processed_outputs.extend(extra_outputs)
340
+ # If more outputs than expected, cap the outputs to the expected number if
341
+ elif difference_outputs < 0:
342
+ processed_outputs = processed_outputs[:difference_outputs]
343
+
344
+ return (
345
+ tuple(processed_outputs)
346
+ if len(processed_outputs) > 1
347
+ else processed_outputs[0]
348
+ )
349
+
350
+ else:
351
+ if response.status_code == 409:
352
+ raise gr.Error(
353
+ f"Sorry, the Cog image is still processing. Try again in a bit."
354
+ )
355
+ raise gr.Error(f"The submission failed! Error: {response.status_code}")
356
+
357
+ app = gr.Interface(
358
+ fn=predict,
359
+ inputs=inputs,
360
+ outputs=outputs,
361
+ title=title,
362
+ description=model_description,
363
+ allow_flagging="never",
364
+ )
365
+ return app
366
+
367
+
368
+ def create_gradio_app_script(
369
+ inputs_string,
370
+ outputs_string,
371
+ api_url,
372
+ api_id=None,
373
+ replicate_token=None,
374
+ title="",
375
+ model_description="",
376
+ local_base=False,
377
+ hostname="0.0.0.0"
378
+ ):
379
+ headers = {"Content-Type": "application/json"}
380
+ if replicate_token:
381
+ headers["Authorization"] = f"Token {replicate_token}"
382
+
383
+ if local_base:
384
+ base_url = f'base_url = "http://{hostname}:7860"'
385
+ else:
386
+ base_url = """parsed_url = urlparse(str(request.url))
387
+ base_url = parsed_url.scheme + "://" + parsed_url.netloc"""
388
+ headers_string = f"""headers = {headers}\n"""
389
+ api_id_value = f'payload["version"] = "{api_id}"' if api_id is not None else ""
390
+ definition_string = """expected_outputs = len(outputs)
391
+ def predict(request: gr.Request, *args, progress=gr.Progress(track_tqdm=True)):"""
392
+ payload_string = f"""payload = {{"input": {{}}}}
393
+ {api_id_value}
394
+
395
+ {base_url}
396
+ for i, key in enumerate(names):
397
+ value = args[i]
398
+ if value and (os.path.exists(str(value))):
399
+ value = f"{{base_url}}/file=" + value
400
+ if value is not None and value != "":
401
+ payload["input"][key] = value\n"""
402
+
403
+ request_string = (
404
+ f"""response = requests.post("{api_url}", headers=headers, json=payload)\n"""
405
+ )
406
+
407
+ result_string = f"""
408
+ if response.status_code == 201:
409
+ follow_up_url = response.json()["urls"]["get"]
410
+ response = requests.get(follow_up_url, headers=headers)
411
+ while response.json()["status"] != "succeeded":
412
+ if response.json()["status"] == "failed":
413
+ raise gr.Error("The submission failed!")
414
+ response = requests.get(follow_up_url, headers=headers)
415
+ time.sleep(1)
416
+ if response.status_code == 200:
417
+ json_response = response.json()
418
+ #If the output component is JSON return the entire output response
419
+ if(outputs[0].get_config()["name"] == "json"):
420
+ return json_response["output"]
421
+ predict_outputs = parse_outputs(json_response["output"])
422
+ processed_outputs = process_outputs(predict_outputs)
423
+ difference_outputs = expected_outputs - len(processed_outputs)
424
+ # If less outputs than expected, hide the extra ones
425
+ if difference_outputs > 0:
426
+ extra_outputs = [gr.update(visible=False)] * difference_outputs
427
+ processed_outputs.extend(extra_outputs)
428
+ # If more outputs than expected, cap the outputs to the expected number
429
+ elif difference_outputs < 0:
430
+ processed_outputs = processed_outputs[:difference_outputs]
431
+
432
+ return tuple(processed_outputs) if len(processed_outputs) > 1 else processed_outputs[0]
433
+ else:
434
+ if(response.status_code == 409):
435
+ raise gr.Error(f"Sorry, the Cog image is still processing. Try again in a bit.")
436
+ raise gr.Error(f"The submission failed! Error: {{response.status_code}}")\n"""
437
+
438
+ interface_string = f"""title = "{title}"
439
+ model_description = "{model_description}"
440
+
441
+ app = gr.Interface(
442
+ fn=predict,
443
+ inputs=inputs,
444
+ outputs=outputs,
445
+ title=title,
446
+ description=model_description,
447
+ allow_flagging="never",
448
+ )
449
+ app.launch(share=True)
450
+ """
451
+
452
+ app_string = f"""import gradio as gr
453
+ from urllib.parse import urlparse
454
+ import requests
455
+ import time
456
+ import os
457
+
458
+ from utils.gradio_helpers import parse_outputs, process_outputs
459
+
460
+ {inputs_string}
461
+ {outputs_string}
462
+ {definition_string}
463
+ {headers_string}
464
+ {payload_string}
465
+ {request_string}
466
+ {result_string}
467
+ {interface_string}
468
+ """
469
+ return app_string