root-sajjan commited on
Commit
e2e8ffc
·
verified ·
1 Parent(s): 4e269c9

edited tesseract error handling

Browse files
Files changed (1) hide show
  1. model.py +160 -148
model.py CHANGED
@@ -1,149 +1,161 @@
1
- import torch
2
- from pathlib import Path
3
- from transformers import CLIPProcessor, CLIPModel
4
- from PIL import Image, ImageDraw
5
- import pytesseract
6
- import requests
7
- import os
8
- from llm import inference, upload_image
9
- from fastapi.responses import FileResponse, JSONResponse
10
-
11
- import re
12
-
13
- from io import BytesIO
14
-
15
- cropped_images_dir = "cropped_images"
16
- os.makedirs(cropped_images_dir, exist_ok=True)
17
-
18
- # Load YOLO model
19
- class YOLOModel:
20
- def __init__(self, model_path="yolov5s.pt"):
21
- """
22
- Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
23
- """
24
- torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
25
- self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)
26
-
27
-
28
- def predict_clip(self, image, brand_names):
29
- """
30
- Predict the most probable brand using CLIP.
31
- """
32
- inputs = self.clip_processor(
33
- text=brand_names,
34
- images=image,
35
- return_tensors="pt",
36
- padding=True
37
- )
38
- # print(f'Inputs to clip processor:{inputs}')
39
- outputs = self.clip_model(**inputs)
40
- logits_per_image = outputs.logits_per_image
41
- probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
42
- best_idx = probs.argmax().item()
43
- return brand_names[best_idx], probs[0, best_idx].item()
44
-
45
-
46
- def predict_text(self, image):
47
- grayscale = image.convert('L')
48
- text = pytesseract.image_to_string(grayscale)
49
- return text.strip()
50
-
51
-
52
- def predict(self, image_path):
53
- """
54
- Run YOLO inference on an image.
55
-
56
- :param image_path: Path to the input image
57
- :return: List of predictions with labels and bounding boxes
58
- """
59
- results = self.model(image_path)
60
- image = Image.open(image_path).convert("RGB")
61
- draw = ImageDraw.Draw(image)
62
- predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame
63
- print(f'YOLO predictions:\n\n{predictions}')
64
-
65
-
66
- output = []
67
- file_responses = []
68
-
69
-
70
- for idx, row in predictions.iterrows():
71
- category = row['name']
72
- confidence = row['confidence']
73
- bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]
74
-
75
- # Crop the detected region
76
- cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
77
- cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
78
- cropped_image.save(cropped_image_path, "JPEG")
79
-
80
- # uploading to cloud for getting URL to pass into LLM
81
- print(f'Uploading now to image url')
82
- image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
83
- print(f'Image URL received as{image_url}')
84
- # inferencing llm for possible brands
85
- result_llms = inference.get_name(image_url, category)
86
-
87
- detected_text = self.predict_text(cropped_image)
88
- print(f'Details:{detected_text}')
89
- print(f'Predicted brand: {result_llms["model"]}')
90
- # Draw bounding box and label on the image
91
- draw.rectangle(bbox, outline="red", width=3)
92
- draw.text(
93
- (bbox[0], bbox[1] - 10),
94
- f'{result_llms["brand"]})',
95
- fill="red"
96
- )
97
-
98
- cropped_image_io = BytesIO()
99
- cropped_image.save(cropped_image_io, format="JPEG")
100
- cropped_image_io.seek(0)
101
-
102
- # Append result
103
- output.append({
104
- "category": category,
105
- "bbox": bbox,
106
- "confidence": confidence,
107
- "category_llm":result_llms["brand"],
108
- "predicted_brand": result_llms["model"],
109
- # "clip_confidence": clip_confidence,
110
- "price":result_llms["price"],
111
- "details":result_llms["description"],
112
- "detected_text":detected_text,
113
- "image_path":cropped_image_path,
114
- "image_url":image_url,
115
- })
116
-
117
- # file_responses.append(f"/download_cropped_image/{idx}")
118
-
119
- valid_indices = set(range(len(predictions)))
120
-
121
- # Iterate over all files in the directory
122
- for filename in os.listdir(cropped_images_dir):
123
- # Check if the filename matches the pattern for cropped images
124
- if filename.startswith("crop_") and filename.endswith(".jpg"):
125
- # Extract the index from the filename
126
- try:
127
- file_idx = int(filename.split("_")[1].split(".")[0])
128
- if file_idx not in valid_indices:
129
- # Delete the file if its index is not valid
130
- file_path = os.path.join(cropped_images_dir, filename)
131
- os.remove(file_path)
132
- print(f"Deleted excess file: {filename}")
133
- except ValueError:
134
- # Skip files that don't match the pattern
135
- continue
136
-
137
- return output
138
- # return JSONResponse(
139
- # content={
140
- # "metadata": results,
141
- # "cropped_image_urls": [
142
- # f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
143
- # ],
144
- # }
145
- # )
146
- # return {"metadata": results, "cropped_image_urls": file_responses}
147
-
148
-
 
 
 
 
 
 
 
 
 
 
 
 
149
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image, ImageDraw
5
+ import pytesseract
6
+ import requests
7
+ import os
8
+ from llm import inference, upload_image
9
+ from fastapi.responses import FileResponse, JSONResponse
10
+
11
+ import re
12
+
13
+ from io import BytesIO
14
+
15
+ cropped_images_dir = "cropped_images"
16
+ os.makedirs(cropped_images_dir, exist_ok=True)
17
+
18
+ # Load YOLO model
19
+ class YOLOModel:
20
+ def __init__(self, model_path="yolov5s.pt"):
21
+ """
22
+ Initialize the YOLO model. Downloads YOLOv5 pretrained model if not available.
23
+ """
24
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
25
+ self.model = torch.hub.load("ultralytics/yolov5", "custom", path=model_path, force_reload=True)
26
+
27
+
28
+ def predict_clip(self, image, brand_names):
29
+ """
30
+ Predict the most probable brand using CLIP.
31
+ """
32
+ inputs = self.clip_processor(
33
+ text=brand_names,
34
+ images=image,
35
+ return_tensors="pt",
36
+ padding=True
37
+ )
38
+ # print(f'Inputs to clip processor:{inputs}')
39
+ outputs = self.clip_model(**inputs)
40
+ logits_per_image = outputs.logits_per_image
41
+ probs = logits_per_image.softmax(dim=1) # Convert logits to probabilities
42
+ best_idx = probs.argmax().item()
43
+ return brand_names[best_idx], probs[0, best_idx].item()
44
+
45
+
46
+ def predict_text(self, image):
47
+ try:
48
+ # Convert image to grayscale
49
+ grayscale = image.convert('L')
50
+
51
+ # Perform OCR using pytesseract
52
+ text = pytesseract.image_to_string(grayscale)
53
+
54
+ # Return the stripped text if successful
55
+ return text.strip()
56
+ except Exception as e:
57
+ # Log the error for debugging purposes
58
+ print(f"Error during text prediction: {e}")
59
+
60
+ # Return an empty string if OCR fails
61
+ return ""
62
+
63
+
64
+ def predict(self, image_path):
65
+ """
66
+ Run YOLO inference on an image.
67
+
68
+ :param image_path: Path to the input image
69
+ :return: List of predictions with labels and bounding boxes
70
+ """
71
+ results = self.model(image_path)
72
+ image = Image.open(image_path).convert("RGB")
73
+ draw = ImageDraw.Draw(image)
74
+ predictions = results.pandas().xyxy[0] # Get predictions as pandas DataFrame
75
+ print(f'YOLO predictions:\n\n{predictions}')
76
+
77
+
78
+ output = []
79
+ file_responses = []
80
+
81
+
82
+ for idx, row in predictions.iterrows():
83
+ category = row['name']
84
+ confidence = row['confidence']
85
+ bbox = [row["xmin"], row["ymin"], row["xmax"], row["ymax"]]
86
+
87
+ # Crop the detected region
88
+ cropped_image = image.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
89
+ cropped_image_path = os.path.join(cropped_images_dir, f"crop_{idx}.jpg")
90
+ cropped_image.save(cropped_image_path, "JPEG")
91
+
92
+ # uploading to cloud for getting URL to pass into LLM
93
+ print(f'Uploading now to image url')
94
+ image_url = upload_image.upload_image_to_imgbb(cropped_image_path)
95
+ print(f'Image URL received as{image_url}')
96
+ # inferencing llm for possible brands
97
+ result_llms = inference.get_name(image_url, category)
98
+
99
+ detected_text = self.predict_text(cropped_image)
100
+ print(f'Details:{detected_text}')
101
+ print(f'Predicted brand: {result_llms["model"]}')
102
+ # Draw bounding box and label on the image
103
+ draw.rectangle(bbox, outline="red", width=3)
104
+ draw.text(
105
+ (bbox[0], bbox[1] - 10),
106
+ f'{result_llms["brand"]})',
107
+ fill="red"
108
+ )
109
+
110
+ cropped_image_io = BytesIO()
111
+ cropped_image.save(cropped_image_io, format="JPEG")
112
+ cropped_image_io.seek(0)
113
+
114
+ # Append result
115
+ output.append({
116
+ "category": category,
117
+ "bbox": bbox,
118
+ "confidence": confidence,
119
+ "category_llm":result_llms["brand"],
120
+ "predicted_brand": result_llms["model"],
121
+ # "clip_confidence": clip_confidence,
122
+ "price":result_llms["price"],
123
+ "details":result_llms["description"],
124
+ "detected_text":detected_text,
125
+ "image_path":cropped_image_path,
126
+ "image_url":image_url,
127
+ })
128
+
129
+ # file_responses.append(f"/download_cropped_image/{idx}")
130
+
131
+ valid_indices = set(range(len(predictions)))
132
+
133
+ # Iterate over all files in the directory
134
+ for filename in os.listdir(cropped_images_dir):
135
+ # Check if the filename matches the pattern for cropped images
136
+ if filename.startswith("crop_") and filename.endswith(".jpg"):
137
+ # Extract the index from the filename
138
+ try:
139
+ file_idx = int(filename.split("_")[1].split(".")[0])
140
+ if file_idx not in valid_indices:
141
+ # Delete the file if its index is not valid
142
+ file_path = os.path.join(cropped_images_dir, filename)
143
+ os.remove(file_path)
144
+ print(f"Deleted excess file: {filename}")
145
+ except ValueError:
146
+ # Skip files that don't match the pattern
147
+ continue
148
+
149
+ return output
150
+ # return JSONResponse(
151
+ # content={
152
+ # "metadata": results,
153
+ # "cropped_image_urls": [
154
+ # f"/download_cropped_image/{idx}" for idx in range(len(file_responses))
155
+ # ],
156
+ # }
157
+ # )
158
+ # return {"metadata": results, "cropped_image_urls": file_responses}
159
+
160
+
161