Pierre Chapuis commited on
Commit
e619418
·
unverified ·
1 Parent(s): 7e847c7

update to use the official Finegrain API directly

Browse files
README.md CHANGED
@@ -4,7 +4,8 @@ emoji: 🧽
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.42.0
 
8
  app_file: src/app.py
9
  pinned: false
10
  license: other
 
4
  colorFrom: gray
5
  colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
+ python_version: 3.12
9
  app_file: src/app.py
10
  pinned: false
11
  license: other
gradio_image_annotation-0.2.3-py3-none-any.whl DELETED
Binary file (85.3 kB)
 
pyproject.toml CHANGED
@@ -6,13 +6,14 @@ authors = [
6
  { name = "Pierre Chapuis", email = "[email protected]" }
7
  ]
8
  dependencies = [
9
- "gradio>=4.41.0",
10
  "environs>=11.0.0",
11
- "gradio-image-annotation @ https://huggingface.co/spaces/finegrain/finegrain-object-eraser/resolve/main/gradio_image_annotation-0.2.3-py3-none-any.whl",
12
  "httpx>=0.27.0",
13
  "pillow>=10.4.0",
14
  "gradio-imageslider>=0.0.20",
15
  "pillow-heif>=0.18.0",
 
16
  ]
17
  readme = "README.md"
18
  requires-python = ">= 3.12, <3.13"
 
6
  { name = "Pierre Chapuis", email = "[email protected]" }
7
  ]
8
  dependencies = [
9
+ "gradio>=4.41.0,<5", # gradio-imageslider requires <5
10
  "environs>=11.0.0",
11
+ "gradio-image-annotation>=0.2.5",
12
  "httpx>=0.27.0",
13
  "pillow>=10.4.0",
14
  "gradio-imageslider>=0.0.20",
15
  "pillow-heif>=0.18.0",
16
+ "httpx-sse>=0.4.0",
17
  ]
18
  readme = "README.md"
19
  requires-python = ">= 3.12, <3.13"
requirements.lock CHANGED
@@ -14,93 +14,95 @@ aiofiles==23.2.1
14
  # via gradio
15
  annotated-types==0.7.0
16
  # via pydantic
17
- anyio==4.4.0
18
  # via gradio
19
  # via httpx
20
  # via starlette
21
- certifi==2024.8.30
22
  # via httpcore
23
  # via httpx
24
  # via requests
25
- charset-normalizer==3.3.2
26
  # via requests
27
- click==8.1.7
28
  # via typer
29
  # via uvicorn
30
- contourpy==1.3.0
31
  # via matplotlib
32
  cycler==0.12.1
33
  # via matplotlib
34
- environs==11.0.0
35
  # via eraser
36
- fastapi==0.112.2
37
  # via gradio
38
- ffmpy==0.4.0
39
  # via gradio
40
- filelock==3.15.4
41
  # via huggingface-hub
42
- fonttools==4.53.1
43
  # via matplotlib
44
- fsspec==2024.6.1
45
  # via gradio-client
46
  # via huggingface-hub
47
- gradio==4.42.0
48
  # via eraser
49
  # via gradio-image-annotation
50
  # via gradio-imageslider
51
  gradio-client==1.3.0
52
  # via gradio
53
- gradio-image-annotation @ https://huggingface.co/spaces/finegrain/finegrain-object-eraser/resolve/main/gradio_image_annotation-0.2.3-py3-none-any.whl
54
  # via eraser
55
  gradio-imageslider==0.0.20
56
  # via eraser
57
  h11==0.14.0
58
  # via httpcore
59
  # via uvicorn
60
- httpcore==1.0.5
61
  # via httpx
62
- httpx==0.27.2
63
  # via eraser
64
  # via gradio
65
  # via gradio-client
66
- huggingface-hub==0.24.6
 
 
67
  # via gradio
68
  # via gradio-client
69
- idna==3.8
70
  # via anyio
71
  # via httpx
72
  # via requests
73
- importlib-resources==6.4.4
74
  # via gradio
75
- jinja2==3.1.4
76
  # via gradio
77
- kiwisolver==1.4.5
78
  # via matplotlib
79
  markdown-it-py==3.0.0
80
  # via rich
81
  markupsafe==2.1.5
82
  # via gradio
83
  # via jinja2
84
- marshmallow==3.22.0
85
  # via environs
86
- matplotlib==3.9.2
87
  # via gradio
88
  mdurl==0.1.2
