Spaces:
Runtime error
Runtime error
Benjamin Bossan
commited on
Commit
·
01ae0bb
1
Parent(s):
64d4f97
Use transformers agents where applicable
Browse files- .gitignore +3 -0
- README.md +7 -4
- pyproject.toml +1 -1
- requests.org +18 -10
- requirements-dev.txt +1 -0
- requirements.txt +3 -1
- src/gistillery/config.py +24 -0
- src/gistillery/db.py +3 -4
- src/gistillery/preprocessing.py +59 -12
- src/gistillery/registry.py +29 -4
- src/gistillery/{ml.py → tools.py} +42 -44
- src/gistillery/webservice.py +21 -1
- src/gistillery/worker.py +6 -41
- tests/test_app.py +42 -29
.gitignore
CHANGED
@@ -10,3 +10,6 @@ build
|
|
10 |
htmlcov
|
11 |
|
12 |
*.db
|
|
|
|
|
|
|
|
10 |
htmlcov
|
11 |
|
12 |
*.db
|
13 |
+
notebooks/
|
14 |
+
*.ipynb
|
15 |
+
.env
|
README.md
CHANGED
@@ -19,18 +19,21 @@ python -m pip install -e .
|
|
19 |
|
20 |
## Starting
|
21 |
|
|
|
|
|
|
|
|
|
|
|
22 |
In one terminal, start the background worker:
|
23 |
|
24 |
```sh
|
25 |
-
|
26 |
-
python worker.py
|
27 |
```
|
28 |
|
29 |
In another terminal, start the web server:
|
30 |
|
31 |
```sh
|
32 |
-
|
33 |
-
uvicorn webservice:app --reload --port 8080
|
34 |
```
|
35 |
|
36 |
For example requests, check `requests.org`.
|
|
|
19 |
|
20 |
## Starting
|
21 |
|
22 |
+
### Preparing environemnt
|
23 |
+
|
24 |
+
Set an environemnt variable called "HF_HUB_TOKEN" with your Hugging Face token
|
25 |
+
or create a `.env` file with that env var.
|
26 |
+
|
27 |
In one terminal, start the background worker:
|
28 |
|
29 |
```sh
|
30 |
+
python src/gistillery/worker.py
|
|
|
31 |
```
|
32 |
|
33 |
In another terminal, start the web server:
|
34 |
|
35 |
```sh
|
36 |
+
uvicorn src.gistillery.webservice:app --reload --port 8080
|
|
|
37 |
```
|
38 |
|
39 |
For example requests, check `requests.org`.
|
pyproject.toml
CHANGED
@@ -18,5 +18,5 @@ no_implicit_optional = true
|
|
18 |
strict = true
|
19 |
|
20 |
[[tool.mypy.overrides]]
|
21 |
-
module = "
|
22 |
ignore_missing_imports = true
|
|
|
18 |
strict = true
|
19 |
|
20 |
[[tool.mypy.overrides]]
|
21 |
+
module = "huggingface_hub,trafilatura,transformers.*"
|
22 |
ignore_missing_imports = true
|
requests.org
CHANGED
@@ -10,19 +10,18 @@ curl -X 'GET' \
|
|
10 |
: OK
|
11 |
|
12 |
#+begin_src bash
|
13 |
-
# curl command to localhost and post the message "hi there"
|
14 |
curl -X 'POST' \
|
15 |
'http://localhost:8080/submit/' \
|
16 |
-H 'accept: application/json' \
|
17 |
-H 'Content-Type: application/json' \
|
18 |
-d '{
|
19 |
"author": "ben",
|
20 |
-
"content": "
|
21 |
}'
|
22 |
#+end_src
|
23 |
|
24 |
#+RESULTS:
|
25 |
-
: Submitted job
|
26 |
|
27 |
#+begin_src bash
|
28 |
curl -X 'POST' \
|
@@ -31,12 +30,12 @@ curl -X 'POST' \
|
|
31 |
-H 'Content-Type: application/json' \
|
32 |
-d '{
|
33 |
"author": "ben",
|
34 |
-
"content": "
|
35 |
}'
|
36 |
#+end_src
|
37 |
|
38 |
#+RESULTS:
|
39 |
-
: Submitted job
|
40 |
|
41 |
#+begin_src bash
|
42 |
curl -X 'POST' \
|
@@ -45,21 +44,21 @@ curl -X 'POST' \
|
|
45 |
-H 'Content-Type: application/json' \
|
46 |
-d '{
|
47 |
"author": "ben",
|
48 |
-
"content": "https://
|
49 |
}'
|
50 |
#+end_src
|
51 |
|
52 |
#+RESULTS:
|
53 |
-
: Submitted job
|
54 |
|
55 |
#+begin_src bash
|
56 |
curl -X 'GET' \
|
57 |
-
'http://localhost:8080/check_job_status/
|
58 |
-H 'accept: application/json'
|
59 |
#+end_src
|
60 |
|
61 |
#+RESULTS:
|
62 |
-
|
63 |
|
64 |
#+begin_src bash
|
65 |
curl -X 'GET' \
|
@@ -68,4 +67,13 @@ curl -X 'GET' \
|
|
68 |
#+end_src
|
69 |
|
70 |
#+RESULTS:
|
71 |
-
| [{"id":"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
: OK
|
11 |
|
12 |
#+begin_src bash
|
|
|
13 |
curl -X 'POST' \
|
14 |
'http://localhost:8080/submit/' \
|
15 |
-H 'accept: application/json' \
|
16 |
-H 'Content-Type: application/json' \
|
17 |
-d '{
|
18 |
"author": "ben",
|
19 |
+
"content": "In literature discussing why ChatGPT is able to capture so much of our imagination, I often come across two narratives: Scale: throwing more data and compute at it. UX: moving from a prompt interface to a more natural chat interface. A narrative that is often glossed over in the demo frenzy is the incredible technical creativity that went into making models like ChatGPT work. One such cool idea is RLHF (Reinforcement Learning from Human Feedback): incorporating reinforcement learning and human feedback into NLP. RL has been notoriously difficult to work with, and therefore, mostly confined to gaming and simulated environments like Atari or MuJoCo. Just five years ago, both RL and NLP were progressing pretty much orthogonally – different stacks, different techniques, and different experimentation setups. It’s impressive to see it work in a new domain at a massive scale. So, how exactly does RLHF work? Why does it work? This post will discuss the answers to those questions."
|
20 |
}'
|
21 |
#+end_src
|
22 |
|
23 |
#+RESULTS:
|
24 |
+
: Submitted job fef72c3aa4394bc7a299291c80a5c06b
|
25 |
|
26 |
#+begin_src bash
|
27 |
curl -X 'POST' \
|
|
|
30 |
-H 'Content-Type: application/json' \
|
31 |
-d '{
|
32 |
"author": "ben",
|
33 |
+
"content": "https://en.wikipedia.org/wiki/Goulburn_Street"
|
34 |
}'
|
35 |
#+end_src
|
36 |
|
37 |
#+RESULTS:
|
38 |
+
: Submitted job f37729bb36104ab4a23cefd0480e4862
|
39 |
|
40 |
#+begin_src bash
|
41 |
curl -X 'POST' \
|
|
|
44 |
-H 'Content-Type: application/json' \
|
45 |
-d '{
|
46 |
"author": "ben",
|
47 |
+
"content": "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e1/Cattle_tyrant_%28Machetornis_rixosa%29_on_Capybara.jpg/1920px-Cattle_tyrant_%28Machetornis_rixosa%29_on_Capybara.jpg"
|
48 |
}'
|
49 |
#+end_src
|
50 |
|
51 |
#+RESULTS:
|
52 |
+
: Submitted job dc3da7b1d5aa47c38dc6713952104f5f
|
53 |
|
54 |
#+begin_src bash
|
55 |
curl -X 'GET' \
|
56 |
+
'http://localhost:8080/check_job_status/' \
|
57 |
-H 'accept: application/json'
|
58 |
#+end_src
|
59 |
|
60 |
#+RESULTS:
|
61 |
+
: Found 3 pending job(s): fef72c3aa4394bc7a299291c80a5c06b, f37729bb36104ab4a23cefd0480e4862, dc3da7b1d5aa47c38dc6713952104f5f
|
62 |
|
63 |
#+begin_src bash
|
64 |
curl -X 'GET' \
|
|
|
67 |
#+end_src
|
68 |
|
69 |
#+RESULTS:
|
70 |
+
| [{"id":"dc3da7b1d5aa47c38dc6713952104f5f" | author:"ben" | summary:"A small bird is perched on the back of a capy capy. It's looking for a place to nestle. It doesn't seem to be finding a suitable place for it | though | because it's not very big. The place is not very flat. " | tags:["#back" | #bird | #capy | #general | #perch | #perched] | date:"2023-05-11T13:16:48"} | {"id":"f37729bb36104ab4a23cefd0480e4862" | author:"ben" | summary:"Goulburn Street is a street in the central business district of Sydney in New South Wales | Australia. It runs from Darling Harbour and Chinatown in the west to Crown Street in the east at Darlinghurst and Surry Hills. It is the only car park operated by Sydney City Council within the CBD and was the first air rights car park in Australia." | tags:["#centralbusinessdistrict" | #darlinghurst | #general | #goulburnstreet | #surryhills | #sydney | #sydneymasoniccentre] | date:"2023-05-11T13:16:47"} | {"id":"fef72c3aa4394bc7a299291c80a5c06b" | author:"ben" | summary:"ChatGPT is able to capture our imagination because of its scale. RLHF (Reinforcement Learning from Human Feedback) is a new approach to NLP that incorporates reinforcement learning and human feedback into NLP. It's impressive to see it work in a new domain at a massive scale." | tags:["#" | #general | #rlhf] | date:"2023-05-11T13:16:45"}] |
|
71 |
+
|
72 |
+
#+begin_src bash
|
73 |
+
curl -X 'GET' \
|
74 |
+
'http://localhost:8080/recent/rlhf' \
|
75 |
+
-H 'accept: application/json'
|
76 |
+
#+end_src
|
77 |
+
|
78 |
+
#+RESULTS:
|
79 |
+
| [{"id":"fef72c3aa4394bc7a299291c80a5c06b" | author:"ben" | summary:"ChatGPT is able to capture our imagination because of its scale. RLHF (Reinforcement Learning from Human Feedback) is a new approach to NLP that incorporates reinforcement learning and human feedback into NLP. It's impressive to see it work in a new domain at a massive scale." | tags:["#" | #general | #rlhf] | date:"2023-05-11T13:16:45"}] |
|
requirements-dev.txt
CHANGED
@@ -4,3 +4,4 @@ mypy
|
|
4 |
ruff
|
5 |
pytest
|
6 |
pytest-cov
|
|
|
|
4 |
ruff
|
5 |
pytest
|
6 |
pytest-cov
|
7 |
+
types-Pillow
|
requirements.txt
CHANGED
@@ -2,6 +2,8 @@ fastapi
|
|
2 |
httpx
|
3 |
uvicorn[standard]
|
4 |
torch
|
5 |
-
transformers
|
|
|
6 |
charset-normalizer
|
7 |
trafilatura
|
|
|
|
2 |
httpx
|
3 |
uvicorn[standard]
|
4 |
torch
|
5 |
+
transformers>=4.29.0
|
6 |
+
accelerate
|
7 |
charset-normalizer
|
8 |
trafilatura
|
9 |
+
pillow
|
src/gistillery/config.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from pydantic import BaseSettings
|
5 |
+
|
6 |
+
|
7 |
+
class Config(BaseSettings):
|
8 |
+
hf_hub_token: str = "missing"
|
9 |
+
hf_agent: str = "https://api-inference.huggingface.co/models/bigcode/starcoder"
|
10 |
+
db_file_name: Path = Path("sqlite-data.db")
|
11 |
+
|
12 |
+
class Config:
|
13 |
+
# load .env file by default, with provisio to use other .env files if set
|
14 |
+
env_file = os.getenv('ENV_FILE', '.env')
|
15 |
+
|
16 |
+
|
17 |
+
_config = None
|
18 |
+
|
19 |
+
|
20 |
+
def get_config() -> Config:
|
21 |
+
global _config
|
22 |
+
if _config is None:
|
23 |
+
_config = Config()
|
24 |
+
return _config
|
src/gistillery/db.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
import logging
|
2 |
-
import os
|
3 |
import sqlite3
|
4 |
from collections import namedtuple
|
5 |
from contextlib import contextmanager
|
6 |
from typing import Generator
|
7 |
|
|
|
|
|
8 |
logger = logging.getLogger(__name__)
|
9 |
logger.setLevel(logging.DEBUG)
|
10 |
|
11 |
-
db_file = os.getenv("DB_FILE_NAME", "sqlite-data.db")
|
12 |
-
|
13 |
|
14 |
schema_entries = """
|
15 |
CREATE TABLE entries
|
@@ -91,7 +90,7 @@ def _get_db_connection() -> sqlite3.Connection:
|
|
91 |
global TABLES_CREATED
|
92 |
|
93 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
94 |
-
conn = sqlite3.connect(
|
95 |
conn.row_factory = namedtuple_factory
|
96 |
if TABLES_CREATED:
|
97 |
return conn
|
|
|
1 |
import logging
|
|
|
2 |
import sqlite3
|
3 |
from collections import namedtuple
|
4 |
from contextlib import contextmanager
|
5 |
from typing import Generator
|
6 |
|
7 |
+
from gistillery.config import get_config
|
8 |
+
|
9 |
logger = logging.getLogger(__name__)
|
10 |
logger.setLevel(logging.DEBUG)
|
11 |
|
|
|
|
|
12 |
|
13 |
schema_entries = """
|
14 |
CREATE TABLE entries
|
|
|
90 |
global TABLES_CREATED
|
91 |
|
92 |
# sqlite cannot deal with concurrent access, so we set a big timeout
|
93 |
+
conn = sqlite3.connect(get_config().db_file_name, timeout=30)
|
94 |
conn.row_factory = namedtuple_factory
|
95 |
if TABLES_CREATED:
|
96 |
return conn
|
src/gistillery/preprocessing.py
CHANGED
@@ -1,16 +1,33 @@
|
|
1 |
import abc
|
|
|
2 |
import logging
|
3 |
import re
|
|
|
4 |
|
|
|
5 |
from httpx import Client
|
6 |
-
|
|
|
7 |
|
8 |
from gistillery.base import JobInput
|
|
|
|
|
9 |
|
10 |
logger = logging.getLogger(__name__)
|
11 |
logger.setLevel(logging.DEBUG)
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
class Processor(abc.ABC):
|
15 |
def get_name(self) -> str:
|
16 |
return self.__class__.__name__
|
@@ -40,25 +57,55 @@ class RawTextProcessor(Processor):
|
|
40 |
|
41 |
|
42 |
class DefaultUrlProcessor(Processor):
|
43 |
-
# uses trafilatura to extract text from html
|
44 |
def __init__(self) -> None:
|
45 |
self.client = Client()
|
46 |
-
self.
|
47 |
-
self.url = None
|
48 |
self.template = "{url}\n\n{content}"
|
49 |
|
50 |
def match(self, input: JobInput) -> bool:
|
51 |
-
|
52 |
-
if
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
56 |
|
57 |
def process(self, input: JobInput) -> str:
|
58 |
"""Get content of website and return it as string"""
|
59 |
-
|
|
|
|
|
60 |
text = self.client.get(self.url).text
|
61 |
assert isinstance(text, str)
|
62 |
-
extracted = extract(text)
|
63 |
text = self.template.format(url=self.url, content=extracted)
|
64 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import abc
|
2 |
+
import io
|
3 |
import logging
|
4 |
import re
|
5 |
+
from typing import Optional
|
6 |
|
7 |
+
import trafilatura
|
8 |
from httpx import Client
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
|
12 |
from gistillery.base import JobInput
|
13 |
+
from gistillery.tools import get_agent
|
14 |
+
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
logger.setLevel(logging.DEBUG)
|
18 |
|
19 |
|
20 |
+
RE_URL = re.compile(r"(https?://[^\s]+)")
|
21 |
+
|
22 |
+
|
23 |
+
def get_url(text: str) -> str | None:
|
24 |
+
urls: list[str] = list(RE_URL.findall(text))
|
25 |
+
if len(urls) == 1:
|
26 |
+
url = urls[0]
|
27 |
+
return url
|
28 |
+
return None
|
29 |
+
|
30 |
+
|
31 |
class Processor(abc.ABC):
|
32 |
def get_name(self) -> str:
|
33 |
return self.__class__.__name__
|
|
|
57 |
|
58 |
|
59 |
class DefaultUrlProcessor(Processor):
|
|
|
60 |
def __init__(self) -> None:
|
61 |
self.client = Client()
|
62 |
+
self.url = Optional[str]
|
|
|
63 |
self.template = "{url}\n\n{content}"
|
64 |
|
65 |
def match(self, input: JobInput) -> bool:
|
66 |
+
url = get_url(input.content.strip())
|
67 |
+
if url is None:
|
68 |
+
return False
|
69 |
+
|
70 |
+
self.url = url
|
71 |
+
return True
|
72 |
|
73 |
def process(self, input: JobInput) -> str:
|
74 |
"""Get content of website and return it as string"""
|
75 |
+
if not isinstance(self.url, str):
|
76 |
+
raise TypeError("self.url must be a string")
|
77 |
+
|
78 |
text = self.client.get(self.url).text
|
79 |
assert isinstance(text, str)
|
80 |
+
extracted = trafilatura.extract(text)
|
81 |
text = self.template.format(url=self.url, content=extracted)
|
82 |
+
return str(text)
|
83 |
+
|
84 |
+
|
85 |
+
class ImageUrlProcessor(Processor):
|
86 |
+
def __init__(self) -> None:
|
87 |
+
self.client = Client()
|
88 |
+
self.url = Optional[str]
|
89 |
+
self.template = "{url}\n\n{content}"
|
90 |
+
self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'}
|
91 |
+
|
92 |
+
def match(self, input: JobInput) -> bool:
|
93 |
+
url = get_url(input.content.strip())
|
94 |
+
if url is None:
|
95 |
+
return False
|
96 |
+
|
97 |
+
suffix = url.rsplit(".", 1)[-1].lower()
|
98 |
+
if suffix not in self.image_suffixes:
|
99 |
+
return False
|
100 |
+
|
101 |
+
self.url = url
|
102 |
+
return True
|
103 |
+
|
104 |
+
def process(self, input: JobInput) -> str:
|
105 |
+
if not isinstance(self.url, str):
|
106 |
+
raise TypeError("self.url must be a string")
|
107 |
+
|
108 |
+
response = self.client.get(self.url)
|
109 |
+
image = Image.open(io.BytesIO(response.content)).convert('RGB')
|
110 |
+
caption = get_agent().run("Caption the following image", image=image)
|
111 |
+
return str(caption)
|
src/gistillery/registry.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
-
from gistillery.ml import Summarizer, Tagger
|
2 |
-
from gistillery.preprocessing import Processor, RawTextProcessor
|
3 |
-
|
4 |
from gistillery.base import JobInput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
-
class
|
8 |
def __init__(self) -> None:
|
9 |
self.processors: list[Processor] = []
|
10 |
self.summerizer: Summarizer | None = None
|
@@ -39,3 +43,24 @@ class MlRegistry:
|
|
39 |
def get_tagger(self) -> Tagger:
|
40 |
assert self.tagger
|
41 |
return self.tagger
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from gistillery.base import JobInput
|
2 |
+
from gistillery.tools import Summarizer, Tagger, HfDefaultSummarizer, HfDefaultTagger
|
3 |
+
from gistillery.preprocessing import (
|
4 |
+
Processor,
|
5 |
+
RawTextProcessor,
|
6 |
+
ImageUrlProcessor,
|
7 |
+
DefaultUrlProcessor,
|
8 |
+
)
|
9 |
|
10 |
|
11 |
+
class ToolRegistry:
|
12 |
def __init__(self) -> None:
|
13 |
self.processors: list[Processor] = []
|
14 |
self.summerizer: Summarizer | None = None
|
|
|
43 |
def get_tagger(self) -> Tagger:
|
44 |
assert self.tagger
|
45 |
return self.tagger
|
46 |
+
|
47 |
+
|
48 |
+
_registry = None
|
49 |
+
|
50 |
+
|
51 |
+
def get_tool_registry() -> ToolRegistry:
|
52 |
+
global _registry
|
53 |
+
if _registry is not None:
|
54 |
+
return _registry
|
55 |
+
|
56 |
+
summarizer = HfDefaultSummarizer()
|
57 |
+
tagger = HfDefaultTagger()
|
58 |
+
|
59 |
+
_registry = ToolRegistry()
|
60 |
+
_registry.register_processor(ImageUrlProcessor())
|
61 |
+
_registry.register_processor(DefaultUrlProcessor())
|
62 |
+
_registry.register_processor(RawTextProcessor())
|
63 |
+
_registry.register_summarizer(summarizer)
|
64 |
+
_registry.register_tagger(tagger)
|
65 |
+
|
66 |
+
return _registry
|
src/gistillery/{ml.py → tools.py}
RENAMED
@@ -1,17 +1,26 @@
|
|
1 |
import abc
|
2 |
-
from typing import Any
|
3 |
-
import logging
|
4 |
|
5 |
-
|
6 |
-
|
|
|
|
|
7 |
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
class Summarizer(abc.ABC):
|
10 |
-
def __init__(
|
11 |
-
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
12 |
-
) -> None:
|
13 |
-
raise NotImplementedError
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def get_name(self) -> str:
|
16 |
raise NotImplementedError
|
17 |
|
@@ -20,12 +29,21 @@ class Summarizer(abc.ABC):
|
|
20 |
raise NotImplementedError
|
21 |
|
22 |
|
23 |
-
class
|
24 |
-
def __init__(
|
25 |
-
self
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
|
|
|
|
29 |
def get_name(self) -> str:
|
30 |
raise NotImplementedError
|
31 |
|
@@ -34,39 +52,19 @@ class Tagger(abc.ABC):
|
|
34 |
raise NotImplementedError
|
35 |
|
36 |
|
37 |
-
class
|
38 |
-
def __init__(
|
39 |
-
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
40 |
-
) -> None:
|
41 |
self.model_name = model_name
|
42 |
-
self.model = model
|
43 |
-
self.tokenizer = tokenizer
|
44 |
-
self.generation_config = generation_config
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
outputs = self.model.generate(
|
52 |
-
**inputs, generation_config=self.generation_config
|
53 |
-
)
|
54 |
-
output = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
55 |
-
assert isinstance(output, str)
|
56 |
-
return output
|
57 |
-
|
58 |
-
def get_name(self) -> str:
|
59 |
-
return f"{self.__class__.__name__}({self.model_name})"
|
60 |
-
|
61 |
-
|
62 |
-
class HfTransformersTagger(Tagger):
|
63 |
-
def __init__(
|
64 |
-
self, model_name: str, model: Any, tokenizer: Any, generation_config: Any
|
65 |
-
) -> None:
|
66 |
-
self.model_name = model_name
|
67 |
-
self.model = model
|
68 |
-
self.tokenizer = tokenizer
|
69 |
-
self.generation_config = generation_config
|
70 |
|
71 |
self.template = (
|
72 |
"Create a list of tags for the text below. The tags should be high level "
|
|
|
1 |
import abc
|
|
|
|
|
2 |
|
3 |
+
from huggingface_hub import login
|
4 |
+
from transformers.tools import TextSummarizationTool
|
5 |
+
from transformers import HfAgent
|
6 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
7 |
|
8 |
+
from gistillery.config import get_config
|
9 |
+
|
10 |
+
|
11 |
+
agent = None
|
12 |
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
def get_agent() -> HfAgent:
|
15 |
+
global agent
|
16 |
+
if agent is None:
|
17 |
+
login(get_config().hf_hub_token)
|
18 |
+
agent = HfAgent(get_config().hf_agent)
|
19 |
+
return agent
|
20 |
+
|
21 |
+
|
22 |
+
class Summarizer(abc.ABC):
|
23 |
+
@abc.abstractmethod
|
24 |
def get_name(self) -> str:
|
25 |
raise NotImplementedError
|
26 |
|
|
|
29 |
raise NotImplementedError
|
30 |
|
31 |
|
32 |
+
class HfDefaultSummarizer(Summarizer):
|
33 |
+
def __init__(self) -> None:
|
34 |
+
self.summarizer = TextSummarizationTool()
|
35 |
+
|
36 |
+
def get_name(self) -> str:
|
37 |
+
return "hf_default"
|
38 |
+
|
39 |
+
def __call__(self, x: str) -> str:
|
40 |
+
summary = self.summarizer(x)
|
41 |
+
assert isinstance(summary, str)
|
42 |
+
return summary
|
43 |
+
|
44 |
|
45 |
+
class Tagger(abc.ABC):
|
46 |
+
@abc.abstractmethod
|
47 |
def get_name(self) -> str:
|
48 |
raise NotImplementedError
|
49 |
|
|
|
52 |
raise NotImplementedError
|
53 |
|
54 |
|
55 |
+
class HfDefaultTagger(Tagger):
|
56 |
+
def __init__(self, model_name: str = "google/flan-t5-large") -> None:
|
|
|
|
|
57 |
self.model_name = model_name
|
|
|
|
|
|
|
58 |
|
59 |
+
config = GenerationConfig.from_pretrained(self.model_name)
|
60 |
+
config.max_new_tokens = 50
|
61 |
+
config.min_new_tokens = 25
|
62 |
+
# increase the temperature to make the model more creative
|
63 |
+
config.temperature = 1.5
|
64 |
|
65 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
66 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
67 |
+
self.generation_config = config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
self.template = (
|
70 |
"Create a list of tags for the text below. The tags should be high level "
|
src/gistillery/webservice.py
CHANGED
@@ -37,8 +37,28 @@ def submit_job(input: RequestInput) -> str:
|
|
37 |
return f"Submitted job {_id}"
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
@app.get("/check_job_status/{_id}")
|
41 |
-
def
|
42 |
with get_db_cursor() as cursor:
|
43 |
cursor.execute(
|
44 |
"SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
|
|
|
37 |
return f"Submitted job {_id}"
|
38 |
|
39 |
|
40 |
+
@app.get("/check_job_status/")
|
41 |
+
def check_job_status() -> str:
|
42 |
+
with get_db_cursor() as cursor:
|
43 |
+
cursor.execute(
|
44 |
+
"SELECT entry_id "
|
45 |
+
"FROM jobs WHERE status = 'pending' "
|
46 |
+
"ORDER BY last_updated ASC"
|
47 |
+
)
|
48 |
+
result = cursor.fetchall()
|
49 |
+
|
50 |
+
if not result:
|
51 |
+
return "No pending jobs found"
|
52 |
+
|
53 |
+
entry_ids = [r.entry_id for r in result]
|
54 |
+
num_entries = len(entry_ids)
|
55 |
+
if len(entry_ids) > 3:
|
56 |
+
entry_ids = entry_ids[:3] + ["..."]
|
57 |
+
return f"Found {num_entries} pending job(s): {', '.join(entry_ids)}"
|
58 |
+
|
59 |
+
|
60 |
@app.get("/check_job_status/{_id}")
|
61 |
+
def check_job_status_id(_id: str) -> JobStatusResult:
|
62 |
with get_db_cursor() as cursor:
|
63 |
cursor.execute(
|
64 |
"SELECT status, last_updated FROM jobs WHERE entry_id = ?", (_id,)
|
src/gistillery/worker.py
CHANGED
@@ -3,9 +3,7 @@ from dataclasses import dataclass
|
|
3 |
|
4 |
from gistillery.base import JobInput
|
5 |
from gistillery.db import get_db_cursor
|
6 |
-
from gistillery.
|
7 |
-
from gistillery.preprocessing import DefaultUrlProcessor, RawTextProcessor
|
8 |
-
from gistillery.registry import MlRegistry
|
9 |
|
10 |
SLEEP_INTERVAL = 5
|
11 |
|
@@ -13,7 +11,7 @@ SLEEP_INTERVAL = 5
|
|
13 |
def check_pending_jobs() -> list[JobInput]:
|
14 |
"""Check DB for pending jobs"""
|
15 |
with get_db_cursor() as cursor:
|
16 |
-
# fetch pending jobs, join
|
17 |
query = """
|
18 |
SELECT j.entry_id, e.author, e.source
|
19 |
FROM jobs j
|
@@ -21,7 +19,7 @@ def check_pending_jobs() -> list[JobInput]:
|
|
21 |
ON j.entry_id = e.id
|
22 |
WHERE j.status = 'pending'
|
23 |
"""
|
24 |
-
res =
|
25 |
return [
|
26 |
JobInput(id=_id, author=author, content=content) for _id, author, content in res
|
27 |
]
|
@@ -37,7 +35,7 @@ class JobOutput:
|
|
37 |
tagger_name: str
|
38 |
|
39 |
|
40 |
-
def _process_job(job: JobInput, registry:
|
41 |
processor = registry.get_processor(job)
|
42 |
processor_name = processor.get_name()
|
43 |
processed = processor(job)
|
@@ -79,7 +77,7 @@ def store(job: JobInput, output: JobOutput) -> None:
|
|
79 |
)
|
80 |
|
81 |
|
82 |
-
def process_job(job: JobInput, registry:
|
83 |
tic = time.perf_counter()
|
84 |
print(f"Processing job for (id={job.id[:8]})")
|
85 |
|
@@ -105,41 +103,8 @@ def process_job(job: JobInput, registry: MlRegistry) -> None:
|
|
105 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
106 |
|
107 |
|
108 |
-
def load_mlregistry(model_name: str) -> MlRegistry:
|
109 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig
|
110 |
-
|
111 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
112 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
113 |
-
|
114 |
-
config_summarizer = GenerationConfig.from_pretrained(model_name)
|
115 |
-
config_summarizer.max_new_tokens = 200
|
116 |
-
config_summarizer.min_new_tokens = 100
|
117 |
-
config_summarizer.top_k = 5
|
118 |
-
config_summarizer.repetition_penalty = 1.5
|
119 |
-
|
120 |
-
config_tagger = GenerationConfig.from_pretrained(model_name)
|
121 |
-
config_tagger.max_new_tokens = 50
|
122 |
-
config_tagger.min_new_tokens = 25
|
123 |
-
# increase the temperature to make the model more creative
|
124 |
-
config_tagger.temperature = 1.5
|
125 |
-
|
126 |
-
summarizer = HfTransformersSummarizer(
|
127 |
-
model_name, model, tokenizer, config_summarizer
|
128 |
-
)
|
129 |
-
tagger = HfTransformersTagger(model_name, model, tokenizer, config_tagger)
|
130 |
-
|
131 |
-
registry = MlRegistry()
|
132 |
-
registry.register_processor(DefaultUrlProcessor())
|
133 |
-
registry.register_processor(RawTextProcessor())
|
134 |
-
registry.register_summarizer(summarizer)
|
135 |
-
registry.register_tagger(tagger)
|
136 |
-
|
137 |
-
return registry
|
138 |
-
|
139 |
-
|
140 |
def main() -> None:
|
141 |
-
|
142 |
-
registry = load_mlregistry(model_name)
|
143 |
|
144 |
while True:
|
145 |
jobs = check_pending_jobs()
|
|
|
3 |
|
4 |
from gistillery.base import JobInput
|
5 |
from gistillery.db import get_db_cursor
|
6 |
+
from gistillery.registry import ToolRegistry, get_tool_registry
|
|
|
|
|
7 |
|
8 |
SLEEP_INTERVAL = 5
|
9 |
|
|
|
11 |
def check_pending_jobs() -> list[JobInput]:
|
12 |
"""Check DB for pending jobs"""
|
13 |
with get_db_cursor() as cursor:
|
14 |
+
# fetch pending jobs, join author and content from entries table
|
15 |
query = """
|
16 |
SELECT j.entry_id, e.author, e.source
|
17 |
FROM jobs j
|
|
|
19 |
ON j.entry_id = e.id
|
20 |
WHERE j.status = 'pending'
|
21 |
"""
|
22 |
+
res = cursor.execute(query).fetchall()
|
23 |
return [
|
24 |
JobInput(id=_id, author=author, content=content) for _id, author, content in res
|
25 |
]
|
|
|
35 |
tagger_name: str
|
36 |
|
37 |
|
38 |
+
def _process_job(job: JobInput, registry: ToolRegistry) -> JobOutput:
|
39 |
processor = registry.get_processor(job)
|
40 |
processor_name = processor.get_name()
|
41 |
processed = processor(job)
|
|
|
77 |
)
|
78 |
|
79 |
|
80 |
+
def process_job(job: JobInput, registry: ToolRegistry) -> None:
|
81 |
tic = time.perf_counter()
|
82 |
print(f"Processing job for (id={job.id[:8]})")
|
83 |
|
|
|
103 |
print(f"Finished processing job (id={job.id[:8]}) in {toc - tic:0.3f} seconds")
|
104 |
|
105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
def main() -> None:
|
107 |
+
registry = get_tool_registry()
|
|
|
108 |
|
109 |
while True:
|
110 |
jobs = check_pending_jobs()
|
tests/test_app.py
CHANGED
@@ -35,18 +35,14 @@ class TestWebservice:
|
|
35 |
return client
|
36 |
|
37 |
@pytest.fixture
|
38 |
-
def
|
39 |
# use dummy models
|
40 |
-
from gistillery.
|
41 |
from gistillery.preprocessing import RawTextProcessor
|
42 |
-
from gistillery.registry import
|
43 |
|
44 |
class DummySummarizer(Summarizer):
|
45 |
"""Returns the first 10 characters of the input"""
|
46 |
-
|
47 |
-
def __init__(self, *args, **kwargs):
|
48 |
-
pass
|
49 |
-
|
50 |
def get_name(self):
|
51 |
return "dummy summarizer"
|
52 |
|
@@ -55,24 +51,20 @@ class TestWebservice:
|
|
55 |
|
56 |
class DummyTagger(Tagger):
|
57 |
"""Returns the first 3 words of the input"""
|
58 |
-
|
59 |
-
def __init__(self, *args, **kwargs):
|
60 |
-
pass
|
61 |
-
|
62 |
def get_name(self):
|
63 |
return "dummy tagger"
|
64 |
|
65 |
def __call__(self, x):
|
66 |
return ["#" + word for word in x.split(maxsplit=4)[:3]]
|
67 |
|
68 |
-
registry =
|
69 |
registry.register_processor(RawTextProcessor())
|
70 |
|
71 |
# arguments don't matter for dummy summarizer and tagger
|
72 |
-
summarizer = DummySummarizer(
|
73 |
registry.register_summarizer(summarizer)
|
74 |
|
75 |
-
tagger = DummyTagger(
|
76 |
registry.register_tagger(tagger)
|
77 |
return registry
|
78 |
|
@@ -128,7 +120,7 @@ class TestWebservice:
|
|
128 |
}
|
129 |
assert last_updated is None
|
130 |
|
131 |
-
def test_submitted_job_failed(self, client,
|
132 |
# monkeypatch uuid4 to return a known value
|
133 |
job_id = "abc1234"
|
134 |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
|
@@ -143,7 +135,7 @@ class TestWebservice:
|
|
143 |
"gistillery.worker._process_job",
|
144 |
lambda job, registry: raise_(RuntimeError("something went wrong")),
|
145 |
)
|
146 |
-
self.process_jobs(
|
147 |
|
148 |
resp = client.get(f"/check_job_status/{job_id}")
|
149 |
output = resp.json()
|
@@ -153,12 +145,12 @@ class TestWebservice:
|
|
153 |
"status": "failed",
|
154 |
}
|
155 |
|
156 |
-
def test_submitted_job_status_done(self, client,
|
157 |
# monkeypatch uuid4 to return a known value
|
158 |
job_id = "abc1234"
|
159 |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
|
160 |
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
161 |
-
self.process_jobs(
|
162 |
|
163 |
resp = client.get(f"/check_job_status/{job_id}")
|
164 |
output = resp.json()
|
@@ -169,7 +161,28 @@ class TestWebservice:
|
|
169 |
}
|
170 |
assert is_roughly_now(last_updated)
|
171 |
|
172 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
# submit 2 entries
|
174 |
client.post(
|
175 |
"/submit", json={"author": "maxi", "content": "this is a first test"}
|
@@ -178,7 +191,7 @@ class TestWebservice:
|
|
178 |
"/submit",
|
179 |
json={"author": "mini", "content": "this would be something else"},
|
180 |
)
|
181 |
-
self.process_jobs(
|
182 |
resp = client.get("/recent").json()
|
183 |
|
184 |
# results are sorted by recency but since dummy models are so fast, the
|
@@ -196,7 +209,7 @@ class TestWebservice:
|
|
196 |
assert resp1["summary"] == "this would"
|
197 |
assert resp1["tags"] == sorted(["#this", "#would", "#be"])
|
198 |
|
199 |
-
def test_recent_tag_with_entries(self, client,
|
200 |
# submit 2 entries
|
201 |
client.post(
|
202 |
"/submit", json={"author": "maxi", "content": "this is a first test"}
|
@@ -205,7 +218,7 @@ class TestWebservice:
|
|
205 |
"/submit",
|
206 |
json={"author": "mini", "content": "this would be something else"},
|
207 |
)
|
208 |
-
self.process_jobs(
|
209 |
|
210 |
# the "this" tag is in both entries
|
211 |
resp = client.get("/recent/this").json()
|
@@ -220,22 +233,22 @@ class TestWebservice:
|
|
220 |
assert resp0["summary"] == "this would"
|
221 |
assert resp0["tags"] == sorted(["#this", "#would", "#be"])
|
222 |
|
223 |
-
def test_clear(self, client, cursor,
|
224 |
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
225 |
-
self.process_jobs(
|
226 |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
|
227 |
|
228 |
client.get("/clear")
|
229 |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
|
230 |
|
231 |
-
def test_inputs_stored(self, client, cursor,
|
232 |
client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
|
233 |
-
self.process_jobs(
|
234 |
rows = cursor.execute("SELECT * FROM inputs").fetchall()
|
235 |
assert len(rows) == 1
|
236 |
assert rows[0].input == "this is a test"
|
237 |
|
238 |
-
def test_submit_url(self, client, cursor,
|
239 |
class MockClient:
|
240 |
"""Mock httpx Client, return www.example.com content"""
|
241 |
|
@@ -269,7 +282,7 @@ class TestWebservice:
|
|
269 |
from gistillery.preprocessing import DefaultUrlProcessor
|
270 |
|
271 |
# register url processor, put it before the default processor
|
272 |
-
|
273 |
client.post(
|
274 |
"/submit",
|
275 |
json={
|
@@ -277,7 +290,7 @@ class TestWebservice:
|
|
277 |
"content": "https://en.wikipedia.org/wiki/non-existing-page",
|
278 |
},
|
279 |
)
|
280 |
-
self.process_jobs(
|
281 |
|
282 |
rows = cursor.execute("SELECT * FROM inputs").fetchall()
|
283 |
assert len(rows) == 1
|
|
|
35 |
return client
|
36 |
|
37 |
@pytest.fixture
|
38 |
+
def registry(self):
|
39 |
# use dummy models
|
40 |
+
from gistillery.tools import Summarizer, Tagger
|
41 |
from gistillery.preprocessing import RawTextProcessor
|
42 |
+
from gistillery.registry import ToolRegistry
|
43 |
|
44 |
class DummySummarizer(Summarizer):
|
45 |
"""Returns the first 10 characters of the input"""
|
|
|
|
|
|
|
|
|
46 |
def get_name(self):
|
47 |
return "dummy summarizer"
|
48 |
|
|
|
51 |
|
52 |
class DummyTagger(Tagger):
|
53 |
"""Returns the first 3 words of the input"""
|
|
|
|
|
|
|
|
|
54 |
def get_name(self):
|
55 |
return "dummy tagger"
|
56 |
|
57 |
def __call__(self, x):
|
58 |
return ["#" + word for word in x.split(maxsplit=4)[:3]]
|
59 |
|
60 |
+
registry = ToolRegistry()
|
61 |
registry.register_processor(RawTextProcessor())
|
62 |
|
63 |
# arguments don't matter for dummy summarizer and tagger
|
64 |
+
summarizer = DummySummarizer()
|
65 |
registry.register_summarizer(summarizer)
|
66 |
|
67 |
+
tagger = DummyTagger()
|
68 |
registry.register_tagger(tagger)
|
69 |
return registry
|
70 |
|
|
|
120 |
}
|
121 |
assert last_updated is None
|
122 |
|
123 |
+
def test_submitted_job_failed(self, client, registry, monkeypatch):
|
124 |
# monkeypatch uuid4 to return a known value
|
125 |
job_id = "abc1234"
|
126 |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
|
|
|
135 |
"gistillery.worker._process_job",
|
136 |
lambda job, registry: raise_(RuntimeError("something went wrong")),
|
137 |
)
|
138 |
+
self.process_jobs(registry)
|
139 |
|
140 |
resp = client.get(f"/check_job_status/{job_id}")
|
141 |
output = resp.json()
|
|
|
145 |
"status": "failed",
|
146 |
}
|
147 |
|
148 |
+
def test_submitted_job_status_done(self, client, registry, monkeypatch):
|
149 |
# monkeypatch uuid4 to return a known value
|
150 |
job_id = "abc1234"
|
151 |
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=job_id))
|
152 |
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
153 |
+
self.process_jobs(registry)
|
154 |
|
155 |
resp = client.get(f"/check_job_status/{job_id}")
|
156 |
output = resp.json()
|
|
|
161 |
}
|
162 |
assert is_roughly_now(last_updated)
|
163 |
|
164 |
+
def test_status_pending_jobs(self, client, registry, monkeypatch):
|
165 |
+
resp = client.get("/check_job_status/")
|
166 |
+
output = resp.json()
|
167 |
+
assert output == "No pending jobs found"
|
168 |
+
|
169 |
+
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex="abc0"))
|
170 |
+
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
171 |
+
resp = client.get("/check_job_status/")
|
172 |
+
output = resp.json()
|
173 |
+
expected = "Found 1 pending job(s): abc0"
|
174 |
+
assert output == expected
|
175 |
+
|
176 |
+
for i in range(1, 10):
|
177 |
+
monkeypatch.setattr("uuid.uuid4", lambda: SimpleNamespace(hex=f"abc{i}"))
|
178 |
+
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
179 |
+
|
180 |
+
resp = client.get("/check_job_status/")
|
181 |
+
output = resp.json()
|
182 |
+
expected = "Found 10 pending job(s): abc0, abc1, abc2, ..."
|
183 |
+
assert output == expected
|
184 |
+
|
185 |
+
def test_recent_with_entries(self, client, registry):
|
186 |
# submit 2 entries
|
187 |
client.post(
|
188 |
"/submit", json={"author": "maxi", "content": "this is a first test"}
|
|
|
191 |
"/submit",
|
192 |
json={"author": "mini", "content": "this would be something else"},
|
193 |
)
|
194 |
+
self.process_jobs(registry)
|
195 |
resp = client.get("/recent").json()
|
196 |
|
197 |
# results are sorted by recency but since dummy models are so fast, the
|
|
|
209 |
assert resp1["summary"] == "this would"
|
210 |
assert resp1["tags"] == sorted(["#this", "#would", "#be"])
|
211 |
|
212 |
+
def test_recent_tag_with_entries(self, client, registry):
|
213 |
# submit 2 entries
|
214 |
client.post(
|
215 |
"/submit", json={"author": "maxi", "content": "this is a first test"}
|
|
|
218 |
"/submit",
|
219 |
json={"author": "mini", "content": "this would be something else"},
|
220 |
)
|
221 |
+
self.process_jobs(registry)
|
222 |
|
223 |
# the "this" tag is in both entries
|
224 |
resp = client.get("/recent/this").json()
|
|
|
233 |
assert resp0["summary"] == "this would"
|
234 |
assert resp0["tags"] == sorted(["#this", "#would", "#be"])
|
235 |
|
236 |
+
def test_clear(self, client, cursor, registry):
|
237 |
client.post("/submit", json={"author": "ben", "content": "this is a test"})
|
238 |
+
self.process_jobs(registry)
|
239 |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 1
|
240 |
|
241 |
client.get("/clear")
|
242 |
assert cursor.execute("SELECT COUNT(*) c FROM entries").fetchone()[0] == 0
|
243 |
|
244 |
+
def test_inputs_stored(self, client, cursor, registry):
|
245 |
client.post("/submit", json={"author": "ben", "content": " this is a test\n"})
|
246 |
+
self.process_jobs(registry)
|
247 |
rows = cursor.execute("SELECT * FROM inputs").fetchall()
|
248 |
assert len(rows) == 1
|
249 |
assert rows[0].input == "this is a test"
|
250 |
|
251 |
+
def test_submit_url(self, client, cursor, registry, monkeypatch):
|
252 |
class MockClient:
|
253 |
"""Mock httpx Client, return www.example.com content"""
|
254 |
|
|
|
282 |
from gistillery.preprocessing import DefaultUrlProcessor
|
283 |
|
284 |
# register url processor, put it before the default processor
|
285 |
+
registry.register_processor(DefaultUrlProcessor(), last=False)
|
286 |
client.post(
|
287 |
"/submit",
|
288 |
json={
|
|
|
290 |
"content": "https://en.wikipedia.org/wiki/non-existing-page",
|
291 |
},
|
292 |
)
|
293 |
+
self.process_jobs(registry)
|
294 |
|
295 |
rows = cursor.execute("SELECT * FROM inputs").fetchall()
|
296 |
assert len(rows) == 1
|