WeedDetector / app.py
Snearec's picture
Update app.py
3031969
raw
history blame
2.59 kB
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import gradio as gr
import torch
# Cargar un modelo YOLOv8n preentrenado
model = YOLO('best.pt')
def contar_detecciones(cls_tensor, nombres_clases):
conteos = {nombre: torch.sum(cls_tensor == indice).item() for indice, nombre in enumerate(nombres_clases)}
return conteos
def clases_detectadas(res):
if res and hasattr(res[0], 'xyxy'):
cls_tensor = res[0].xyxy[0][:, -1] # Obtener tensor de clases
nombres_clases = model.names
conteos = contar_detecciones(cls_tensor, nombres_clases)
respuesta = ""
for nombre, conteo in conteos.items():
respuesta += f"Clase {nombre} : {conteo} detecciones\n"
return respuesta
else:
return "No se encontraron resultados o falta informaci贸n relevante"
def detect_objects(image: Image.Image):
# Realizar la inferencia
results = model.predict(image)
# Guardar la imagen con todas las detecciones
im_array = results.render()[0]
im_all_detections = Image.fromarray(im_array[..., ::-1])
# Contar las clases detectadas
conteo_clases = clases_detectadas(results)
# Guardar informaci贸n de detecci贸n
detections = results.xyxy[0].tolist()
return im_all_detections, conteo_clases, detections
def update_image(original_image: Image.Image, detections, show_potatoes: bool, show_tongues: bool):
# Crear una copia de la imagen original
updated_image = original_image.copy()
draw = ImageDraw.Draw(updated_image)
# Definir la fuente para las etiquetas
try:
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
font = ImageFont.load_default()
# Filtrar y dibujar solo las detecciones seleccionadas
for det in detections:
label = model.names[int(det[5])]
if (label == 'papa' and show_potatoes) or (label == 'lengua' and show_tongues):
box = det[:4]
label_text = f"{label} {det[4]:.2f}"
draw.rectangle(box, outline="red", width=2)
text_size = draw.textsize(label_text, font=font)
draw.rectangle([box[0], box[1] - text_size[1], box[0] + text_size[0], box[1]], fill="red")
draw.text((box[0], box[1] - text_size[1]), label_text, fill="white", font=font)
return updated_image
# Crear la interfaz de Gradio
iface = gr.Interface(
fn=detect_objects,
update=update_image,
inputs=["image", gr.Checkbox(label="Mostrar Papas"), gr.Checkbox(label="Mostrar Lenguas")],
outputs=["image", "text", "image"]
).launch()