ferferefer's picture
fin?
65cdf3c
import gradio as gr
import spaces
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
from datetime import datetime
import numpy as np
import os
import pdf2image
import tempfile
from pathlib import Path
import gc
import torch.cuda
# Add version and configuration
VERSION = "1.0.0"
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB limit per file
SUPPORTED_FORMATS = [".pdf", ".png", ".jpg", ".jpeg"]
DESCRIPTION = """
# AI Ophthalmology Assistant for Corneal Imaging
This application helps analyze Pentacam and anterior segment OCT (MS39) images to assist in diagnosis and progression analysis.
**Capabilities:**
- Analysis of single or multiple Pentacam images
- Interpretation of anterior segment OCT (MS39) results
- Detection of corneal pathologies
- Assessment of disease progression when multiple images are provided
Created by **Dr. Nerea Zubieta**
## Instructions
1. Upload one or multiple images (supported formats: PDF, PNG, JPEG)
2. Type your question about diagnosis or progression
3. Click Submit to get the AI analysis
**Note:** Maximum file size: 10MB per file
"""
model_id = "Qwen/Qwen2-VL-7B-Instruct"
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
adapter_path = "sergiopaniego/qwen2-7b-instruct-trl-sft-ChartQA"
model.load_adapter(adapter_path)
processor = Qwen2VLProcessor.from_pretrained(model_id)
def process_uploaded_file(file_obj):
"""Process uploaded file whether it's an image or PDF"""
file_extension = Path(file_obj.name).suffix.lower()
try:
if file_extension == '.pdf':
try:
# For PDF files, we need to use a temporary directory
with tempfile.TemporaryDirectory() as temp_dir:
temp_pdf_path = os.path.join(temp_dir, "temp.pdf")
# Save the uploaded PDF to the temporary path
with open(file_obj.name, 'rb') as src_file:
with open(temp_pdf_path, 'wb') as dst_file:
dst_file.write(src_file.read())
# Convert PDF to images using pdf2image
try:
images = pdf2image.convert_from_path(
temp_pdf_path,
poppler_path=None, # Will use system poppler if available
dpi=200, # Adjust DPI as needed
fmt='PNG'
)
return images
except Exception as pdf_error:
if "poppler" in str(pdf_error).lower():
raise Exception(
"PDF processing requires poppler to be installed. "
"Please install poppler-utils package on your system. "
"On Ubuntu/Debian: sudo apt-get install -y poppler-utils"
)
raise
except Exception as e:
raise Exception(f"PDF processing error: {str(e)}")
else:
# Handle regular image files
try:
img = Image.open(file_obj.name)
return [img]
except Exception as e:
raise Exception(f"Image processing error: {str(e)}")
except Exception as e:
raise Exception(f"Error processing file {file_obj.name}: {str(e)}")
def cleanup_temp_files(file_paths):
"""Clean up temporary files after processing"""
for path in file_paths:
try:
if os.path.exists(path):
os.remove(path)
except Exception as e:
print(f"Error cleaning up {path}: {e}")
def clear_gpu_memory():
"""Clear GPU memory cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def validate_file(file_obj):
"""Validate file size and format"""
file_extension = Path(file_obj.name).suffix.lower()
if file_extension not in SUPPORTED_FORMATS:
raise ValueError(f"Unsupported file format. Please use: {', '.join(SUPPORTED_FORMATS)}")
file_size = os.path.getsize(file_obj.name)
if file_size > MAX_FILE_SIZE:
raise ValueError(f"File too large. Maximum size is {MAX_FILE_SIZE/1024/1024}MB")
return True
@spaces.GPU
def run_example(files, text_input=None):
if not files:
return "Please upload at least one image for analysis."
temp_paths = []
processed_images = []
try:
clear_gpu_memory()
# Process files
for file in files:
try:
# For HuggingFace Spaces, we need to handle the file path directly
file_path = file.name
file_extension = Path(file_path).suffix.lower()
if file_extension == '.pdf':
# Convert PDF to images
images = pdf2image.convert_from_path(file_path)
processed_images.extend(images)
else:
# Handle regular image files
img = Image.open(file_path)
processed_images.append(img)
except Exception as e:
return f"Error processing file {file.name}: {str(e)}"
if not processed_images:
return "No valid images were processed. Please check your files."
# Save processed images temporarily
image_paths = []
for idx, img in enumerate(processed_images):
try:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"/tmp/temp_image_{timestamp}_{idx}.png"
# Resize image to reduce memory usage
img = img.resize((512, 512), Image.Resampling.LANCZOS)
img.save(filename)
image_paths.append(filename)
temp_paths.append(filename)
except Exception as e:
cleanup_temp_files(temp_paths)
return f"Error saving processed image: {str(e)}"
try:
# Process images with the model
messages = [
{
"role": "user",
"content": [
*({"type": "image", "image": path} for path in image_paths),
{
"type": "text",
"text": text_input if text_input else "Please analyze these ophthalmological images and provide a detailed assessment."
},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
with torch.cuda.amp.autocast():
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
num_beams=1,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
del inputs, generated_ids, generated_ids_trimmed
clear_gpu_memory()
return output_text[0]
except Exception as e:
return f"Error during model inference: {str(e)}"
except Exception as e:
return f"Error processing images: {str(e)}"
finally:
cleanup_temp_files(temp_paths)
clear_gpu_memory()
# Updated CSS for better mobile responsiveness
css = """
#output {
height: 500px;
overflow: auto;
border: 1px solid #ccc;
}
.container {
margin: 15px;
padding: 15px;
border-radius: 10px;
background-color: #f5f5f5;
}
.footer {
text-align: center;
margin-top: 20px;
padding: 10px;
border-top: 1px solid #ccc;
}
/* Mobile responsive styles */
@media (max-width: 768px) {
.container {
margin: 5px;
padding: 10px;
}
#output {
height: 300px;
}
}
/* Loading animation */
.loading {
display: inline-block;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
"""
# Updated interface with better mobile support and clear button
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown(DESCRIPTION)
with gr.Tab(label="Ophthalmology Image Analysis"):
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_files = gr.Files(
label="Upload Images",
file_types=["image", "pdf"],
file_count="multiple",
scale=1
)
text_input = gr.Textbox(
label="Question",
placeholder="Example: Is there any sign of keratoconus? Has there been progression since the last scan?",
lines=2
)
with gr.Row(): # Put buttons in a row
submit_btn = gr.Button(
value="Analyze Images",
variant="primary",
scale=1
)
clear_btn = gr.Button(
value="Clear Results",
variant="secondary",
scale=1
)
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Analysis Results",
lines=12,
placeholder="AI analysis will appear here...",
show_copy_button=True
)
# Error messages area
error_box = gr.Textbox(
label="Status",
visible=False,
interactive=False
)
# Version info
gr.Markdown(f"Version: {VERSION}", elem_classes=["footer"])
# Footer
gr.Markdown(
"""
### Important Notes
- This tool is intended to assist medical professionals and should not replace professional medical judgment
- For medical emergencies, please contact your healthcare provider
- Created by Dr. Nerea Zubieta
- Maximum file size: 10MB per file
- Supported formats: PDF, PNG, JPEG
""",
elem_classes=["footer"]
)
# Updated click handlers with concurrency limits
submit_btn.click(
fn=run_example,
inputs=[input_files, text_input],
outputs=[output_text],
api_name="analyze",
concurrency_limit=1 # Limit concurrent executions
)
# Add clear functionality
def clear_outputs():
return "", "" # Clear both output and text input
clear_btn.click(
fn=clear_outputs,
inputs=[],
outputs=[output_text, text_input],
api_name="clear",
concurrency_limit=10 # Higher limit for clear operation as it's lightweight
)
# Simplified launch for HuggingFace Spaces
demo.launch()