Spaces:
Runtime error
Runtime error
test gridio
Browse files- README.md +1 -1
- app.py +21 -36
- requirements.txt +1 -1
README.md
CHANGED
@@ -3,7 +3,7 @@ title: Medical diagnosis evaluation via MedCLIP
|
|
3 |
emoji: π©ββοΈ
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
-
sdk:
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
|
|
3 |
emoji: π©ββοΈ
|
4 |
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
app.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
3 |
import jax
|
4 |
-
import streamlit as st
|
5 |
-
import transformers
|
6 |
from huggingface_hub import snapshot_download
|
|
|
7 |
from transformers import AutoTokenizer
|
8 |
import torch
|
9 |
from torchvision.io import ImageReadMode, read_image
|
@@ -15,7 +15,6 @@ sys.path.append(LOCAL_PATH)
|
|
15 |
from src.modeling_medclip import FlaxMedCLIP
|
16 |
from run_medclip import Transform
|
17 |
|
18 |
-
|
19 |
def prepare_image(image_path, model):
|
20 |
image = read_image(image_path, mode=ImageReadMode.RGB)
|
21 |
preprocess = Transform(model.config.vision_config.image_size)
|
@@ -28,17 +27,11 @@ def prepare_text(text, tokenizer):
|
|
28 |
return tokenizer(text, return_tensors="np")
|
29 |
|
30 |
def save_file_to_disk(uplaoded_file):
|
31 |
-
temp_file =
|
32 |
-
|
33 |
-
|
34 |
return temp_file
|
35 |
-
|
36 |
-
hash_funcs={
|
37 |
-
transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: id,
|
38 |
-
FlaxMedCLIP: id,
|
39 |
-
},
|
40 |
-
show_spinner=False
|
41 |
-
)
|
42 |
def load_tokenizer_and_model():
|
43 |
# load the saved model
|
44 |
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
|
@@ -60,24 +53,16 @@ def run_inference(image_path, text, model, tokenizer):
|
|
60 |
return score
|
61 |
|
62 |
tokenizer, model = load_tokenizer_and_model()
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
clamp=False,
|
77 |
-
channels="RGB",
|
78 |
-
output_format="auto",
|
79 |
-
)
|
80 |
-
st.write(f"## Score: {score:.2f}")
|
81 |
-
finally:
|
82 |
-
if local_image_path:
|
83 |
-
os.remove(local_image_path)
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
import jax
|
|
|
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
+
from PIL import Image
|
7 |
from transformers import AutoTokenizer
|
8 |
import torch
|
9 |
from torchvision.io import ImageReadMode, read_image
|
|
|
15 |
from src.modeling_medclip import FlaxMedCLIP
|
16 |
from run_medclip import Transform
|
17 |
|
|
|
18 |
def prepare_image(image_path, model):
|
19 |
image = read_image(image_path, mode=ImageReadMode.RGB)
|
20 |
preprocess = Transform(model.config.vision_config.image_size)
|
|
|
27 |
return tokenizer(text, return_tensors="np")
|
28 |
|
29 |
def save_file_to_disk(uplaoded_file):
|
30 |
+
temp_file = "/tmp/image.jpeg"
|
31 |
+
im = Image.fromarray(uplaoded_file)
|
32 |
+
im.save(temp_file)
|
33 |
return temp_file
|
34 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
def load_tokenizer_and_model():
|
36 |
# load the saved model
|
37 |
tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
|
|
|
53 |
return score
|
54 |
|
55 |
tokenizer, model = load_tokenizer_and_model()
|
56 |
+
|
57 |
+
def score_image_caption_pair(uploaded_file, text_input):
|
58 |
+
local_image_path = save_file_to_disk(uploaded_file)
|
59 |
+
score = run_inference(
|
60 |
+
local_image_path, text_input, model, tokenizer).tolist()
|
61 |
+
return {"Score": score}, "{:.2f}".format(score)
|
62 |
+
|
63 |
+
|
64 |
+
image = gr.inputs.Image(shape=(299, 299))
|
65 |
+
iface = gr.Interface(
|
66 |
+
fn=score_image_caption_pair, inputs=[image, "text"], outputs=["label", "text"]
|
67 |
+
)
|
68 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
flax==0.3.4
|
2 |
huggingface-hub==0.0.12
|
3 |
jax==0.2.17
|
4 |
-
|
5 |
torch==1.9.0
|
6 |
torchvision==0.10.0
|
7 |
transformers==4.8.2
|
|
|
1 |
flax==0.3.4
|
2 |
huggingface-hub==0.0.12
|
3 |
jax==0.2.17
|
4 |
+
gradio==2.2.2
|
5 |
torch==1.9.0
|
6 |
torchvision==0.10.0
|
7 |
transformers==4.8.2
|