shpotes commited on
Commit
9b070b0
Β·
1 Parent(s): c6a82cb

test gridio

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +21 -36
  3. requirements.txt +1 -1
README.md CHANGED
@@ -3,7 +3,7 @@ title: Medical diagnosis evaluation via MedCLIP
3
  emoji: πŸ‘©β€βš•οΈ
4
  colorFrom: blue
5
  colorTo: indigo
6
- sdk: streamlit
7
  app_file: app.py
8
  pinned: false
9
  ---
 
3
  emoji: πŸ‘©β€βš•οΈ
4
  colorFrom: blue
5
  colorTo: indigo
6
+ sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
  ---
app.py CHANGED
@@ -1,9 +1,9 @@
1
- import os
2
- import sys
 
3
  import jax
4
- import streamlit as st
5
- import transformers
6
  from huggingface_hub import snapshot_download
 
7
  from transformers import AutoTokenizer
8
  import torch
9
  from torchvision.io import ImageReadMode, read_image
@@ -15,7 +15,6 @@ sys.path.append(LOCAL_PATH)
15
  from src.modeling_medclip import FlaxMedCLIP
16
  from run_medclip import Transform
17
 
18
-
19
  def prepare_image(image_path, model):
20
  image = read_image(image_path, mode=ImageReadMode.RGB)
21
  preprocess = Transform(model.config.vision_config.image_size)
@@ -28,17 +27,11 @@ def prepare_text(text, tokenizer):
28
  return tokenizer(text, return_tensors="np")
29
 
30
  def save_file_to_disk(uplaoded_file):
31
- temp_file = os.path.join("/tmp", uplaoded_file.name)
32
- with open(temp_file, "wb") as f:
33
- f.write(uploaded_file.getbuffer())
34
  return temp_file
35
- @st.cache(
36
- hash_funcs={
37
- transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: id,
38
- FlaxMedCLIP: id,
39
- },
40
- show_spinner=False
41
- )
42
  def load_tokenizer_and_model():
43
  # load the saved model
44
  tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
@@ -60,24 +53,16 @@ def run_inference(image_path, text, model, tokenizer):
60
  return score
61
 
62
  tokenizer, model = load_tokenizer_and_model()
63
- st.title("Diagnosis Scoring")
64
- uploaded_file = st.file_uploader("Choose a Chest x-ray...", type=["png", "jpg"])
65
- text_input = st.text_input("Type the doctor diagnosis")
66
- if uploaded_file is not None and text_input:
67
- local_image_path = None
68
- try:
69
- local_image_path = save_file_to_disk(uploaded_file)
70
- score = run_inference(local_image_path, text_input, model, tokenizer).tolist()
71
- st.image(
72
- uploaded_file,
73
- caption=text_input,
74
- width=None,
75
- use_column_width=None,
76
- clamp=False,
77
- channels="RGB",
78
- output_format="auto",
79
- )
80
- st.write(f"## Score: {score:.2f}")
81
- finally:
82
- if local_image_path:
83
- os.remove(local_image_path)
 
1
+ import sys
2
+
3
+ import gradio as gr
4
  import jax
 
 
5
  from huggingface_hub import snapshot_download
6
+ from PIL import Image
7
  from transformers import AutoTokenizer
8
  import torch
9
  from torchvision.io import ImageReadMode, read_image
 
15
  from src.modeling_medclip import FlaxMedCLIP
16
  from run_medclip import Transform
17
 
 
18
  def prepare_image(image_path, model):
19
  image = read_image(image_path, mode=ImageReadMode.RGB)
20
  preprocess = Transform(model.config.vision_config.image_size)
 
27
  return tokenizer(text, return_tensors="np")
28
 
29
  def save_file_to_disk(uplaoded_file):
30
+ temp_file = "/tmp/image.jpeg"
31
+ im = Image.fromarray(uplaoded_file)
32
+ im.save(temp_file)
33
  return temp_file
34
+
 
 
 
 
 
 
35
  def load_tokenizer_and_model():
36
  # load the saved model
37
  tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
 
53
  return score
54
 
55
  tokenizer, model = load_tokenizer_and_model()
56
+
57
+ def score_image_caption_pair(uploaded_file, text_input):
58
+ local_image_path = save_file_to_disk(uploaded_file)
59
+ score = run_inference(
60
+ local_image_path, text_input, model, tokenizer).tolist()
61
+ return {"Score": score}, "{:.2f}".format(score)
62
+
63
+
64
+ image = gr.inputs.Image(shape=(299, 299))
65
+ iface = gr.Interface(
66
+ fn=score_image_caption_pair, inputs=[image, "text"], outputs=["label", "text"]
67
+ )
68
+ iface.launch()
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  flax==0.3.4
2
  huggingface-hub==0.0.12
3
  jax==0.2.17
4
- streamlit==0.84.1
5
  torch==1.9.0
6
  torchvision==0.10.0
7
  transformers==4.8.2
 
1
  flax==0.3.4
2
  huggingface-hub==0.0.12
3
  jax==0.2.17
4
+ gradio==2.2.2
5
  torch==1.9.0
6
  torchvision==0.10.0
7
  transformers==4.8.2