import torch from pathlib import Path from transformers import CLIPProcessor, CLIPModel from PIL import Image, ImageDraw import pytesseract import requests import os from llm import inference, upload_image from fastapi.responses import FileResponse, JSONResponse import re from io import BytesIO cropped_images_dir = "cropped_images" os.makedirs(cropped_images_dir, exist_ok=True) # Load YOLO model class YOLOModel: def __init__(self, model_path="yolov5s.pt"): """ Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available. """ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True) def predict_clip(self, image, brand_names): """ Predict the most probable brand using CLIP. """ inputs = self.clip_processor( text=brand_names, images=image, return_tensors="pt", padding=True ) # print(f'Inputs to clip processor:{inputs}') outputs = self.clip_model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities best_idx = probs.argmax().item() return brand_names[best_idx], probs[0, best_idx].item() def predict_text(self, image): try: # Convert image to grayscale grayscale = image.convert('L') # Perform OCR using pytesseract text = pytesseract.image_to_string(grayscale) # Return the stripped text if successful return text.strip() except Exception as e: # Log the error for debugging purposes print(f"Error during text prediction: {e}") # Return an empty string if OCR fails return "" def predict(self, image_path): """ Run YOLO inference on an image. :param image_path: Path to the input image :return: List of predictions with labels and bounding boxes """ results = self.model(image_path) image = Image.open(image_path).convert("RGB") draw = ImageDraw.Draw(image) predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame print(f'YOLO predictions:\n\n{predictions}') output = [] file_responses = [] for idx, row in predictions.iterrows(): category = row['name'] confidence = row['confidence'] bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]] # Crop the detected region cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3])) cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg") cropped_image.save(cropped_image_path, "JPEG") # uploading to cloud for getting URL to pass into LLM print(f'Uploading now to image url') image_url = upload_image.upload_image_to_imgbb(cropped_image_path) print(f'Image URL received as{image_url}') # inferencing llm for possible brands result_llms = inference.get_name(image_url, category) detected_text = self.predict_text(cropped_image) print(f'Details:{detected_text}') print(f'Predicted brand: {result_llms["model"]}') # Draw bounding box and label on the image draw.rectangle(bbox, outline="red", width=3) draw.text( (bbox[0], bbox[1] - 10), f'{result_llms["brand"]})', fill="red" ) cropped_image_io = BytesIO() cropped_image.save(cropped_image_io, format="JPEG") cropped_image_io.seek(0) # Append result output.append({ "category": category, "bbox": bbox, "confidence": confidence, "category_llm":result_llms["brand"], "predicted_brand": result_llms["model"], # "clip_confidence": clip_confidence, "price":result_llms["price"], "details":result_llms["description"], "detected_text":detected_text, "image_path":cropped_image_path, "image_url":image_url, }) # file_responses.append(f"/download_cropped_image/{idx}") valid_indices = set(range(len(predictions))) # Iterate over all files in the directory for filename in os.listdir(cropped_images_dir): # Check if the filename matches the pattern for cropped images if filename.startswith("crop_") and filename.endswith(".jpg"): # Extract the index from the filename try: file_idx = int(filename.split("_")[1].split(".")[0]) if file_idx not in valid_indices: # Delete the file if its index is not valid file_path = os.path.join(cropped_images_dir, filename) os.remove(file_path) print(f"Deleted excess file: {filename}") except ValueError: # Skip files that don't match the pattern continue return output # return JSONResponse( # content={ # "metadata": results, # "cropped_image_urls": [ # f"/download_cropped_image/{idx}" for idx in range(len(file_responses)) # ], # } # ) # return {"metadata": results, "cropped_image_urls": file_responses}