89
  # via markdown-it-py
90
- numpy==2.1.0
91
  # via contourpy
92
  # via gradio
93
  # via matplotlib
94
  # via pandas
95
- orjson==3.10.7
96
  # via gradio
97
- packaging==24.1
98
  # via gradio
99
  # via gradio-client
100
  # via huggingface-hub
101
  # via marshmallow
102
  # via matplotlib
103
- pandas==2.2.2
104
  # via gradio
105
  pillow==10.4.0
106
  # via eraser
@@ -108,55 +110,55 @@ pillow==10.4.0
108
  # via gradio-imageslider
109
  # via matplotlib
110
  # via pillow-heif
111
- pillow-heif==0.18.0
112
  # via eraser
113
- pydantic==2.8.2
114
  # via fastapi
115
  # via gradio
116
- pydantic-core==2.20.1
117
  # via pydantic
118
  pydub==0.25.1
119
  # via gradio
120
- pygments==2.18.0
121
  # via rich
122
- pyparsing==3.1.4
123
  # via matplotlib
124
  python-dateutil==2.9.0.post0
125
  # via matplotlib
126
  # via pandas
127
  python-dotenv==1.0.1
128
  # via environs
129
- python-multipart==0.0.9
130
  # via gradio
131
- pytz==2024.1
132
  # via pandas
133
  pyyaml==6.0.2
134
  # via gradio
135
  # via huggingface-hub
136
  requests==2.32.3
137
  # via huggingface-hub
138
- rich==13.8.0
139
  # via typer
140
- ruff==0.6.3
141
  # via gradio
142
  semantic-version==2.10.0
143
  # via gradio
144
  shellingham==1.5.4
145
  # via typer
146
- six==1.16.0
147
  # via python-dateutil
148
  sniffio==1.3.1
149
  # via anyio
150
- # via httpx
151
- starlette==0.38.4
152
  # via fastapi
153
  tomlkit==0.12.0
154
  # via gradio
155
- tqdm==4.66.5
156
  # via huggingface-hub
157
- typer==0.12.5
158
  # via gradio
159
  typing-extensions==4.12.2
 
160
  # via fastapi
161
  # via gradio
162
  # via gradio-client
@@ -164,12 +166,12 @@ typing-extensions==4.12.2
164
  # via pydantic
165
  # via pydantic-core
166
  # via typer
167
- tzdata==2024.1
168
  # via pandas
169
- urllib3==2.2.2
170
  # via gradio
171
  # via requests
172
- uvicorn==0.30.6
173
  # via gradio
174
  websockets==12.0
175
  # via gradio-client
 
14
  # via gradio
15
  annotated-types==0.7.0
16
  # via pydantic
17
+ anyio==4.8.0
18
  # via gradio
19
  # via httpx
20
  # via starlette
21
+ certifi==2024.12.14
22
  # via httpcore
23
  # via httpx
24
  # via requests
25
+ charset-normalizer==3.4.1
26
  # via requests
27
+ click==8.1.8
28
  # via typer
29
  # via uvicorn
30
+ contourpy==1.3.1
31
  # via matplotlib
32
  cycler==0.12.1
33
  # via matplotlib
34
+ environs==14.1.0
35
  # via eraser
36
+ fastapi==0.115.6
37
  # via gradio
38
+ ffmpy==0.5.0
39
  # via gradio
40
+ filelock==3.16.1
41
  # via huggingface-hub
42
+ fonttools==4.55.3
43
  # via matplotlib
44
+ fsspec==2024.12.0
45
  # via gradio-client
46
  # via huggingface-hub
47
+ gradio==4.44.1
48
  # via eraser
49
  # via gradio-image-annotation
50
  # via gradio-imageslider
51
  gradio-client==1.3.0
52
  # via gradio
53
+ gradio-image-annotation==0.2.5
54
  # via eraser
55
  gradio-imageslider==0.0.20
56
  # via eraser
57
  h11==0.14.0
58
  # via httpcore
59
  # via uvicorn
60
+ httpcore==1.0.7
61
  # via httpx
62
+ httpx==0.28.1
63
  # via eraser
64
  # via gradio
65
  # via gradio-client
66
+ httpx-sse==0.4.0
67
+ # via eraser
68
+ huggingface-hub==0.27.1
69
  # via gradio
70
  # via gradio-client
71
+ idna==3.10
72
  # via anyio
73
  # via httpx
74
  # via requests
