Pierre Chapuis
commited on
update to use the official Finegrain API directly
Browse files- README.md +2 -1
- gradio_image_annotation-0.2.3-py3-none-any.whl +0 -0
- pyproject.toml +3 -2
- requirements.lock +45 -43
- requirements.txt +2 -1
- src/app.py +100 -53
- src/fg.py +117 -0
README.md
CHANGED
@@ -4,7 +4,8 @@ emoji: 🧽
|
|
4 |
colorFrom: gray
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
|
|
8 |
app_file: src/app.py
|
9 |
pinned: false
|
10 |
license: other
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: blue
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.44.1
|
8 |
+
python_version: 3.12
|
9 |
app_file: src/app.py
|
10 |
pinned: false
|
11 |
license: other
|
gradio_image_annotation-0.2.3-py3-none-any.whl
DELETED
Binary file (85.3 kB)
|
|
pyproject.toml
CHANGED
@@ -6,13 +6,14 @@ authors = [
|
|
6 |
{ name = "Pierre Chapuis", email = "[email protected]" }
|
7 |
]
|
8 |
dependencies = [
|
9 |
-
"gradio>=4.41.0",
|
10 |
"environs>=11.0.0",
|
11 |
-
"gradio-image-annotation
|
12 |
"httpx>=0.27.0",
|
13 |
"pillow>=10.4.0",
|
14 |
"gradio-imageslider>=0.0.20",
|
15 |
"pillow-heif>=0.18.0",
|
|
|
16 |
]
|
17 |
readme = "README.md"
|
18 |
requires-python = ">= 3.12, <3.13"
|
|
|
6 |
{ name = "Pierre Chapuis", email = "[email protected]" }
|
7 |
]
|
8 |
dependencies = [
|
9 |
+
"gradio>=4.41.0,<5", # gradio-imageslider requires <5
|
10 |
"environs>=11.0.0",
|
11 |
+
"gradio-image-annotation>=0.2.5",
|
12 |
"httpx>=0.27.0",
|
13 |
"pillow>=10.4.0",
|
14 |
"gradio-imageslider>=0.0.20",
|
15 |
"pillow-heif>=0.18.0",
|
16 |
+
"httpx-sse>=0.4.0",
|
17 |
]
|
18 |
readme = "README.md"
|
19 |
requires-python = ">= 3.12, <3.13"
|
requirements.lock
CHANGED
@@ -14,93 +14,95 @@ aiofiles==23.2.1
|
|
14 |
# via gradio
|
15 |
annotated-types==0.7.0
|
16 |
# via pydantic
|
17 |
-
anyio==4.
|
18 |
# via gradio
|
19 |
# via httpx
|
20 |
# via starlette
|
21 |
-
certifi==2024.
|
22 |
# via httpcore
|
23 |
# via httpx
|
24 |
# via requests
|
25 |
-
charset-normalizer==3.
|
26 |
# via requests
|
27 |
-
click==8.1.
|
28 |
# via typer
|
29 |
# via uvicorn
|
30 |
-
contourpy==1.3.
|
31 |
# via matplotlib
|
32 |
cycler==0.12.1
|
33 |
# via matplotlib
|
34 |
-
environs==
|
35 |
# via eraser
|
36 |
-
fastapi==0.
|
37 |
# via gradio
|
38 |
-
ffmpy==0.
|
39 |
# via gradio
|
40 |
-
filelock==3.
|
41 |
# via huggingface-hub
|
42 |
-
fonttools==4.
|
43 |
# via matplotlib
|
44 |
-
fsspec==2024.
|
45 |
# via gradio-client
|
46 |
# via huggingface-hub
|
47 |
-
gradio==4.
|
48 |
# via eraser
|
49 |
# via gradio-image-annotation
|
50 |
# via gradio-imageslider
|
51 |
gradio-client==1.3.0
|
52 |
# via gradio
|
53 |
-
gradio-image-annotation
|
54 |
# via eraser
|
55 |
gradio-imageslider==0.0.20
|
56 |
# via eraser
|
57 |
h11==0.14.0
|
58 |
# via httpcore
|
59 |
# via uvicorn
|
60 |
-
httpcore==1.0.
|
61 |
# via httpx
|
62 |
-
httpx==0.
|
63 |
# via eraser
|
64 |
# via gradio
|
65 |
# via gradio-client
|
66 |
-
|
|
|
|
|
67 |
# via gradio
|
68 |
# via gradio-client
|
69 |
-
idna==3.
|
70 |
# via anyio
|
71 |
# via httpx
|
72 |
# via requests
|
73 |
-
importlib-resources==6.
|
74 |
# via gradio
|
75 |
-
jinja2==3.1.
|
76 |
# via gradio
|
77 |
-
kiwisolver==1.4.
|
78 |
# via matplotlib
|
79 |
markdown-it-py==3.0.0
|
80 |
# via rich
|
81 |
markupsafe==2.1.5
|
82 |
# via gradio
|
83 |
# via jinja2
|
84 |
-
marshmallow==3.
|
85 |
# via environs
|
86 |
-
matplotlib==3.
|
87 |
# via gradio
|
88 |
mdurl==0.1.2
|
89 |
# via markdown-it-py
|
90 |
-
numpy==2.
|
91 |
# via contourpy
|
92 |
# via gradio
|
93 |
# via matplotlib
|
94 |
# via pandas
|
95 |
-
orjson==3.10.
|
96 |
# via gradio
|
97 |
-
packaging==24.
|
98 |
# via gradio
|
99 |
# via gradio-client
|
100 |
# via huggingface-hub
|
101 |
# via marshmallow
|
102 |
# via matplotlib
|
103 |
-
pandas==2.2.
|
104 |
# via gradio
|
105 |
pillow==10.4.0
|
106 |
# via eraser
|
@@ -108,55 +110,55 @@ pillow==10.4.0
|
|
108 |
# via gradio-imageslider
|
109 |
# via matplotlib
|
110 |
# via pillow-heif
|
111 |
-
pillow-heif==0.
|
112 |
# via eraser
|
113 |
-
pydantic==2.
|
114 |
# via fastapi
|
115 |
# via gradio
|
116 |
-
pydantic-core==2.
|
117 |
# via pydantic
|
118 |
pydub==0.25.1
|
119 |
# via gradio
|
120 |
-
pygments==2.
|
121 |
# via rich
|
122 |
-
pyparsing==3.1
|
123 |
# via matplotlib
|
124 |
python-dateutil==2.9.0.post0
|
125 |
# via matplotlib
|
126 |
# via pandas
|
127 |
python-dotenv==1.0.1
|
128 |
# via environs
|
129 |
-
python-multipart==0.0.
|
130 |
# via gradio
|
131 |
-
pytz==2024.
|
132 |
# via pandas
|
133 |
pyyaml==6.0.2
|
134 |
# via gradio
|
135 |
# via huggingface-hub
|
136 |
requests==2.32.3
|
137 |
# via huggingface-hub
|
138 |
-
rich==13.
|
139 |
# via typer
|
140 |
-
ruff==0.
|
141 |
# via gradio
|
142 |
semantic-version==2.10.0
|
143 |
# via gradio
|
144 |
shellingham==1.5.4
|
145 |
# via typer
|
146 |
-
six==1.
|
147 |
# via python-dateutil
|
148 |
sniffio==1.3.1
|
149 |
# via anyio
|
150 |
-
|
151 |
-
starlette==0.38.4
|
152 |
# via fastapi
|
153 |
tomlkit==0.12.0
|
154 |
# via gradio
|
155 |
-
tqdm==4.
|
156 |
# via huggingface-hub
|
157 |
-
typer==0.
|
158 |
# via gradio
|
159 |
typing-extensions==4.12.2
|
|
|
160 |
# via fastapi
|
161 |
# via gradio
|
162 |
# via gradio-client
|
@@ -164,12 +166,12 @@ typing-extensions==4.12.2
|
|
164 |
# via pydantic
|
165 |
# via pydantic-core
|
166 |
# via typer
|
167 |
-
tzdata==2024.
|
168 |
# via pandas
|
169 |
-
urllib3==2.
|
170 |
# via gradio
|
171 |
# via requests
|
172 |
-
uvicorn==0.
|
173 |
# via gradio
|
174 |
websockets==12.0
|
175 |
# via gradio-client
|
|
|
14 |
# via gradio
|
15 |
annotated-types==0.7.0
|
16 |
# via pydantic
|
17 |
+
anyio==4.8.0
|
18 |
# via gradio
|
19 |
# via httpx
|
20 |
# via starlette
|
21 |
+
certifi==2024.12.14
|
22 |
# via httpcore
|
23 |
# via httpx
|
24 |
# via requests
|
25 |
+
charset-normalizer==3.4.1
|
26 |
# via requests
|
27 |
+
click==8.1.8
|
28 |
# via typer
|
29 |
# via uvicorn
|
30 |
+
contourpy==1.3.1
|
31 |
# via matplotlib
|
32 |
cycler==0.12.1
|
33 |
# via matplotlib
|
34 |
+
environs==14.1.0
|
35 |
# via eraser
|
36 |
+
fastapi==0.115.6
|
37 |
# via gradio
|
38 |
+
ffmpy==0.5.0
|
39 |
# via gradio
|
40 |
+
filelock==3.16.1
|
41 |
# via huggingface-hub
|
42 |
+
fonttools==4.55.3
|
43 |
# via matplotlib
|
44 |
+
fsspec==2024.12.0
|
45 |
# via gradio-client
|
46 |
# via huggingface-hub
|
47 |
+
gradio==4.44.1
|
48 |
# via eraser
|
49 |
# via gradio-image-annotation
|
50 |
# via gradio-imageslider
|
51 |
gradio-client==1.3.0
|
52 |
# via gradio
|
53 |
+
gradio-image-annotation==0.2.5
|
54 |
# via eraser
|
55 |
gradio-imageslider==0.0.20
|
56 |
# via eraser
|
57 |
h11==0.14.0
|
58 |
# via httpcore
|
59 |
# via uvicorn
|
60 |
+
httpcore==1.0.7
|
61 |
# via httpx
|
62 |
+
httpx==0.28.1
|
63 |
# via eraser
|
64 |
# via gradio
|
65 |
# via gradio-client
|
66 |
+
httpx-sse==0.4.0
|
67 |
+
# via eraser
|
68 |
+
huggingface-hub==0.27.1
|
69 |
# via gradio
|
70 |
# via gradio-client
|
71 |
+
idna==3.10
|
72 |
# via anyio
|
73 |
# via httpx
|
74 |
# via requests
|
75 |
+
importlib-resources==6.5.2
|
76 |
# via gradio
|
77 |
+
jinja2==3.1.5
|
78 |
# via gradio
|
79 |
+
kiwisolver==1.4.8
|
80 |
# via matplotlib
|
81 |
markdown-it-py==3.0.0
|
82 |
# via rich
|
83 |
markupsafe==2.1.5
|
84 |
# via gradio
|
85 |
# via jinja2
|
86 |
+
marshmallow==3.25.1
|
87 |
# via environs
|
88 |
+
matplotlib==3.10.0
|
89 |
# via gradio
|
90 |
mdurl==0.1.2
|
91 |
# via markdown-it-py
|
92 |
+
numpy==2.2.2
|
93 |
# via contourpy
|
94 |
# via gradio
|
95 |
# via matplotlib
|
96 |
# via pandas
|
97 |
+
orjson==3.10.15
|
98 |
# via gradio
|
99 |
+
packaging==24.2
|
100 |
# via gradio
|
101 |
# via gradio-client
|
102 |
# via huggingface-hub
|
103 |
# via marshmallow
|
104 |
# via matplotlib
|
105 |
+
pandas==2.2.3
|
106 |
# via gradio
|
107 |
pillow==10.4.0
|
108 |
# via eraser
|
|
|
110 |
# via gradio-imageslider
|
111 |
# via matplotlib
|
112 |
# via pillow-heif
|
113 |
+
pillow-heif==0.21.0
|
114 |
# via eraser
|
115 |
+
pydantic==2.10.5
|
116 |
# via fastapi
|
117 |
# via gradio
|
118 |
+
pydantic-core==2.27.2
|
119 |
# via pydantic
|
120 |
pydub==0.25.1
|
121 |
# via gradio
|
122 |
+
pygments==2.19.1
|
123 |
# via rich
|
124 |
+
pyparsing==3.2.1
|
125 |
# via matplotlib
|
126 |
python-dateutil==2.9.0.post0
|
127 |
# via matplotlib
|
128 |
# via pandas
|
129 |
python-dotenv==1.0.1
|
130 |
# via environs
|
131 |
+
python-multipart==0.0.20
|
132 |
# via gradio
|
133 |
+
pytz==2024.2
|
134 |
# via pandas
|
135 |
pyyaml==6.0.2
|
136 |
# via gradio
|
137 |
# via huggingface-hub
|
138 |
requests==2.32.3
|
139 |
# via huggingface-hub
|
140 |
+
rich==13.9.4
|
141 |
# via typer
|
142 |
+
ruff==0.9.2
|
143 |
# via gradio
|
144 |
semantic-version==2.10.0
|
145 |
# via gradio
|
146 |
shellingham==1.5.4
|
147 |
# via typer
|
148 |
+
six==1.17.0
|
149 |
# via python-dateutil
|
150 |
sniffio==1.3.1
|
151 |
# via anyio
|
152 |
+
starlette==0.41.3
|
|
|
153 |
# via fastapi
|
154 |
tomlkit==0.12.0
|
155 |
# via gradio
|
156 |
+
tqdm==4.67.1
|
157 |
# via huggingface-hub
|
158 |
+
typer==0.15.1
|
159 |
# via gradio
|
160 |
typing-extensions==4.12.2
|
161 |
+
# via anyio
|
162 |
# via fastapi
|
163 |
# via gradio
|
164 |
# via gradio-client
|
|
|
166 |
# via pydantic
|
167 |
# via pydantic-core
|
168 |
# via typer
|
169 |
+
tzdata==2024.2
|
170 |
# via pandas
|
171 |
+
urllib3==2.3.0
|
172 |
# via gradio
|
173 |
# via requests
|
174 |
+
uvicorn==0.34.0
|
175 |
# via gradio
|
176 |
websockets==12.0
|
177 |
# via gradio-client
|
requirements.txt
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
|
2 |
gradio_imageslider>=0.0.20
|
3 |
environs>=11.0.0
|
4 |
httpx>=0.27.0
|
|
|
5 |
pillow>=10.4.0
|
6 |
pillow-heif>=0.18.0
|
|
|
1 |
+
gradio_image_annotation>=0.2.5
|
2 |
gradio_imageslider>=0.0.20
|
3 |
environs>=11.0.0
|
4 |
httpx>=0.27.0
|
5 |
+
httpx-sse>=0.4.0
|
6 |
pillow>=10.4.0
|
7 |
pillow-heif>=0.18.0
|
src/app.py
CHANGED
@@ -1,14 +1,16 @@
|
|
|
|
1 |
import io
|
2 |
from typing import Any
|
3 |
|
4 |
import gradio as gr
|
5 |
-
import httpx
|
6 |
import pillow_heif
|
7 |
from environs import Env
|
8 |
from gradio_image_annotation import image_annotator
|
9 |
from gradio_imageslider import ImageSlider
|
10 |
from PIL import Image
|
11 |
|
|
|
|
|
12 |
pillow_heif.register_heif_opener()
|
13 |
pillow_heif.register_avif_opener()
|
14 |
|
@@ -16,11 +18,16 @@ env = Env()
|
|
16 |
env.read_env()
|
17 |
|
18 |
with env.prefixed("ERASER_"):
|
19 |
-
API_URL: str = str(env.str("API_URL", "https://
|
20 |
-
|
|
|
21 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
22 |
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
|
@@ -31,38 +38,97 @@ def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
|
|
31 |
return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side))
|
32 |
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
assert isinstance(img := prompts["image"], Image.Image)
|
39 |
assert isinstance(boxes := prompts["boxes"], list)
|
40 |
assert len(boxes) == 1
|
41 |
assert isinstance(box := boxes[0], dict)
|
42 |
-
headers = {}
|
43 |
-
if request: # avoid DOS - can be None despite type hint!
|
44 |
-
client_ip = request.headers.get("x-forwarded-for") or request.client.host
|
45 |
-
headers = {"X-HF-Client-IP": client_ip}
|
46 |
|
47 |
resized_img = resize(img)
|
48 |
bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]]
|
49 |
if resized_img.width != img.width:
|
50 |
bbox = [int(v * resized_img.width / img.width) for v in bbox]
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
timeout=30.0,
|
60 |
-
auth=auth,
|
61 |
-
headers=headers,
|
62 |
-
)
|
63 |
-
r.raise_for_status()
|
64 |
-
|
65 |
-
output_image = Image.open(io.BytesIO(r.content))
|
66 |
return (img, output_image)
|
67 |
|
68 |
|
@@ -70,32 +136,12 @@ def on_change_bbox(prompts: dict[str, Any] | None):
|
|
70 |
return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0)
|
71 |
|
72 |
|
73 |
-
def process_prompt(
|
74 |
-
img: Image.Image,
|
75 |
-
prompt: str,
|
76 |
-
request: gr.Request | None,
|
77 |
-
) -> tuple[Image.Image, Image.Image]:
|
78 |
-
headers = {}
|
79 |
-
if request: # avoid DOS - can be None despite type hint!
|
80 |
-
client_ip = request.headers.get("x-forwarded-for") or request.client.host
|
81 |
-
headers = {"X-HF-Client-IP": client_ip}
|
82 |
-
|
83 |
resized_img = resize(img)
|
84 |
-
|
85 |
-
|
86 |
-
resized_img
|
87 |
-
|
88 |
-
API_URL,
|
89 |
-
data={"prompt": prompt},
|
90 |
-
files={"file": f},
|
91 |
-
verify=CA_BUNDLE or True,
|
92 |
-
timeout=30.0,
|
93 |
-
auth=auth,
|
94 |
-
headers=headers,
|
95 |
-
)
|
96 |
-
r.raise_for_status()
|
97 |
-
|
98 |
-
output_image = Image.open(io.BytesIO(r.content))
|
99 |
return (img, output_image)
|
100 |
|
101 |
|
@@ -112,8 +158,9 @@ TITLE = """
|
|
112 |
padding: 0.5rem 1rem;
|
113 |
font-size: 1.25rem;
|
114 |
">
|
115 |
-
🚀 For an optimized version of this space, try out the
|
116 |
-
<a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-eraser" target="_blank">Finegrain Editor</a>!
|
|
|
117 |
</div>
|
118 |
|
119 |
<h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
|
|
|
1 |
+
import dataclasses as dc
|
2 |
import io
|
3 |
from typing import Any
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
import pillow_heif
|
7 |
from environs import Env
|
8 |
from gradio_image_annotation import image_annotator
|
9 |
from gradio_imageslider import ImageSlider
|
10 |
from PIL import Image
|
11 |
|
12 |
+
from fg import EditorAPIContext
|
13 |
+
|
14 |
pillow_heif.register_heif_opener()
|
15 |
pillow_heif.register_avif_opener()
|
16 |
|
|
|
18 |
env.read_env()
|
19 |
|
20 |
with env.prefixed("ERASER_"):
|
21 |
+
API_URL: str = str(env.str("API_URL", "https://api.finegrain.ai/editor"))
|
22 |
+
API_USER: str | None = env.str("API_USER")
|
23 |
+
API_PASSWORD: str | None = env.str("API_PASSWORD")
|
24 |
CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
|
25 |
|
26 |
+
assert API_USER is not None
|
27 |
+
assert API_PASSWORD is not None
|
28 |
+
CTX = EditorAPIContext(uri=API_URL, user=API_USER, password=API_PASSWORD)
|
29 |
+
if CA_BUNDLE:
|
30 |
+
CTX.verify = CA_BUNDLE
|
31 |
|
32 |
|
33 |
def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
|
|
|
38 |
return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side))
|
39 |
|
40 |
|
41 |
+
@dc.dataclass(kw_only=True)
|
42 |
+
class ProcessParams:
|
43 |
+
image: Image.Image
|
44 |
+
prompt: str | None = None
|
45 |
+
bbox: tuple[int, int, int, int] | None = None
|
46 |
+
|
47 |
+
|
48 |
+
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
49 |
+
with io.BytesIO() as f:
|
50 |
+
params.image.save(f, format="JPEG")
|
51 |
+
async with ctx as client:
|
52 |
+
response = await client.post(
|
53 |
+
f"{ctx.uri}/state/upload",
|
54 |
+
files={"file": f},
|
55 |
+
headers=ctx.auth_headers,
|
56 |
+
)
|
57 |
+
response.raise_for_status()
|
58 |
+
st_input = response.json()["state"]
|
59 |
+
|
60 |
+
if params.bbox:
|
61 |
+
segment_input_st = st_input
|
62 |
+
segment_params = {"bbox": list(params.bbox)}
|
63 |
+
else:
|
64 |
+
assert params.prompt
|
65 |
+
async with ctx as client:
|
66 |
+
response = await client.post(
|
67 |
+
f"{ctx.uri}/skills/infer-bbox/{st_input}",
|
68 |
+
json={"product_name": params.prompt},
|
69 |
+
headers=ctx.auth_headers,
|
70 |
+
)
|
71 |
+
response.raise_for_status()
|
72 |
+
st_bbox = response.json()["state"]
|
73 |
+
await ctx.sse_await(st_bbox)
|
74 |
+
segment_input_st = st_bbox
|
75 |
+
segment_params = {}
|
76 |
+
|
77 |
+
async with ctx as client:
|
78 |
+
response = await client.post(
|
79 |
+
f"{ctx.uri}/skills/segment/{segment_input_st}",
|
80 |
+
json=segment_params,
|
81 |
+
headers=ctx.auth_headers,
|
82 |
+
)
|
83 |
+
response.raise_for_status()
|
84 |
+
st_mask = response.json()["state"]
|
85 |
+
await ctx.sse_await(st_mask)
|
86 |
+
|
87 |
+
erase_params: dict[str, str | bool] = {
|
88 |
+
"mode": "free", # new API
|
89 |
+
"restore_original_resolution": False, # legacy API
|
90 |
+
}
|
91 |
+
async with ctx as client:
|
92 |
+
response = await client.post(
|
93 |
+
f"{ctx.uri}/skills/erase/{st_input}/{st_mask}",
|
94 |
+
json=erase_params,
|
95 |
+
headers=ctx.auth_headers,
|
96 |
+
)
|
97 |
+
response.raise_for_status()
|
98 |
+
st_erased = response.json()["state"]
|
99 |
+
await ctx.sse_await(st_erased)
|
100 |
+
|
101 |
+
async with ctx as client:
|
102 |
+
response = await client.get(
|
103 |
+
f"{ctx.uri}/state/image/{st_erased}",
|
104 |
+
params={"format": "JPEG", "resolution": "DISPLAY"},
|
105 |
+
headers=ctx.auth_headers,
|
106 |
+
)
|
107 |
+
response.raise_for_status()
|
108 |
+
f = io.BytesIO()
|
109 |
+
f.write(response.content)
|
110 |
+
f.seek(0)
|
111 |
+
return Image.open(f)
|
112 |
+
|
113 |
+
|
114 |
+
def process_bbox(prompts: dict[str, Any]) -> tuple[Image.Image, Image.Image]:
|
115 |
assert isinstance(img := prompts["image"], Image.Image)
|
116 |
assert isinstance(boxes := prompts["boxes"], list)
|
117 |
assert len(boxes) == 1
|
118 |
assert isinstance(box := boxes[0], dict)
|
|
|
|
|
|
|
|
|
119 |
|
120 |
resized_img = resize(img)
|
121 |
bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]]
|
122 |
if resized_img.width != img.width:
|
123 |
bbox = [int(v * resized_img.width / img.width) for v in bbox]
|
124 |
|
125 |
+
output_image = CTX.run_one_sync(
|
126 |
+
_process,
|
127 |
+
ProcessParams(
|
128 |
+
image=resized_img,
|
129 |
+
bbox=(bbox[0], bbox[1], bbox[2], bbox[3]),
|
130 |
+
),
|
131 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
return (img, output_image)
|
133 |
|
134 |
|
|
|
136 |
return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0)
|
137 |
|
138 |
|
139 |
+
def process_prompt(img: Image.Image, prompt: str) -> tuple[Image.Image, Image.Image]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
resized_img = resize(img)
|
141 |
+
output_image = CTX.run_one_sync(
|
142 |
+
_process,
|
143 |
+
ProcessParams(image=resized_img, prompt=prompt),
|
144 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return (img, output_image)
|
146 |
|
147 |
|
|
|
158 |
padding: 0.5rem 1rem;
|
159 |
font-size: 1.25rem;
|
160 |
">
|
161 |
+
🚀 For an optimized version of this space, try out the
|
162 |
+
<a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-eraser" target="_blank">Finegrain Editor</a>!
|
163 |
+
You'll find there all our AI tools made available in a nice UI. 🚀
|
164 |
</div>
|
165 |
|
166 |
<h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
|
src/fg.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import dataclasses as dc
|
3 |
+
import json
|
4 |
+
from collections import defaultdict
|
5 |
+
from collections.abc import Awaitable, Callable
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import httpx
|
9 |
+
import httpx_sse
|
10 |
+
|
11 |
+
|
12 |
+
def _new_future() -> asyncio.Future[Any]:
|
13 |
+
return asyncio.get_running_loop().create_future()
|
14 |
+
|
15 |
+
|
16 |
+
@dc.dataclass(kw_only=True)
|
17 |
+
class EditorAPIContext:
|
18 |
+
uri: str
|
19 |
+
user: str
|
20 |
+
password: str
|
21 |
+
token: str | None = None
|
22 |
+
verify: bool | str = True
|
23 |
+
_client: httpx.AsyncClient | None = None
|
24 |
+
|
25 |
+
sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future))
|
26 |
+
|
27 |
+
async def __aenter__(self) -> httpx.AsyncClient:
|
28 |
+
if self._client:
|
29 |
+
return self._client
|
30 |
+
self._client = httpx.AsyncClient(verify=self.verify)
|
31 |
+
return self._client
|
32 |
+
|
33 |
+
async def __aexit__(self, *args: Any) -> None:
|
34 |
+
if self._client:
|
35 |
+
await self._client.__aexit__(*args)
|
36 |
+
self._client = None
|
37 |
+
|
38 |
+
@property
|
39 |
+
def auth_headers(self) -> dict[str, str]:
|
40 |
+
assert self.token
|
41 |
+
return {"Authorization": f"Bearer {self.token}"}
|
42 |
+
|
43 |
+
async def login(self) -> None:
|
44 |
+
async with self as client:
|
45 |
+
response = await client.post(
|
46 |
+
f"{self.uri}/auth/login",
|
47 |
+
json={"username": self.user, "password": self.password},
|
48 |
+
)
|
49 |
+
response.raise_for_status()
|
50 |
+
self.token = response.json()["token"]
|
51 |
+
|
52 |
+
async def sse_loop(self) -> None:
|
53 |
+
async with self as client:
|
54 |
+
response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers)
|
55 |
+
response.raise_for_status()
|
56 |
+
sub_token = response.json()["token"]
|
57 |
+
url = f"{self.uri}/sub/{sub_token}"
|
58 |
+
async with (
|
59 |
+
httpx.AsyncClient(timeout=None, verify=self.verify) as c,
|
60 |
+
httpx_sse.aconnect_sse(c, "GET", url) as es,
|
61 |
+
):
|
62 |
+
future = self.sse_futures["_sse_loop"]
|
63 |
+
future.set_result({"status": "ok"})
|
64 |
+
async for sse in es.aiter_sse():
|
65 |
+
jdata = json.loads(sse.data)
|
66 |
+
future = self.sse_futures[jdata["state"]]
|
67 |
+
future.set_result(jdata)
|
68 |
+
|
69 |
+
async def sse_await(self, state_id: str) -> None:
|
70 |
+
future = self.sse_futures[state_id]
|
71 |
+
jdata = await future
|
72 |
+
if jdata["status"] != "ok":
|
73 |
+
print("ERROR", jdata)
|
74 |
+
assert jdata["status"] == "ok"
|
75 |
+
del self.sse_futures[state_id]
|
76 |
+
|
77 |
+
async def get_meta(self, state_id: str) -> dict[str, Any]:
|
78 |
+
async with self as client:
|
79 |
+
response = await client.get(
|
80 |
+
f"{self.uri}/state/meta/{state_id}",
|
81 |
+
headers=self.auth_headers,
|
82 |
+
)
|
83 |
+
response.raise_for_status()
|
84 |
+
return response.json()
|
85 |
+
|
86 |
+
async def run_one[Tin, Tout](
|
87 |
+
self,
|
88 |
+
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
89 |
+
params: Tin,
|
90 |
+
) -> Tout:
|
91 |
+
await self.login()
|
92 |
+
async with asyncio.TaskGroup() as tg:
|
93 |
+
sse_task = tg.create_task(self.sse_loop())
|
94 |
+
|
95 |
+
async def outer_co(params: Tin) -> Tout:
|
96 |
+
# _sse_loop is a fake event to wait until the SSE loop is properly setup.
|
97 |
+
await self.sse_await("_sse_loop")
|
98 |
+
r = await co(self, params)
|
99 |
+
sse_task.cancel()
|
100 |
+
return r
|
101 |
+
|
102 |
+
r = tg.create_task(outer_co(params))
|
103 |
+
|
104 |
+
return r.result()
|
105 |
+
|
106 |
+
def run_one_sync[Tin, Tout](
|
107 |
+
self,
|
108 |
+
co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
|
109 |
+
params: Tin,
|
110 |
+
) -> Tout:
|
111 |
+
try:
|
112 |
+
loop = asyncio.get_event_loop()
|
113 |
+
except RuntimeError:
|
114 |
+
loop = asyncio.new_event_loop()
|
115 |
+
asyncio.set_event_loop(loop)
|
116 |
+
|
117 |
+
return loop.run_until_complete(self.run_one(co, params))
|