75
+ importlib-resources==6.5.2
76
  # via gradio
77
+ jinja2==3.1.5
78
  # via gradio
79
+ kiwisolver==1.4.8
80
  # via matplotlib
81
  markdown-it-py==3.0.0
82
  # via rich
83
  markupsafe==2.1.5
84
  # via gradio
85
  # via jinja2
86
+ marshmallow==3.25.1
87
  # via environs
88
+ matplotlib==3.10.0
89
  # via gradio
90
  mdurl==0.1.2
91
  # via markdown-it-py
92
+ numpy==2.2.2
93
  # via contourpy
94
  # via gradio
95
  # via matplotlib
96
  # via pandas
97
+ orjson==3.10.15
98
  # via gradio
99
+ packaging==24.2
100
  # via gradio
101
  # via gradio-client
102
  # via huggingface-hub
103
  # via marshmallow
104
  # via matplotlib
105
+ pandas==2.2.3
106
  # via gradio
107
  pillow==10.4.0
108
  # via eraser
 
110
  # via gradio-imageslider
111
  # via matplotlib
112
  # via pillow-heif
113
+ pillow-heif==0.21.0
114
  # via eraser
115
+ pydantic==2.10.5
116
  # via fastapi
117
  # via gradio
118
+ pydantic-core==2.27.2
119
  # via pydantic
120
  pydub==0.25.1
121
  # via gradio
122
+ pygments==2.19.1
123
  # via rich
124
+ pyparsing==3.2.1
125
  # via matplotlib
126
  python-dateutil==2.9.0.post0
127
  # via matplotlib
128
  # via pandas
129
  python-dotenv==1.0.1
130
  # via environs
131
+ python-multipart==0.0.20
132
  # via gradio
133
+ pytz==2024.2
134
  # via pandas
135
  pyyaml==6.0.2
136
  # via gradio
137
  # via huggingface-hub
138
  requests==2.32.3
139
  # via huggingface-hub
140
+ rich==13.9.4
141
  # via typer
142
+ ruff==0.9.2
143
  # via gradio
144
  semantic-version==2.10.0
145
  # via gradio
146
  shellingham==1.5.4
147
  # via typer
148
+ six==1.17.0
149
  # via python-dateutil
150
  sniffio==1.3.1
151
  # via anyio
152
+ starlette==0.41.3
 
153
  # via fastapi
154
  tomlkit==0.12.0
155
  # via gradio
156
+ tqdm==4.67.1
157
  # via huggingface-hub
158
+ typer==0.15.1
159
  # via gradio
160
  typing-extensions==4.12.2
161
+ # via anyio
162
  # via fastapi
163
  # via gradio
164
  # via gradio-client
 
166
  # via pydantic
167
  # via pydantic-core
168
  # via typer
169
+ tzdata==2024.2
170
  # via pandas
171
+ urllib3==2.3.0
172
  # via gradio
173
  # via requests
174
+ uvicorn==0.34.0
175
  # via gradio
176
  websockets==12.0
177
  # via gradio-client
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
- https://huggingface.co/spaces/finegrain/finegrain-object-eraser/resolve/main/gradio_image_annotation-0.2.3-py3-none-any.whl
2
  gradio_imageslider>=0.0.20
3
  environs>=11.0.0
4
  httpx>=0.27.0
 
5
  pillow>=10.4.0
6
  pillow-heif>=0.18.0
 
1
+ gradio_image_annotation>=0.2.5
2
  gradio_imageslider>=0.0.20
3
  environs>=11.0.0
4
  httpx>=0.27.0
5
+ httpx-sse>=0.4.0
6
  pillow>=10.4.0
7
  pillow-heif>=0.18.0
src/app.py CHANGED
@@ -1,14 +1,16 @@
 
1
  import io
2
  from typing import Any
3
 
4
  import gradio as gr
5
- import httpx
6
  import pillow_heif
7
  from environs import Env
8
  from gradio_image_annotation import image_annotator
9
  from gradio_imageslider import ImageSlider
10
  from PIL import Image
11
 
 
 
12
  pillow_heif.register_heif_opener()
13
  pillow_heif.register_avif_opener()
14
 
@@ -16,11 +18,16 @@ env = Env()
16
  env.read_env()
17
 
18
  with env.prefixed("ERASER_"):
19
- API_URL: str = str(env.str("API_URL", "https://spaces.finegrain.ai/eraser"))
20
- API_KEY: str | None = env.str("API_KEY", None)
 
21
  CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
22
 
23
- auth = None if API_KEY is None else httpx.BasicAuth("hf", API_KEY)
 
 
 
 
24
 
25
 
26
  def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
@@ -31,38 +38,97 @@ def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
31
  return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side))
32
 
33
 
34
- def process_bbox(
35
- prompts: dict[str, Any],
36
- request: gr.Request | None,
37
- ) -> tuple[Image.Image, Image.Image]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  assert isinstance(img := prompts["image"], Image.Image)
39
  assert isinstance(boxes := prompts["boxes"], list)
40
  assert len(boxes) == 1
41
  assert isinstance(box := boxes[0], dict)
42
- headers = {}
43
- if request: # avoid DOS - can be None despite type hint!
44
- client_ip = request.headers.get("x-forwarded-for") or request.client.host
45
- headers = {"X-HF-Client-IP": client_ip}
46
 
47
  resized_img = resize(img)
48
  bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]]
49
  if resized_img.width != img.width:
50
  bbox = [int(v * resized_img.width / img.width) for v in bbox]
51
 
52
- with io.BytesIO() as f:
53
- resized_img.save(f, format="JPEG")
54
- r = httpx.post(
55
- API_URL,
56
- data={"bbox": ",".join([str(v) for v in bbox])},
57
- files={"file": f},
58
- verify=CA_BUNDLE or True,
59
- timeout=30.0,
60
- auth=auth,
61
- headers=headers,
62
- )
63
- r.raise_for_status()
64
-
65
- output_image = Image.open(io.BytesIO(r.content))
66
  return (img, output_image)
67
 
68
 
@@ -70,32 +136,12 @@ def on_change_bbox(prompts: dict[str, Any] | None):
70
  return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0)
71
 
72
 
73
- def process_prompt(
74
- img: Image.Image,
75
- prompt: str,
76
- request: gr.Request | None,
77
- ) -> tuple[Image.Image, Image.Image]:
78
- headers = {}
79
- if request: # avoid DOS - can be None despite type hint!
80
- client_ip = request.headers.get("x-forwarded-for") or request.client.host
81
- headers = {"X-HF-Client-IP": client_ip}
82
-
83
  resized_img = resize(img)
84
-
85
- with io.BytesIO() as f:
86
- resized_img.save(f, format="JPEG")
87
- r = httpx.post(
88
- API_URL,
89
- data={"prompt": prompt},
90
- files={"file": f},
91
- verify=CA_BUNDLE or True,
92
- timeout=30.0,
93
- auth=auth,
94
- headers=headers,
95
- )
96
- r.raise_for_status()
97
-
98
- output_image = Image.open(io.BytesIO(r.content))
99
  return (img, output_image)
100
 
101
 
@@ -112,8 +158,9 @@ TITLE = """
112
  padding: 0.5rem 1rem;
113
  font-size: 1.25rem;
114
  ">
115
- 🚀 For an optimized version of this space, try out the
116
- <a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-eraser" target="_blank">Finegrain Editor</a>! You'll find there all our AI tools made available in a nice UI. 🚀
 
117
  </div>
118
 
119
  <h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
 
1
+ import dataclasses as dc
2
  import io
3
  from typing import Any
4
 
5
  import gradio as gr
 
6
  import pillow_heif
7
  from environs import Env
8
  from gradio_image_annotation import image_annotator
9
  from gradio_imageslider import ImageSlider
10
  from PIL import Image
11
 
12
+ from fg import EditorAPIContext
13
+
14
  pillow_heif.register_heif_opener()
15
  pillow_heif.register_avif_opener()
16
 
 
18
  env.read_env()
19
 
20
  with env.prefixed("ERASER_"):
21
+ API_URL: str = str(env.str("API_URL", "https://api.finegrain.ai/editor"))
22
+ API_USER: str | None = env.str("API_USER")
23
+ API_PASSWORD: str | None = env.str("API_PASSWORD")
24
  CA_BUNDLE: str | None = env.str("CA_BUNDLE", None)
25
 
26
+ assert API_USER is not None
27
+ assert API_PASSWORD is not None
28
+ CTX = EditorAPIContext(uri=API_URL, user=API_USER, password=API_PASSWORD)
29
+ if CA_BUNDLE:
30
+ CTX.verify = CA_BUNDLE
31
 
32
 
33
  def resize(image: Image.Image, shortest_side: int = 768) -> Image.Image:
 
38
  return image.resize(size=(int(shortest_side * image.width / image.height), shortest_side))
39
 
40
 
41
+ @dc.dataclass(kw_only=True)
42
+ class ProcessParams:
43
+ image: Image.Image
44
+ prompt: str | None = None
45
+ bbox: tuple[int, int, int, int] | None = None
46
+
47
+
48
+ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
49
+ with io.BytesIO() as f:
50
+ params.image.save(f, format="JPEG")
51
+ async with ctx as client:
52
+ response = await client.post(
53
+ f"{ctx.uri}/state/upload",
54
+ files={"file": f},
55
+ headers=ctx.auth_headers,
56
+ )
57
+ response.raise_for_status()
58
+ st_input = response.json()["state"]
59
+
60
+ if params.bbox:
61
+ segment_input_st = st_input
62
+ segment_params = {"bbox": list(params.bbox)}
63
+ else:
64
+ assert params.prompt
65
+ async with ctx as client:
66
+ response = await client.post(
67
+ f"{ctx.uri}/skills/infer-bbox/{st_input}",
68
+ json={"product_name": params.prompt},
69
+ headers=ctx.auth_headers,
70
+ )
71
+ response.raise_for_status()
72
+ st_bbox = response.json()["state"]
73
+ await ctx.sse_await(st_bbox)
74
+ segment_input_st = st_bbox
75
+ segment_params = {}
76
+
77
+ async with ctx as client:
78
+ response = await client.post(
79
+ f"{ctx.uri}/skills/segment/{segment_input_st}",
80
+ json=segment_params,
81
+ headers=ctx.auth_headers,
82
+ )
83
+ response.raise_for_status()
84
+ st_mask = response.json()["state"]
85
+ await ctx.sse_await(st_mask)
86
+
87
+ erase_params: dict[str, str | bool] = {
88
+ "mode": "free", # new API
89
+ "restore_original_resolution": False, # legacy API
90
+ }
91
+ async with ctx as client:
92
+ response = await client.post(
93
+ f"{ctx.uri}/skills/erase/{st_input}/{st_mask}",
94
+ json=erase_params,
95
+ headers=ctx.auth_headers,
96
+ )
97
+ response.raise_for_status()
98
+ st_erased = response.json()["state"]
99
+ await ctx.sse_await(st_erased)
100
+
101
+ async with ctx as client:
102
+ response = await client.get(
103
+ f"{ctx.uri}/state/image/{st_erased}",
104
+ params={"format": "JPEG", "resolution": "DISPLAY"},
105
+ headers=ctx.auth_headers,
106
+ )
107
+ response.raise_for_status()
108
+ f = io.BytesIO()
109
+ f.write(response.content)
110
+ f.seek(0)
111
+ return Image.open(f)
112
+
113
+
114
+ def process_bbox(prompts: dict[str, Any]) -> tuple[Image.Image, Image.Image]:
115
  assert isinstance(img := prompts["image"], Image.Image)
116
  assert isinstance(boxes := prompts["boxes"], list)
117
  assert len(boxes) == 1
118
  assert isinstance(box := boxes[0], dict)
 
 
 
 
119
 
120
  resized_img = resize(img)
121
  bbox = [box[k] for k in ["xmin", "ymin", "xmax", "ymax"]]
122
  if resized_img.width != img.width:
123
  bbox = [int(v * resized_img.width / img.width) for v in bbox]
124
 
125
+ output_image = CTX.run_one_sync(
126
+ _process,
127
+ ProcessParams(
128
+ image=resized_img,
129
+ bbox=(bbox[0], bbox[1], bbox[2], bbox[3]),
130
+ ),
131
+ )
 
 
 
 
 
 
 
132
  return (img, output_image)
133
 
134
 
 
136
  return gr.update(interactive=prompts is not None and len(prompts["boxes"]) > 0)
137
 
138
 
139
+ def process_prompt(img: Image.Image, prompt: str) -> tuple[Image.Image, Image.Image]:
 
 
 
 
 
 
 
 
 
140
  resized_img = resize(img)
141
+ output_image = CTX.run_one_sync(
142
+ _process,
143
+ ProcessParams(image=resized_img, prompt=prompt),
144
+ )
 
 
 
 
 
 
 
 
 
 
 
145
  return (img, output_image)
146
 
147
 
 
158
  padding: 0.5rem 1rem;
159
  font-size: 1.25rem;
160
  ">
161
+ 🚀 For an optimized version of this space, try out the
162
+ <a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-eraser" target="_blank">Finegrain Editor</a>!
163
+ You'll find there all our AI tools made available in a nice UI. 🚀
164
  </div>
165
 
166
  <h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
src/fg.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import dataclasses as dc
3
+ import json
4
+ from collections import defaultdict
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import Any
7
+
8
+ import httpx
9
+ import httpx_sse
10
+
11
+
12
+ def _new_future() -> asyncio.Future[Any]:
13
+ return asyncio.get_running_loop().create_future()
14
+
15
+
16
+ @dc.dataclass(kw_only=True)
17
+ class EditorAPIContext:
18
+ uri: str
19
+ user: str
20
+ password: str
21
+ token: str | None = None
22
+ verify: bool | str = True
23
+ _client: httpx.AsyncClient | None = None
24
+
25
+ sse_futures: dict[str, asyncio.Future[dict[str, Any]]] = dc.field(default_factory=lambda: defaultdict(_new_future))
26
+
27
+ async def __aenter__(self) -> httpx.AsyncClient:
28
+ if self._client:
29
+ return self._client
30
+ self._client = httpx.AsyncClient(verify=self.verify)
31
+ return self._client
32
+
33
+ async def __aexit__(self, *args: Any) -> None:
34
+ if self._client:
35
+ await self._client.__aexit__(*args)
36
+ self._client = None
37
+
38
+ @property
39
+ def auth_headers(self) -> dict[str, str]:
40
+ assert self.token
41
+ return {"Authorization": f"Bearer {self.token}"}
42
+
43
+ async def login(self) -> None:
44
+ async with self as client:
45
+ response = await client.post(
46
+ f"{self.uri}/auth/login",
47
+ json={"username": self.user, "password": self.password},
48
+ )
49
+ response.raise_for_status()
50
+ self.token = response.json()["token"]
51
+
52
+ async def sse_loop(self) -> None:
53
+ async with self as client:
54
+ response = await client.post(f"{self.uri}/sub-auth", headers=self.auth_headers)
55
+ response.raise_for_status()
56
+ sub_token = response.json()["token"]
57
+ url = f"{self.uri}/sub/{sub_token}"
58
+ async with (
59
+ httpx.AsyncClient(timeout=None, verify=self.verify) as c,
60
+ httpx_sse.aconnect_sse(c, "GET", url) as es,
61
+ ):
62
+ future = self.sse_futures["_sse_loop"]
63
+ future.set_result({"status": "ok"})
64
+ async for sse in es.aiter_sse():
65
+ jdata = json.loads(sse.data)
66
+ future = self.sse_futures[jdata["state"]]
67
+ future.set_result(jdata)
68
+
69
+ async def sse_await(self, state_id: str) -> None:
70
+ future = self.sse_futures[state_id]
71
+ jdata = await future
72
+ if jdata["status"] != "ok":
73
+ print("ERROR", jdata)
74
+ assert jdata["status"] == "ok"
75
+ del self.sse_futures[state_id]
76
+
77
+ async def get_meta(self, state_id: str) -> dict[str, Any]:
78
+ async with self as client:
79
+ response = await client.get(
80
+ f"{self.uri}/state/meta/{state_id}",
81
+ headers=self.auth_headers,
82
+ )
83
+ response.raise_for_status()
84
+ return response.json()
85
+
86
+ async def run_one[Tin, Tout](
87
+ self,
88
+ co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
89
+ params: Tin,
90
+ ) -> Tout:
91
+ await self.login()
92
+ async with asyncio.TaskGroup() as tg:
93
+ sse_task = tg.create_task(self.sse_loop())
94
+
95
+ async def outer_co(params: Tin) -> Tout:
96
+ # _sse_loop is a fake event to wait until the SSE loop is properly setup.
97
+ await self.sse_await("_sse_loop")
98
+ r = await co(self, params)
99
+ sse_task.cancel()
100
+ return r
101
+
102
+ r = tg.create_task(outer_co(params))
103
+
104
+ return r.result()
105
+
106
+ def run_one_sync[Tin, Tout](
107
+ self,
108
+ co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
109
+ params: Tin,
110
+ ) -> Tout:
111
+ try:
112
+ loop = asyncio.get_event_loop()
113
+ except RuntimeError:
114
+ loop = asyncio.new_event_loop()
115
+ asyncio.set_event_loop(loop)
116
+
117
+ return loop.run_until_complete(self.run_one(co, params))